summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py
index b934e79..8e25ad3 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -276,6 +276,9 @@ class StaticFunction(ModelFunction):
ret.update({"type": "static", "value": self.value})
return ret
+ def to_dot(self, pydot, graph, feature_names, parent=None):
+ graph.add_node(pydot.Node(str(id(self)), label=self.value, shape="rectangle"))
+
@classmethod
def from_json(cls, data):
assert data["type"] == "static"
@@ -352,6 +355,16 @@ class SplitFunction(ModelFunction):
ret.append(v.get_max_depth())
return 1 + max(ret)
+ def to_dot(self, pydot, graph, feature_names, parent=None):
+ try:
+ label = feature_names[self.param_index]
+ except IndexError:
+ label = f"param{self.param_index}"
+ graph.add_node(pydot.Node(str(id(self)), label=label))
+ for key, child in self.child.items():
+ child.to_dot(pydot, graph, feature_names, str(id(self)))
+ graph.add_edge(pydot.Edge(str(id(self)), str(id(child)), label=key))
+
@classmethod
def from_json(cls, data):
assert data["type"] == "split"
@@ -737,6 +750,16 @@ class AnalyticFunction(ModelFunction):
)
return ret
+ def to_dot(self, pydot, graph, feature_names, parent=None):
+ model_function = self.model_function
+ for i, arg in enumerate(self.model_args):
+ model_function = model_function.replace(
+ f"regression_arg({i})", f"{arg:.2f}"
+ )
+ graph.add_node(
+ pydot.Node(str(id(self)), label=model_function, shape="rectangle")
+ )
+
@classmethod
def from_json(cls, data):
assert data["type"] == "analytic"