From 80c92831efa7c13c7dab629e58068d8b4855e5db Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Fri, 9 Feb 2024 14:39:06 +0100 Subject: Implement --show-model=param for XGB --- lib/cli.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'lib') diff --git a/lib/cli.py b/lib/cli.py index 2e797d6..1b6cb06 100644 --- a/lib/cli.py +++ b/lib/cli.py @@ -128,6 +128,11 @@ def print_cartinfo(prefix, info, feature_names): _print_cartinfo(prefix, info.to_json(feature_names=feature_names), feature_names) +def print_xgbinfo(prefix, info, feature_names): + for i, tree in enumerate(info.to_json(feature_names=feature_names)): + _print_cartinfo(prefix + f"tree{i:03d} :", tree, feature_names) + + def print_lmtinfo(prefix, info, feature_names): _print_lmtinfo(prefix, info.to_json(feature_names=feature_names)) @@ -210,6 +215,8 @@ def print_model(prefix, info, feature_names): print_splitinfo(feature_names, info, prefix) elif type(info) is df.LMTFunction: print_lmtinfo(prefix, info, feature_names) + elif type(info) is df.XGBoostFunction: + print_xgbinfo(prefix, info, feature_names) else: print(f"{prefix}: {type(info)} UNIMPLEMENTED") -- cgit v1.2.3