diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2022-03-02 21:40:24 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2022-03-02 21:40:24 +0100 |
commit | 4c5b7b2f07b3e1f661dc85828f8fee5c7c88ac72 (patch) | |
tree | c6245cc56f882605c2b4b45013d5287f6747c5d3 /lib | |
parent | ca4e50cb46fd19222a1fbc399dac217b79f25e08 (diff) |
CARTFunction: Add to_json method (adapted from Lennart Kaiser)
Diffstat (limited to 'lib')
-rw-r--r-- | lib/functions.py | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py index 69788df..aeaf48e 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -487,6 +487,50 @@ class CARTFunction(SKLearnRegressionFunction): def get_max_depth(self): return self.regressor.get_depth() + def to_json(self, feature_names=None, **kwargs): + import sklearn.tree + + self.leaf_id = sklearn.tree._tree.TREE_LEAF + self.feature_names = feature_names + + ret = super().to_json(**kwargs) + ret.update(self.recurse_(self.regressor.tree_, 0)) + return ret + + # recursive function for all nodes: + def recurse_(self, tree, node_id, depth=0): + left_child = tree.children_left[node_id] + right_child = tree.children_right[node_id] + + # basic leaf with standard values + # conversion because of numpy + sub_data = { + "functionError": None, + "type": "static", + "value": float(tree.value[node_id]), + "valueError": float(tree.impurity[node_id]), + # "samples": int(tree.n_node_samples[node_id]) + } + + # if has childs / not a leaf: + if left_child != self.leaf_id or right_child != self.leaf_id: + # sub_data["paramName"] = "X[" + str(self.regressor.tree_.feature[left_child_id]) + "]" + # sub_data["paramIndex"] = int(self.regressor.tree_.feature[left_child_id]) + sub_data["paramName"] = self.feature_names[ + self.regressor.tree_.feature[node_id] + ] + sub_data["paramDecisionValue"] = tree.threshold[node_id] + sub_data["type"] = "scalarSplit" + sub_data["child"] = {} + + # child value + if left_child != self.leaf_id: + sub_data["child"]["<="] = self.recurse_(tree, left_child, depth=depth + 1) + if right_child != self.leaf_id: + sub_data["child"][">"] = self.recurse_(tree, right_child, depth=depth + 1) + + return sub_data + class LMTFunction(SKLearnRegressionFunction): def get_number_of_nodes(self): |