diff options
Diffstat (limited to 'lib/parameters.py')
-rw-r--r-- | lib/parameters.py | 71 |
1 files changed, 30 insertions, 41 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index e6d5561..9a0171b 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -688,57 +688,46 @@ class ModelAttribute: graph = pydot.Dot("Regression Model Tree", graph_type="graph") self.model_function.to_dot(pydot, graph, self.param_names) return graph - if type(self.model_function) == df.CARTFunction: - import sklearn.tree - feature_names = list( - map( - lambda i: self.param_names[i], - filter( - lambda i: not self.model_function.ignore_index[i], - range(len(self.param_names)), - ), - ) + feature_names = list( + map( + lambda i: self.param_names[i], + filter( + lambda i: not self.model_function.ignore_index[i], + range(len(self.param_names)), + ), ) - feature_names += list( - map( - lambda i: f"arg{i-len(self.param_names)}", - filter( - lambda i: not self.model_function.ignore_index[i], - range( - len(self.param_names), - len(self.param_names) + self.arg_count, - ), + ) + feature_names += list( + map( + lambda i: f"arg{i-len(self.param_names)}", + filter( + lambda i: not self.model_function.ignore_index[i], + range( + len(self.param_names), + len(self.param_names) + self.arg_count, ), - ) + ), ) + ) + + if type(self.model_function) == df.CARTFunction: + import sklearn.tree + return sklearn.tree.export_graphviz( self.model_function.regressor, out_file=None, feature_names=feature_names, ) + if type(self.model_function) == df.XGBoostFunction: + import xgboost + + self.model_function.regressor.get_booster().feature_names = feature_names + return [ + xgboost.to_graphviz(self.model_function.regressor, num_trees=i) + for i in range(self.model_function.regressor.n_estimators) + ] if type(self.model_function) == df.LMTFunction: - feature_names = list( - map( - lambda i: self.param_names[i], - filter( - lambda i: not self.model_function.ignore_index[i], - range(len(self.param_names)), - ), - ) - ) - feature_names += list( - map( - lambda i: f"arg{i-len(self.param_names)}", - filter( - lambda i: not self.model_function.ignore_index[i], - range( - len(self.param_names), - len(self.param_names) + self.arg_count, - ), - ), - ) - ) return self.model_function.regressor.model_to_dot( feature_names=feature_names ) |