diff options
Diffstat (limited to 'lib/functions.py')
-rw-r--r-- | lib/functions.py | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/lib/functions.py b/lib/functions.py index 8df15de..23bf997 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -589,7 +589,11 @@ class SKLearnRegressionFunction(ModelFunction): has_eval_arr = True def __init__(self, value, regressor, categorial_to_index, ignore_index, **kwargs): + # Needed for JSON export + self.param_names = kwargs.pop("param_names", None) + super().__init__(value, **kwargs) + self.regressor = regressor self.categorial_to_index = categorial_to_index self.ignore_index = ignore_index @@ -763,6 +767,7 @@ class LMTFunction(SKLearnRegressionFunction): return { "type": "scalarSplit", "paramName": self.feature_names[node["col"]], + "paramIndex": self.param_names.index(self.feature_names[node["col"]]), "threshold": node["th"], "left": self.recurse_(node_hash, node["children"][0]), "right": self.recurse_(node_hash, node["children"][1]), @@ -775,7 +780,7 @@ class LMTFunction(SKLearnRegressionFunction): return { "type": "analytic", "functionStr": fs, - "parameterNames": self.feature_names, + "parameterNames": self.param_names, "regressionModel": [model.intercept_] + list(model.coef_), } |