diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2022-02-21 12:26:29 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2022-02-21 12:26:29 +0100 |
commit | 37c3c6ec5382052743a99d881298255b7b7ccc50 (patch) | |
tree | a0394d9e84f993ef625cd1fb2431a2a33103e8f7 /lib/parameters.py | |
parent | 32213cdb49a08bd0f2012baaefb031d0726dc4bf (diff) |
add dtree graphviz/dot export via --export-dot / to_dot()
Diffstat (limited to 'lib/parameters.py')
-rw-r--r-- | lib/parameters.py | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index 50a7ae8..45db489 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -645,6 +645,69 @@ class ModelAttribute: return ret + def to_dot(self): + if type(self.model_function) == df.SplitFunction: + import pydot + + graph = pydot.Dot("Regression Model Tree", graph_type="graph") + self.model_function.to_dot(pydot, graph, self.param_names) + return graph + if type(self.model_function) == df.CARTFunction: + import sklearn.tree + + feature_names = list( + map( + lambda i: self.param_names[i], + filter( + lambda i: not self.model_function.ignore_index[i], + range(len(self.param_names)), + ), + ) + ) + feature_names += list( + map( + lambda i: f"arg{i-len(self.param_names)}", + filter( + lambda i: not self.model_function.ignore_index[i], + range( + len(self.param_names), + len(self.param_names) + self.arg_count, + ), + ), + ) + ) + return sklearn.tree.export_graphviz( + self.model_function.regressor, + out_file=None, + feature_names=feature_names, + ) + if type(self.model_function) == df.LMTFunction: + feature_names = list( + map( + lambda i: self.param_names[i], + filter( + lambda i: not self.model_function.ignore_index[i], + range(len(self.param_names)), + ), + ) + ) + feature_names += list( + map( + lambda i: f"arg{i-len(self.param_names)}", + filter( + lambda i: not self.model_function.ignore_index[i], + range( + len(self.param_names), + len(self.param_names) + self.arg_count, + ), + ), + ) + ) + return self.model_function.regressor.model_to_dot( + feature_names=feature_names + ) + return None + def webconf_function_map(self): return self.model_function.webconf_function_map() |