summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-01-16 13:37:16 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-01-16 13:37:16 +0100
commitfa147a81097d3629afc879aeba13729a8b302ae4 (patch)
tree264709b1bb67f63038ff1b053823ade83106a2ae
parente654b3dd34c32cb5ae0320e1064bf7934d500e85 (diff)
Implement --show-model=param for LMT
-rw-r--r--lib/cli.py26
-rw-r--r--lib/functions.py29
2 files changed, 55 insertions, 0 deletions
diff --git a/lib/cli.py b/lib/cli.py
index f9341f8..7eaaa61 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -103,6 +103,30 @@ def print_cartinfo(prefix, info, feature_names):
_print_cartinfo(prefix, info.to_json(feature_names=feature_names), feature_names)
+def print_lmtinfo(prefix, info, feature_names):
+ _print_lmtinfo(prefix, info.to_json(feature_names=feature_names))
+
+
+def _print_lmtinfo(prefix, model):
+ if model["type"] == "static":
+ print(f"""{prefix}: {model["value"]}""")
+ elif model["type"] == "scalarSplit":
+ _print_lmtinfo(
+ f"""{prefix} {model["paramName"]}≤{model["paramDecisionValue"]} """,
+ model["left"],
+ )
+ _print_lmtinfo(
+ f"""{prefix} {model["paramName"]}>{model["paramDecisionValue"]} """,
+ model["right"],
+ )
+ else:
+ model_function = model["functionStr"].removeprefix("0 + ")
+ for i, coef in enumerate(model["regressionModel"]):
+ model_function = model_function.replace(f"regression_arg({i})", str(coef))
+ model_function = model_function.replace("+ -", "- ")
+ print(f"{prefix}: {model_function}")
+
+
def _print_cartinfo(prefix, model, feature_names):
if model["type"] == "static":
print(f"""{prefix}: {model["value"]}""")
@@ -146,6 +170,8 @@ def print_model(prefix, info, feature_names):
print_cartinfo(prefix, info, feature_names)
elif type(info) is df.SplitFunction:
print_splitinfo(feature_names, info, prefix)
+ elif type(info) is df.LMTFunction:
+ print_lmtinfo(prefix, info, feature_names)
else:
print(f"{prefix}: {type(info)} UNIMPLEMENTED")
diff --git a/lib/functions.py b/lib/functions.py
index d2ef578..4be32c6 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -618,6 +618,35 @@ class LMTFunction(SKLearnRegressionFunction):
def get_max_depth(self):
return max(map(len, self.regressor._leaves.keys())) + 1
+ def to_json(self, feature_names=None, **kwargs):
+ self.feature_names = feature_names
+ ret = super().to_json(**kwargs)
+ ret.update(self.recurse_(self.regressor.summary(), 0))
+ return ret
+
+ def recurse_(self, node_hash, node_index):
+ node = node_hash[node_index]
+ sub_data = dict()
+ if "th" in node:
+ return {
+ "type": "scalarSplit",
+ "paramName": self.feature_names[node["col"]],
+ "paramDecisionValue": node["th"],
+ "left": self.recurse_(node_hash, node["children"][0]),
+ "right": self.recurse_(node_hash, node["children"][1]),
+ }
+ model = node["models"]
+ fs = "0 + regression_arg(0)"
+ for i, coef in enumerate(model.coef_):
+ if coef:
+ fs += f" + regression_arg({i+1}) * parameter({self.feature_names[i]})"
+ return {
+ "type": "analytic",
+ "functionStr": fs,
+ "parameterNames": self.feature_names,
+ "regressionModel": [model.intercept_] + list(model.coef_),
+ }
+
class XGBoostFunction(SKLearnRegressionFunction):
def to_json(self, feature_names=None, **kwargs):