diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-09 15:27:35 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-09 15:27:35 +0100 |
commit | 66d9d29ab699f01a96cfd1c62d9697c10b88b3d9 (patch) | |
tree | 8f45bface748d863416d72f8538c22c21fa04c57 | |
parent | 80c92831efa7c13c7dab629e58068d8b4855e5db (diff) |
CART.to_json: support function arguments
-rw-r--r-- | lib/functions.py | 24 | ||||
-rw-r--r-- | lib/parameters.py | 1 |
2 files changed, 22 insertions, 3 deletions
diff --git a/lib/functions.py b/lib/functions.py index 23bf997..25851c5 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -710,9 +710,27 @@ class CARTFunction(SKLearnRegressionFunction): if left_child != self.leaf_id or right_child != self.leaf_id: # sub_data["paramName"] = "X[" + str(self.regressor.tree_.feature[left_child_id]) + "]" # sub_data["paramIndex"] = int(self.regressor.tree_.feature[left_child_id]) - sub_data["paramName"] = self.feature_names[ - self.regressor.tree_.feature[node_id] - ] + try: + sub_data["paramName"] = self.feature_names[ + self.regressor.tree_.feature[node_id] + ] + sub_data["paramIndex"] = self.param_names.index(sub_data["paramName"]) + except IndexError: + sub_data["paramName"] = "arg" + str( + self.regressor.tree_.feature[node_id] - len(self.feature_names) + ) + sub_data["paramIndex"] = ( + len(self.param_names) + + self.regressor.tree_.feature[node_id] + - len(self.feature_names) + ) + except ValueError: + sub_data["paramIndex"] = ( + len(self.param_names) + + self.regressor.tree_.feature[node_id] + - len(self.feature_names) + ) + sub_data["threshold"] = tree.threshold[node_id] sub_data["type"] = "scalarSplit" diff --git a/lib/parameters.py b/lib/parameters.py index ad23780..ae4fffb 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -1062,6 +1062,7 @@ class ModelAttribute: category_to_index, ignore_index, n_samples=len(data), + param_names=self.param_names, ) logger.debug("Fitted sklearn CART") return |