diff options
Diffstat (limited to 'lib/cli.py')
-rw-r--r-- | lib/cli.py | 50 |
1 files changed, 18 insertions, 32 deletions
@@ -124,17 +124,17 @@ def print_staticinfo(prefix, info): print(f"{prefix}: {info.value}") -def print_cartinfo(prefix, info, feature_names): - _print_cartinfo(prefix, info.to_json(feature_names=feature_names), feature_names) +def print_cartinfo(prefix, info): + _print_cartinfo(prefix, info.to_json()) -def print_xgbinfo(prefix, info, feature_names): - for i, tree in enumerate(info.to_json(feature_names=feature_names)): - _print_cartinfo(prefix + f"tree{i:03d} :", tree, feature_names) +def print_xgbinfo(prefix, info): + for i, tree in enumerate(info.to_json()): + _print_cartinfo(prefix + f"tree{i:03d} :", tree) -def print_lmtinfo(prefix, info, feature_names): - _print_lmtinfo(prefix, info.to_json(feature_names=feature_names)) +def print_lmtinfo(prefix, info): + _print_lmtinfo(prefix, info.to_json()) def _print_lmtinfo(prefix, model): @@ -157,41 +157,27 @@ def _print_lmtinfo(prefix, model): print(f"{prefix}: {model_function}") -def _print_cartinfo(prefix, model, feature_names): +def _print_cartinfo(prefix, model): if model["type"] == "static": print(f"""{prefix}: {model["value"]}""") else: _print_cartinfo( f"""{prefix} {model["paramName"]}≤{model["threshold"]} """, model["left"], - feature_names, ) _print_cartinfo( f"""{prefix} {model["paramName"]}>{model["threshold"]} """, model["right"], - feature_names, ) -def print_splitinfo(param_names, info, prefix=""): +def print_splitinfo(info, prefix=""): if type(info) is df.SplitFunction: for k, v in info.child.items(): - if info.param_index < len(param_names): - param_name = param_names[info.param_index] - else: - param_name = f"arg{info.param_index - len(param_names)}" - print_splitinfo(param_names, v, f"{prefix} {param_name}={k}") + print_splitinfo(v, f"{prefix} {info.param_name}={k}") elif type(info) is df.ScalarSplitFunction: - if info.param_index < len(param_names): - param_name = param_names[info.param_index] - else: - param_name = f"arg{info.param_index - len(param_names)}" - print_splitinfo( - param_names, info.child_le, f"{prefix} {param_name}≤{info.threshold}" - ) - print_splitinfo( - param_names, info.child_gt, f"{prefix} {param_name}>{info.threshold}" - ) + print_splitinfo(info.child_le, f"{prefix} {info.param_name}≤{info.threshold}") + print_splitinfo(info.child_gt, f"{prefix} {info.param_name}>{info.threshold}") elif type(info) is df.AnalyticFunction: print_analyticinfo(prefix, info) elif type(info) is df.StaticFunction: @@ -200,7 +186,7 @@ def print_splitinfo(param_names, info, prefix=""): print(f"{prefix}: UNKNOWN {type(info)}") -def print_model(prefix, info, feature_names): +def print_model(prefix, info): if type(info) is df.StaticFunction: print_staticinfo(prefix, info) elif type(info) is df.AnalyticFunction: @@ -208,15 +194,15 @@ def print_model(prefix, info, feature_names): elif type(info) is df.FOLFunction: print_analyticinfo(prefix, info) elif type(info) is df.CARTFunction: - print_cartinfo(prefix, info, feature_names) + print_cartinfo(prefix, info) elif type(info) is df.SplitFunction: - print_splitinfo(feature_names, info, prefix) + print_splitinfo(info, prefix) elif type(info) is df.ScalarSplitFunction: - print_splitinfo(feature_names, info, prefix) + print_splitinfo(info, prefix) elif type(info) is df.LMTFunction: - print_lmtinfo(prefix, info, feature_names) + print_lmtinfo(prefix, info) elif type(info) is df.XGBoostFunction: - print_xgbinfo(prefix, info, feature_names) + print_xgbinfo(prefix, info) else: print(f"{prefix}: {type(info)} UNIMPLEMENTED") |