From 4c5b7b2f07b3e1f661dc85828f8fee5c7c88ac72 Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Wed, 2 Mar 2022 21:40:24 +0100 Subject: CARTFunction: Add to_json method (adapted from Lennart Kaiser) --- lib/functions.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) (limited to 'lib') diff --git a/lib/functions.py b/lib/functions.py index 69788df..aeaf48e 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -487,6 +487,50 @@ class CARTFunction(SKLearnRegressionFunction): def get_max_depth(self): return self.regressor.get_depth() + def to_json(self, feature_names=None, **kwargs): + import sklearn.tree + + self.leaf_id = sklearn.tree._tree.TREE_LEAF + self.feature_names = feature_names + + ret = super().to_json(**kwargs) + ret.update(self.recurse_(self.regressor.tree_, 0)) + return ret + + # recursive function for all nodes: + def recurse_(self, tree, node_id, depth=0): + left_child = tree.children_left[node_id] + right_child = tree.children_right[node_id] + + # basic leaf with standard values + # conversion because of numpy + sub_data = { + "functionError": None, + "type": "static", + "value": float(tree.value[node_id]), + "valueError": float(tree.impurity[node_id]), + # "samples": int(tree.n_node_samples[node_id]) + } + + # if has childs / not a leaf: + if left_child != self.leaf_id or right_child != self.leaf_id: + # sub_data["paramName"] = "X[" + str(self.regressor.tree_.feature[left_child_id]) + "]" + # sub_data["paramIndex"] = int(self.regressor.tree_.feature[left_child_id]) + sub_data["paramName"] = self.feature_names[ + self.regressor.tree_.feature[node_id] + ] + sub_data["paramDecisionValue"] = tree.threshold[node_id] + sub_data["type"] = "scalarSplit" + sub_data["child"] = {} + + # child value + if left_child != self.leaf_id: + sub_data["child"]["<="] = self.recurse_(tree, left_child, depth=depth + 1) + if right_child != self.leaf_id: + sub_data["child"][">"] = self.recurse_(tree, right_child, depth=depth + 1) + + return sub_data + class LMTFunction(SKLearnRegressionFunction): def get_number_of_nodes(self): -- cgit v1.2.3