diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-09 08:17:27 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-09 08:17:27 +0100 |
commit | c14e6a92e2239d852d52338af41a4152988e7960 (patch) | |
tree | f274ec55b7153222a3012814aab776ae56bd937c /lib | |
parent | 5c38f8cca410464cb24e45a52e64409debd81769 (diff) |
LMT export: set paramIndex; use correct parameterNames in analytic nodes
Diffstat (limited to 'lib')
-rw-r--r-- | lib/functions.py | 7 | ||||
-rw-r--r-- | lib/parameters.py | 7 |
2 files changed, 12 insertions, 2 deletions
diff --git a/lib/functions.py b/lib/functions.py index 8df15de..23bf997 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -589,7 +589,11 @@ class SKLearnRegressionFunction(ModelFunction): has_eval_arr = True def __init__(self, value, regressor, categorial_to_index, ignore_index, **kwargs): + # Needed for JSON export + self.param_names = kwargs.pop("param_names", None) + super().__init__(value, **kwargs) + self.regressor = regressor self.categorial_to_index = categorial_to_index self.ignore_index = ignore_index @@ -763,6 +767,7 @@ class LMTFunction(SKLearnRegressionFunction): return { "type": "scalarSplit", "paramName": self.feature_names[node["col"]], + "paramIndex": self.param_names.index(self.feature_names[node["col"]]), "threshold": node["th"], "left": self.recurse_(node_hash, node["children"][0]), "right": self.recurse_(node_hash, node["children"][1]), @@ -775,7 +780,7 @@ class LMTFunction(SKLearnRegressionFunction): return { "type": "analytic", "functionStr": fs, - "parameterNames": self.feature_names, + "parameterNames": self.param_names, "regressionModel": [model.intercept_] + list(model.coef_), } diff --git a/lib/parameters.py b/lib/parameters.py index 210cbe3..ad23780 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -1240,7 +1240,12 @@ class ModelAttribute: return logger.debug("Fitted LMT") self.model_function = df.LMTFunction( - np.mean(data), lmt, category_to_index, ignore_index, n_samples=len(data) + np.mean(data), + lmt, + category_to_index, + ignore_index, + n_samples=len(data), + param_names=self.param_names, ) return |