summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-03-02 21:40:24 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2022-03-02 21:40:24 +0100
commit4c5b7b2f07b3e1f661dc85828f8fee5c7c88ac72 (patch)
treec6245cc56f882605c2b4b45013d5287f6747c5d3
parentca4e50cb46fd19222a1fbc399dac217b79f25e08 (diff)
CARTFunction: Add to_json method (adapted from Lennart Kaiser)
-rw-r--r--lib/functions.py44
1 files changed, 44 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 69788df..aeaf48e 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -487,6 +487,50 @@ class CARTFunction(SKLearnRegressionFunction):
def get_max_depth(self):
return self.regressor.get_depth()
+ def to_json(self, feature_names=None, **kwargs):
+ import sklearn.tree
+
+ self.leaf_id = sklearn.tree._tree.TREE_LEAF
+ self.feature_names = feature_names
+
+ ret = super().to_json(**kwargs)
+ ret.update(self.recurse_(self.regressor.tree_, 0))
+ return ret
+
+ # recursive function for all nodes:
+ def recurse_(self, tree, node_id, depth=0):
+ left_child = tree.children_left[node_id]
+ right_child = tree.children_right[node_id]
+
+ # basic leaf with standard values
+ # conversion because of numpy
+ sub_data = {
+ "functionError": None,
+ "type": "static",
+ "value": float(tree.value[node_id]),
+ "valueError": float(tree.impurity[node_id]),
+ # "samples": int(tree.n_node_samples[node_id])
+ }
+
+ # if has childs / not a leaf:
+ if left_child != self.leaf_id or right_child != self.leaf_id:
+ # sub_data["paramName"] = "X[" + str(self.regressor.tree_.feature[left_child_id]) + "]"
+ # sub_data["paramIndex"] = int(self.regressor.tree_.feature[left_child_id])
+ sub_data["paramName"] = self.feature_names[
+ self.regressor.tree_.feature[node_id]
+ ]
+ sub_data["paramDecisionValue"] = tree.threshold[node_id]
+ sub_data["type"] = "scalarSplit"
+ sub_data["child"] = {}
+
+ # child value
+ if left_child != self.leaf_id:
+ sub_data["child"]["<="] = self.recurse_(tree, left_child, depth=depth + 1)
+ if right_child != self.leaf_id:
+ sub_data["child"][">"] = self.recurse_(tree, right_child, depth=depth + 1)
+
+ return sub_data
+
class LMTFunction(SKLearnRegressionFunction):
def get_number_of_nodes(self):