diff options
Diffstat (limited to 'lib/functions.py')
-rw-r--r-- | lib/functions.py | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py index d2ef578..4be32c6 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -618,6 +618,35 @@ class LMTFunction(SKLearnRegressionFunction): def get_max_depth(self): return max(map(len, self.regressor._leaves.keys())) + 1 + def to_json(self, feature_names=None, **kwargs): + self.feature_names = feature_names + ret = super().to_json(**kwargs) + ret.update(self.recurse_(self.regressor.summary(), 0)) + return ret + + def recurse_(self, node_hash, node_index): + node = node_hash[node_index] + sub_data = dict() + if "th" in node: + return { + "type": "scalarSplit", + "paramName": self.feature_names[node["col"]], + "paramDecisionValue": node["th"], + "left": self.recurse_(node_hash, node["children"][0]), + "right": self.recurse_(node_hash, node["children"][1]), + } + model = node["models"] + fs = "0 + regression_arg(0)" + for i, coef in enumerate(model.coef_): + if coef: + fs += f" + regression_arg({i+1}) * parameter({self.feature_names[i]})" + return { + "type": "analytic", + "functionStr": fs, + "parameterNames": self.feature_names, + "regressionModel": [model.intercept_] + list(model.coef_), + } + class XGBoostFunction(SKLearnRegressionFunction): def to_json(self, feature_names=None, **kwargs): |