summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-01-10 09:41:50 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-01-10 09:41:50 +0100
commit5adc39c8de1fbda6e1ba9abcaaaf17516eb046ae (patch)
treef52dbe13e5cf9fa70f38a8d0e8d75020856ce781 /lib
parent72fc17a4dcff42bedb21b456f7da63065a72fd9f (diff)
--show-model=param: add CART support
Diffstat (limited to 'lib')
-rw-r--r--lib/cli.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/lib/cli.py b/lib/cli.py
index 314c6dd..a1655c6 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -94,6 +94,26 @@ def print_staticinfo(prefix, info):
print(f"{prefix}: {info.value}")
+def print_cartinfo(prefix, info, feature_names):
+ _print_cartinfo(prefix, info.to_json(feature_names=feature_names), feature_names)
+
+
+def _print_cartinfo(prefix, model, feature_names):
+ if model["type"] == "static":
+ print(f"""{prefix}: {model["value"]}""")
+ else:
+ _print_cartinfo(
+ f"""{prefix} {model["paramName"]}<{model["paramDecisionValue"]} """,
+ model["left"],
+ feature_names,
+ )
+ _print_cartinfo(
+ f"""{prefix} {model["paramName"]}≥{model["paramDecisionValue"]} """,
+ model["right"],
+ feature_names,
+ )
+
+
def print_splitinfo(param_names, info, prefix=""):
if type(info) is SplitFunction:
for k, v in info.child.items():