diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-09 14:39:06 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-09 14:39:06 +0100 |
commit | 80c92831efa7c13c7dab629e58068d8b4855e5db (patch) | |
tree | a88e451e6941fbf4e89aa7277695aaee1ace2c19 /lib | |
parent | 33b11ea1e112166425156a748fc26ea676005c71 (diff) |
Implement --show-model=param for XGB
Diffstat (limited to 'lib')
-rw-r--r-- | lib/cli.py | 7 |
1 files changed, 7 insertions, 0 deletions
@@ -128,6 +128,11 @@ def print_cartinfo(prefix, info, feature_names): _print_cartinfo(prefix, info.to_json(feature_names=feature_names), feature_names) +def print_xgbinfo(prefix, info, feature_names): + for i, tree in enumerate(info.to_json(feature_names=feature_names)): + _print_cartinfo(prefix + f"tree{i:03d} :", tree, feature_names) + + def print_lmtinfo(prefix, info, feature_names): _print_lmtinfo(prefix, info.to_json(feature_names=feature_names)) @@ -210,6 +215,8 @@ def print_model(prefix, info, feature_names): print_splitinfo(feature_names, info, prefix) elif type(info) is df.LMTFunction: print_lmtinfo(prefix, info, feature_names) + elif type(info) is df.XGBoostFunction: + print_xgbinfo(prefix, info, feature_names) else: print(f"{prefix}: {type(info)} UNIMPLEMENTED") |