summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-01-25 09:04:22 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-01-25 09:04:22 +0100
commit9e8f9774c42cac0904d56d8edfd4abd4b2b717d1 (patch)
treee6059818c644b87dd0619bbaeed3ed75a47e62e7
parent680809705799c7b12c10236be7182cad52263a68 (diff)
add hyper-parameter export for CART and LMT
-rw-r--r--lib/functions.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py
index a23a995..32f777b 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -575,6 +575,17 @@ class CARTFunction(SKLearnRegressionFunction):
ret.update(self.recurse_(self.regressor.tree_, 0))
return ret
+ def hyper_to_dref(self):
+ return {
+ "cart/max depth": self.regressor.max_depth or "infty",
+ "cart/min samples split": self.regressor.min_samples_split,
+ "cart/min samples leaf": self.regressor.min_samples_leaf,
+ "cart/min impurity decrease": self.regressor.min_impurity_decrease,
+ "cart/max leaf nodes": self.regressor.max_leaf_nodes or "infty",
+ "cart/criterion": self.regressor.criterion,
+ "cart/splitter": self.regressor.splitter,
+ }
+
# recursive function for all nodes:
def recurse_(self, tree, node_id, depth=0):
left_child = tree.children_left[node_id]
@@ -635,6 +646,15 @@ class LMTFunction(SKLearnRegressionFunction):
ret.update(self.recurse_(self.regressor.summary(), 0))
return ret
+ def hyper_to_dref(self):
+ return {
+ "lmt/max depth": self.regressor.max_depth,
+ "lmt/max bins": self.regressor.max_bins,
+ "lmt/min samples split": self.regressor.min_samples_split,
+ "lmt/min samples leaf": self.regressor.min_samples_leaf,
+ "lmt/criterion": self.regressor.criterion,
+ }
+
def recurse_(self, node_hash, node_index):
node = node_hash[node_index]
sub_data = dict()