From c14e6a92e2239d852d52338af41a4152988e7960 Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Fri, 9 Feb 2024 08:17:27 +0100 Subject: LMT export: set paramIndex; use correct parameterNames in analytic nodes --- lib/functions.py | 7 ++++++- lib/parameters.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) (limited to 'lib') diff --git a/lib/functions.py b/lib/functions.py index 8df15de..23bf997 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -589,7 +589,11 @@ class SKLearnRegressionFunction(ModelFunction): has_eval_arr = True def __init__(self, value, regressor, categorial_to_index, ignore_index, **kwargs): + # Needed for JSON export + self.param_names = kwargs.pop("param_names", None) + super().__init__(value, **kwargs) + self.regressor = regressor self.categorial_to_index = categorial_to_index self.ignore_index = ignore_index @@ -763,6 +767,7 @@ class LMTFunction(SKLearnRegressionFunction): return { "type": "scalarSplit", "paramName": self.feature_names[node["col"]], + "paramIndex": self.param_names.index(self.feature_names[node["col"]]), "threshold": node["th"], "left": self.recurse_(node_hash, node["children"][0]), "right": self.recurse_(node_hash, node["children"][1]), @@ -775,7 +780,7 @@ class LMTFunction(SKLearnRegressionFunction): return { "type": "analytic", "functionStr": fs, - "parameterNames": self.feature_names, + "parameterNames": self.param_names, "regressionModel": [model.intercept_] + list(model.coef_), } diff --git a/lib/parameters.py b/lib/parameters.py index 210cbe3..ad23780 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -1240,7 +1240,12 @@ class ModelAttribute: return logger.debug("Fitted LMT") self.model_function = df.LMTFunction( - np.mean(data), lmt, category_to_index, ignore_index, n_samples=len(data) + np.mean(data), + lmt, + category_to_index, + ignore_index, + n_samples=len(data), + param_names=self.param_names, ) return -- cgit v1.2.3