summaryrefslogtreecommitdiff
path: root/lib/cli.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/cli.py')
-rw-r--r--lib/cli.py50
1 files changed, 18 insertions, 32 deletions
diff --git a/lib/cli.py b/lib/cli.py
index 1b6cb06..3da6fce 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -124,17 +124,17 @@ def print_staticinfo(prefix, info):
print(f"{prefix}: {info.value}")
-def print_cartinfo(prefix, info, feature_names):
- _print_cartinfo(prefix, info.to_json(feature_names=feature_names), feature_names)
+def print_cartinfo(prefix, info):
+ _print_cartinfo(prefix, info.to_json())
-def print_xgbinfo(prefix, info, feature_names):
- for i, tree in enumerate(info.to_json(feature_names=feature_names)):
- _print_cartinfo(prefix + f"tree{i:03d} :", tree, feature_names)
+def print_xgbinfo(prefix, info):
+ for i, tree in enumerate(info.to_json()):
+ _print_cartinfo(prefix + f"tree{i:03d} :", tree)
-def print_lmtinfo(prefix, info, feature_names):
- _print_lmtinfo(prefix, info.to_json(feature_names=feature_names))
+def print_lmtinfo(prefix, info):
+ _print_lmtinfo(prefix, info.to_json())
def _print_lmtinfo(prefix, model):
@@ -157,41 +157,27 @@ def _print_lmtinfo(prefix, model):
print(f"{prefix}: {model_function}")
-def _print_cartinfo(prefix, model, feature_names):
+def _print_cartinfo(prefix, model):
if model["type"] == "static":
print(f"""{prefix}: {model["value"]}""")
else:
_print_cartinfo(
f"""{prefix} {model["paramName"]}≤{model["threshold"]} """,
model["left"],
- feature_names,
)
_print_cartinfo(
f"""{prefix} {model["paramName"]}>{model["threshold"]} """,
model["right"],
- feature_names,
)
-def print_splitinfo(param_names, info, prefix=""):
+def print_splitinfo(info, prefix=""):
if type(info) is df.SplitFunction:
for k, v in info.child.items():
- if info.param_index < len(param_names):
- param_name = param_names[info.param_index]
- else:
- param_name = f"arg{info.param_index - len(param_names)}"
- print_splitinfo(param_names, v, f"{prefix} {param_name}={k}")
+ print_splitinfo(v, f"{prefix} {info.param_name}={k}")
elif type(info) is df.ScalarSplitFunction:
- if info.param_index < len(param_names):
- param_name = param_names[info.param_index]
- else:
- param_name = f"arg{info.param_index - len(param_names)}"
- print_splitinfo(
- param_names, info.child_le, f"{prefix} {param_name}≤{info.threshold}"
- )
- print_splitinfo(
- param_names, info.child_gt, f"{prefix} {param_name}>{info.threshold}"
- )
+ print_splitinfo(info.child_le, f"{prefix} {info.param_name}≤{info.threshold}")
+ print_splitinfo(info.child_gt, f"{prefix} {info.param_name}>{info.threshold}")
elif type(info) is df.AnalyticFunction:
print_analyticinfo(prefix, info)
elif type(info) is df.StaticFunction:
@@ -200,7 +186,7 @@ def print_splitinfo(param_names, info, prefix=""):
print(f"{prefix}: UNKNOWN {type(info)}")
-def print_model(prefix, info, feature_names):
+def print_model(prefix, info):
if type(info) is df.StaticFunction:
print_staticinfo(prefix, info)
elif type(info) is df.AnalyticFunction:
@@ -208,15 +194,15 @@ def print_model(prefix, info, feature_names):
elif type(info) is df.FOLFunction:
print_analyticinfo(prefix, info)
elif type(info) is df.CARTFunction:
- print_cartinfo(prefix, info, feature_names)
+ print_cartinfo(prefix, info)
elif type(info) is df.SplitFunction:
- print_splitinfo(feature_names, info, prefix)
+ print_splitinfo(info, prefix)
elif type(info) is df.ScalarSplitFunction:
- print_splitinfo(feature_names, info, prefix)
+ print_splitinfo(info, prefix)
elif type(info) is df.LMTFunction:
- print_lmtinfo(prefix, info, feature_names)
+ print_lmtinfo(prefix, info)
elif type(info) is df.XGBoostFunction:
- print_xgbinfo(prefix, info, feature_names)
+ print_xgbinfo(prefix, info)
else:
print(f"{prefix}: {type(info)} UNIMPLEMENTED")