summaryrefslogtreecommitdiff
path: root/lib/parameters.py
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-02-21 12:26:29 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2022-02-21 12:26:29 +0100
commit37c3c6ec5382052743a99d881298255b7b7ccc50 (patch)
treea0394d9e84f993ef625cd1fb2431a2a33103e8f7 /lib/parameters.py
parent32213cdb49a08bd0f2012baaefb031d0726dc4bf (diff)
add dtree graphviz/dot export via --export-dot / to_dot()
Diffstat (limited to 'lib/parameters.py')
-rw-r--r--lib/parameters.py63
1 files changed, 63 insertions, 0 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index 50a7ae8..45db489 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -645,6 +645,69 @@ class ModelAttribute:
return ret
+ def to_dot(self):
+ if type(self.model_function) == df.SplitFunction:
+ import pydot
+
+ graph = pydot.Dot("Regression Model Tree", graph_type="graph")
+ self.model_function.to_dot(pydot, graph, self.param_names)
+ return graph
+ if type(self.model_function) == df.CARTFunction:
+ import sklearn.tree
+
+ feature_names = list(
+ map(
+ lambda i: self.param_names[i],
+ filter(
+ lambda i: not self.model_function.ignore_index[i],
+ range(len(self.param_names)),
+ ),
+ )
+ )
+ feature_names += list(
+ map(
+ lambda i: f"arg{i-len(self.param_names)}",
+ filter(
+ lambda i: not self.model_function.ignore_index[i],
+ range(
+ len(self.param_names),
+ len(self.param_names) + self.arg_count,
+ ),
+ ),
+ )
+ )
+ return sklearn.tree.export_graphviz(
+ self.model_function.regressor,
+ out_file=None,
+ feature_names=feature_names,
+ )
+ if type(self.model_function) == df.LMTFunction:
+ feature_names = list(
+ map(
+ lambda i: self.param_names[i],
+ filter(
+ lambda i: not self.model_function.ignore_index[i],
+ range(len(self.param_names)),
+ ),
+ )
+ )
+ feature_names += list(
+ map(
+ lambda i: f"arg{i-len(self.param_names)}",
+ filter(
+ lambda i: not self.model_function.ignore_index[i],
+ range(
+ len(self.param_names),
+ len(self.param_names) + self.arg_count,
+ ),
+ ),
+ )
+ )
+ return self.model_function.regressor.model_to_dot(
+ feature_names=feature_names
+ )
+ return None
+
def webconf_function_map(self):
return self.model_function.webconf_function_map()