summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py24
1 files changed, 21 insertions, 3 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 23bf997..25851c5 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -710,9 +710,27 @@ class CARTFunction(SKLearnRegressionFunction):
if left_child != self.leaf_id or right_child != self.leaf_id:
# sub_data["paramName"] = "X[" + str(self.regressor.tree_.feature[left_child_id]) + "]"
# sub_data["paramIndex"] = int(self.regressor.tree_.feature[left_child_id])
- sub_data["paramName"] = self.feature_names[
- self.regressor.tree_.feature[node_id]
- ]
+ try:
+ sub_data["paramName"] = self.feature_names[
+ self.regressor.tree_.feature[node_id]
+ ]
+ sub_data["paramIndex"] = self.param_names.index(sub_data["paramName"])
+ except IndexError:
+ sub_data["paramName"] = "arg" + str(
+ self.regressor.tree_.feature[node_id] - len(self.feature_names)
+ )
+ sub_data["paramIndex"] = (
+ len(self.param_names)
+ + self.regressor.tree_.feature[node_id]
+ - len(self.feature_names)
+ )
+ except ValueError:
+ sub_data["paramIndex"] = (
+ len(self.param_names)
+ + self.regressor.tree_.feature[node_id]
+ - len(self.feature_names)
+ )
+
sub_data["threshold"] = tree.threshold[node_id]
sub_data["type"] = "scalarSplit"