summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-02-09 08:17:27 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-02-09 08:17:27 +0100
commitc14e6a92e2239d852d52338af41a4152988e7960 (patch)
treef274ec55b7153222a3012814aab776ae56bd937c /lib
parent5c38f8cca410464cb24e45a52e64409debd81769 (diff)
LMT export: set paramIndex; use correct parameterNames in analytic nodes
Diffstat (limited to 'lib')
-rw-r--r--lib/functions.py7
-rw-r--r--lib/parameters.py7
2 files changed, 12 insertions, 2 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 8df15de..23bf997 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -589,7 +589,11 @@ class SKLearnRegressionFunction(ModelFunction):
has_eval_arr = True
def __init__(self, value, regressor, categorial_to_index, ignore_index, **kwargs):
+ # Needed for JSON export
+ self.param_names = kwargs.pop("param_names", None)
+
super().__init__(value, **kwargs)
+
self.regressor = regressor
self.categorial_to_index = categorial_to_index
self.ignore_index = ignore_index
@@ -763,6 +767,7 @@ class LMTFunction(SKLearnRegressionFunction):
return {
"type": "scalarSplit",
"paramName": self.feature_names[node["col"]],
+ "paramIndex": self.param_names.index(self.feature_names[node["col"]]),
"threshold": node["th"],
"left": self.recurse_(node_hash, node["children"][0]),
"right": self.recurse_(node_hash, node["children"][1]),
@@ -775,7 +780,7 @@ class LMTFunction(SKLearnRegressionFunction):
return {
"type": "analytic",
"functionStr": fs,
- "parameterNames": self.feature_names,
+ "parameterNames": self.param_names,
"regressionModel": [model.intercept_] + list(model.coef_),
}
diff --git a/lib/parameters.py b/lib/parameters.py
index 210cbe3..ad23780 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -1240,7 +1240,12 @@ class ModelAttribute:
return
logger.debug("Fitted LMT")
self.model_function = df.LMTFunction(
- np.mean(data), lmt, category_to_index, ignore_index, n_samples=len(data)
+ np.mean(data),
+ lmt,
+ category_to_index,
+ ignore_index,
+ n_samples=len(data),
+ param_names=self.param_names,
)
return