From 863fd68c0835076264011a3422e40142472484aa Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Mon, 12 Feb 2024 13:01:02 +0100 Subject: CART, LMT export: handle function args --- lib/functions.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'lib') diff --git a/lib/functions.py b/lib/functions.py index 81a94b4..63523a8 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -594,6 +594,9 @@ class SKLearnRegressionFunction(ModelFunction): # Needed for JSON export self.param_names = kwargs.pop("param_names") self.arg_count = kwargs.pop("arg_count") + self.param_names_and_args = self.param_names + list( + map(lambda i: f"arg{i}", range(self.arg_count)) + ) super().__init__(value, **kwargs) @@ -763,7 +766,9 @@ class CARTFunction(SKLearnRegressionFunction): sub_data["paramName"] = self.feature_names[ self.regressor.tree_.feature[node_id] ] - sub_data["paramIndex"] = self.param_names.index(sub_data["paramName"]) + sub_data["paramIndex"] = self.param_names_and_args.index( + sub_data["paramName"] + ) except IndexError: sub_data["paramName"] = "arg" + str( self.regressor.tree_.feature[node_id] - len(self.feature_names) @@ -833,7 +838,9 @@ class LMTFunction(SKLearnRegressionFunction): return { "type": "scalarSplit", "paramName": self.feature_names[node["col"]], - "paramIndex": self.param_names.index(self.feature_names[node["col"]]), + "paramIndex": self.param_names_and_args.index( + self.feature_names[node["col"]] + ), "threshold": node["th"], "left": self.recurse_(node_hash, node["children"][0]), "right": self.recurse_(node_hash, node["children"][1]), -- cgit v1.2.3