diff options
Diffstat (limited to 'lib/functions.py')
-rw-r--r-- | lib/functions.py | 24 |
1 files changed, 21 insertions, 3 deletions
diff --git a/lib/functions.py b/lib/functions.py index 23bf997..25851c5 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -710,9 +710,27 @@ class CARTFunction(SKLearnRegressionFunction): if left_child != self.leaf_id or right_child != self.leaf_id: # sub_data["paramName"] = "X[" + str(self.regressor.tree_.feature[left_child_id]) + "]" # sub_data["paramIndex"] = int(self.regressor.tree_.feature[left_child_id]) - sub_data["paramName"] = self.feature_names[ - self.regressor.tree_.feature[node_id] - ] + try: + sub_data["paramName"] = self.feature_names[ + self.regressor.tree_.feature[node_id] + ] + sub_data["paramIndex"] = self.param_names.index(sub_data["paramName"]) + except IndexError: + sub_data["paramName"] = "arg" + str( + self.regressor.tree_.feature[node_id] - len(self.feature_names) + ) + sub_data["paramIndex"] = ( + len(self.param_names) + + self.regressor.tree_.feature[node_id] + - len(self.feature_names) + ) + except ValueError: + sub_data["paramIndex"] = ( + len(self.param_names) + + self.regressor.tree_.feature[node_id] + - len(self.feature_names) + ) + sub_data["threshold"] = tree.threshold[node_id] sub_data["type"] = "scalarSplit" |