summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py
index d2ef578..4be32c6 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -618,6 +618,35 @@ class LMTFunction(SKLearnRegressionFunction):
def get_max_depth(self):
return max(map(len, self.regressor._leaves.keys())) + 1
+ def to_json(self, feature_names=None, **kwargs):
+ self.feature_names = feature_names
+ ret = super().to_json(**kwargs)
+ ret.update(self.recurse_(self.regressor.summary(), 0))
+ return ret
+
+ def recurse_(self, node_hash, node_index):
+ node = node_hash[node_index]
+ sub_data = dict()
+ if "th" in node:
+ return {
+ "type": "scalarSplit",
+ "paramName": self.feature_names[node["col"]],
+ "paramDecisionValue": node["th"],
+ "left": self.recurse_(node_hash, node["children"][0]),
+ "right": self.recurse_(node_hash, node["children"][1]),
+ }
+ model = node["models"]
+ fs = "0 + regression_arg(0)"
+ for i, coef in enumerate(model.coef_):
+ if coef:
+ fs += f" + regression_arg({i+1}) * parameter({self.feature_names[i]})"
+ return {
+ "type": "analytic",
+ "functionStr": fs,
+ "parameterNames": self.feature_names,
+ "regressionModel": [model.intercept_] + list(model.coef_),
+ }
+
class XGBoostFunction(SKLearnRegressionFunction):
def to_json(self, feature_names=None, **kwargs):