summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 8df15de..23bf997 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -589,7 +589,11 @@ class SKLearnRegressionFunction(ModelFunction):
has_eval_arr = True
def __init__(self, value, regressor, categorial_to_index, ignore_index, **kwargs):
+ # Needed for JSON export
+ self.param_names = kwargs.pop("param_names", None)
+
super().__init__(value, **kwargs)
+
self.regressor = regressor
self.categorial_to_index = categorial_to_index
self.ignore_index = ignore_index
@@ -763,6 +767,7 @@ class LMTFunction(SKLearnRegressionFunction):
return {
"type": "scalarSplit",
"paramName": self.feature_names[node["col"]],
+ "paramIndex": self.param_names.index(self.feature_names[node["col"]]),
"threshold": node["th"],
"left": self.recurse_(node_hash, node["children"][0]),
"right": self.recurse_(node_hash, node["children"][1]),
@@ -775,7 +780,7 @@ class LMTFunction(SKLearnRegressionFunction):
return {
"type": "analytic",
"functionStr": fs,
- "parameterNames": self.feature_names,
+ "parameterNames": self.param_names,
"regressionModel": [model.intercept_] + list(model.coef_),
}