diff options
-rw-r--r-- | lib/functions.py | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py index a23a995..32f777b 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -575,6 +575,17 @@ class CARTFunction(SKLearnRegressionFunction): ret.update(self.recurse_(self.regressor.tree_, 0)) return ret + def hyper_to_dref(self): + return { + "cart/max depth": self.regressor.max_depth or "infty", + "cart/min samples split": self.regressor.min_samples_split, + "cart/min samples leaf": self.regressor.min_samples_leaf, + "cart/min impurity decrease": self.regressor.min_impurity_decrease, + "cart/max leaf nodes": self.regressor.max_leaf_nodes or "infty", + "cart/criterion": self.regressor.criterion, + "cart/splitter": self.regressor.splitter, + } + # recursive function for all nodes: def recurse_(self, tree, node_id, depth=0): left_child = tree.children_left[node_id] @@ -635,6 +646,15 @@ class LMTFunction(SKLearnRegressionFunction): ret.update(self.recurse_(self.regressor.summary(), 0)) return ret + def hyper_to_dref(self): + return { + "lmt/max depth": self.regressor.max_depth, + "lmt/max bins": self.regressor.max_bins, + "lmt/min samples split": self.regressor.min_samples_split, + "lmt/min samples leaf": self.regressor.min_samples_leaf, + "lmt/criterion": self.regressor.criterion, + } + def recurse_(self, node_hash, node_index): node = node_hash[node_index] sub_data = dict() |