summaryrefslogtreecommitdiff
path: root/lib/cli.py
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-02-09 14:39:06 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-02-09 14:39:06 +0100
commit80c92831efa7c13c7dab629e58068d8b4855e5db (patch)
treea88e451e6941fbf4e89aa7277695aaee1ace2c19 /lib/cli.py
parent33b11ea1e112166425156a748fc26ea676005c71 (diff)
Implement --show-model=param for XGB
Diffstat (limited to 'lib/cli.py')
-rw-r--r--lib/cli.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/lib/cli.py b/lib/cli.py
index 2e797d6..1b6cb06 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -128,6 +128,11 @@ def print_cartinfo(prefix, info, feature_names):
_print_cartinfo(prefix, info.to_json(feature_names=feature_names), feature_names)
+def print_xgbinfo(prefix, info, feature_names):
+ for i, tree in enumerate(info.to_json(feature_names=feature_names)):
+ _print_cartinfo(prefix + f"tree{i:03d} :", tree, feature_names)
+
+
def print_lmtinfo(prefix, info, feature_names):
_print_lmtinfo(prefix, info.to_json(feature_names=feature_names))
@@ -210,6 +215,8 @@ def print_model(prefix, info, feature_names):
print_splitinfo(feature_names, info, prefix)
elif type(info) is df.LMTFunction:
print_lmtinfo(prefix, info, feature_names)
+ elif type(info) is df.XGBoostFunction:
+ print_xgbinfo(prefix, info, feature_names)
else:
print(f"{prefix}: {type(info)} UNIMPLEMENTED")