diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-12 13:01:02 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-12 13:01:02 +0100 |
commit | 863fd68c0835076264011a3422e40142472484aa (patch) | |
tree | 69e36b27d27717139cc341df3c3b7137d01373f7 /lib | |
parent | 27697c537df10f764750602d81b3807523236b23 (diff) |
CART, LMT export: handle function args
Diffstat (limited to 'lib')
-rw-r--r-- | lib/functions.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/lib/functions.py b/lib/functions.py index 81a94b4..63523a8 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -594,6 +594,9 @@ class SKLearnRegressionFunction(ModelFunction): # Needed for JSON export self.param_names = kwargs.pop("param_names") self.arg_count = kwargs.pop("arg_count") + self.param_names_and_args = self.param_names + list( + map(lambda i: f"arg{i}", range(self.arg_count)) + ) super().__init__(value, **kwargs) @@ -763,7 +766,9 @@ class CARTFunction(SKLearnRegressionFunction): sub_data["paramName"] = self.feature_names[ self.regressor.tree_.feature[node_id] ] - sub_data["paramIndex"] = self.param_names.index(sub_data["paramName"]) + sub_data["paramIndex"] = self.param_names_and_args.index( + sub_data["paramName"] + ) except IndexError: sub_data["paramName"] = "arg" + str( self.regressor.tree_.feature[node_id] - len(self.feature_names) @@ -833,7 +838,9 @@ class LMTFunction(SKLearnRegressionFunction): return { "type": "scalarSplit", "paramName": self.feature_names[node["col"]], - "paramIndex": self.param_names.index(self.feature_names[node["col"]]), + "paramIndex": self.param_names_and_args.index( + self.feature_names[node["col"]] + ), "threshold": node["th"], "left": self.recurse_(node_hash, node["children"][0]), "right": self.recurse_(node_hash, node["children"][1]), |