diff options
Diffstat (limited to 'lib/functions.py')
-rw-r--r-- | lib/functions.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py index b934e79..8e25ad3 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -276,6 +276,9 @@ class StaticFunction(ModelFunction): ret.update({"type": "static", "value": self.value}) return ret + def to_dot(self, pydot, graph, feature_names, parent=None): + graph.add_node(pydot.Node(str(id(self)), label=self.value, shape="rectangle")) + @classmethod def from_json(cls, data): assert data["type"] == "static" @@ -352,6 +355,16 @@ class SplitFunction(ModelFunction): ret.append(v.get_max_depth()) return 1 + max(ret) + def to_dot(self, pydot, graph, feature_names, parent=None): + try: + label = feature_names[self.param_index] + except IndexError: + label = f"param{self.param_index}" + graph.add_node(pydot.Node(str(id(self)), label=label)) + for key, child in self.child.items(): + child.to_dot(pydot, graph, feature_names, str(id(self))) + graph.add_edge(pydot.Edge(str(id(self)), str(id(child)), label=key)) + @classmethod def from_json(cls, data): assert data["type"] == "split" @@ -737,6 +750,16 @@ class AnalyticFunction(ModelFunction): ) return ret + def to_dot(self, pydot, graph, feature_names, parent=None): + model_function = self.model_function + for i, arg in enumerate(self.model_args): + model_function = model_function.replace( + f"regression_arg({i})", f"{arg:.2f}" + ) + graph.add_node( + pydot.Node(str(id(self)), label=model_function, shape="rectangle") + ) + @classmethod def from_json(cls, data): assert data["type"] == "analytic" |