diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-16 13:37:16 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-16 13:37:16 +0100 |
commit | fa147a81097d3629afc879aeba13729a8b302ae4 (patch) | |
tree | 264709b1bb67f63038ff1b053823ade83106a2ae | |
parent | e654b3dd34c32cb5ae0320e1064bf7934d500e85 (diff) |
Implement --show-model=param for LMT
-rw-r--r-- | lib/cli.py | 26 | ||||
-rw-r--r-- | lib/functions.py | 29 |
2 files changed, 55 insertions, 0 deletions
@@ -103,6 +103,30 @@ def print_cartinfo(prefix, info, feature_names): _print_cartinfo(prefix, info.to_json(feature_names=feature_names), feature_names) +def print_lmtinfo(prefix, info, feature_names): + _print_lmtinfo(prefix, info.to_json(feature_names=feature_names)) + + +def _print_lmtinfo(prefix, model): + if model["type"] == "static": + print(f"""{prefix}: {model["value"]}""") + elif model["type"] == "scalarSplit": + _print_lmtinfo( + f"""{prefix} {model["paramName"]}≤{model["paramDecisionValue"]} """, + model["left"], + ) + _print_lmtinfo( + f"""{prefix} {model["paramName"]}>{model["paramDecisionValue"]} """, + model["right"], + ) + else: + model_function = model["functionStr"].removeprefix("0 + ") + for i, coef in enumerate(model["regressionModel"]): + model_function = model_function.replace(f"regression_arg({i})", str(coef)) + model_function = model_function.replace("+ -", "- ") + print(f"{prefix}: {model_function}") + + def _print_cartinfo(prefix, model, feature_names): if model["type"] == "static": print(f"""{prefix}: {model["value"]}""") @@ -146,6 +170,8 @@ def print_model(prefix, info, feature_names): print_cartinfo(prefix, info, feature_names) elif type(info) is df.SplitFunction: print_splitinfo(feature_names, info, prefix) + elif type(info) is df.LMTFunction: + print_lmtinfo(prefix, info, feature_names) else: print(f"{prefix}: {type(info)} UNIMPLEMENTED") diff --git a/lib/functions.py b/lib/functions.py index d2ef578..4be32c6 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -618,6 +618,35 @@ class LMTFunction(SKLearnRegressionFunction): def get_max_depth(self): return max(map(len, self.regressor._leaves.keys())) + 1 + def to_json(self, feature_names=None, **kwargs): + self.feature_names = feature_names + ret = super().to_json(**kwargs) + ret.update(self.recurse_(self.regressor.summary(), 0)) + return ret + + def recurse_(self, node_hash, node_index): + node = node_hash[node_index] + sub_data = dict() + if "th" in node: + return { + "type": "scalarSplit", + "paramName": self.feature_names[node["col"]], + "paramDecisionValue": node["th"], + "left": self.recurse_(node_hash, node["children"][0]), + "right": self.recurse_(node_hash, node["children"][1]), + } + model = node["models"] + fs = "0 + regression_arg(0)" + for i, coef in enumerate(model.coef_): + if coef: + fs += f" + regression_arg({i+1}) * parameter({self.feature_names[i]})" + return { + "type": "analytic", + "functionStr": fs, + "parameterNames": self.feature_names, + "regressionModel": [model.intercept_] + list(model.coef_), + } + class XGBoostFunction(SKLearnRegressionFunction): def to_json(self, feature_names=None, **kwargs): |