summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-02-12 13:01:02 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-02-12 13:01:02 +0100
commit863fd68c0835076264011a3422e40142472484aa (patch)
tree69e36b27d27717139cc341df3c3b7137d01373f7 /lib
parent27697c537df10f764750602d81b3807523236b23 (diff)
CART, LMT export: handle function args
Diffstat (limited to 'lib')
-rw-r--r--lib/functions.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 81a94b4..63523a8 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -594,6 +594,9 @@ class SKLearnRegressionFunction(ModelFunction):
# Needed for JSON export
self.param_names = kwargs.pop("param_names")
self.arg_count = kwargs.pop("arg_count")
+ self.param_names_and_args = self.param_names + list(
+ map(lambda i: f"arg{i}", range(self.arg_count))
+ )
super().__init__(value, **kwargs)
@@ -763,7 +766,9 @@ class CARTFunction(SKLearnRegressionFunction):
sub_data["paramName"] = self.feature_names[
self.regressor.tree_.feature[node_id]
]
- sub_data["paramIndex"] = self.param_names.index(sub_data["paramName"])
+ sub_data["paramIndex"] = self.param_names_and_args.index(
+ sub_data["paramName"]
+ )
except IndexError:
sub_data["paramName"] = "arg" + str(
self.regressor.tree_.feature[node_id] - len(self.feature_names)
@@ -833,7 +838,9 @@ class LMTFunction(SKLearnRegressionFunction):
return {
"type": "scalarSplit",
"paramName": self.feature_names[node["col"]],
- "paramIndex": self.param_names.index(self.feature_names[node["col"]]),
+ "paramIndex": self.param_names_and_args.index(
+ self.feature_names[node["col"]]
+ ),
"threshold": node["th"],
"left": self.recurse_(node_hash, node["children"][0]),
"right": self.recurse_(node_hash, node["children"][1]),