summaryrefslogtreecommitdiff
path: root/lib/cli.py
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-02-06 09:09:25 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-02-06 09:09:25 +0100
commitb04085507479962ccd40f962b62e795ad54998f0 (patch)
tree813e8fc5420a79cd9ca2a0c08ae2f40ee1a9bb90 /lib/cli.py
parent87841177a29211c70130b177e4c6f42a3b9cd5ca (diff)
Add ScalarSplitFunction support for manual scalar splits in RMT
Diffstat (limited to 'lib/cli.py')
-rw-r--r--lib/cli.py41
1 files changed, 33 insertions, 8 deletions
diff --git a/lib/cli.py b/lib/cli.py
index 7997482..0a5f79a 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -39,15 +39,27 @@ def print_static(model, static_model, name, attribute, with_dependence=False):
unit = "µs"
elif attribute == "substate_count":
unit = "su"
- print(
- "{:10s}: {:28s} : {:.2f} {:s} ({:.2f})".format(
- name,
- attribute,
- static_model(name, attribute),
- unit,
- model.attr_by_name[name][attribute].stats.generic_param_dependence_ratio(),
+ if model.attr_by_name[name][attribute].stats:
+ print(
+ "{:10s}: {:28s} : {:.2f} {:s} ({:.2f})".format(
+ name,
+ attribute,
+ static_model(name, attribute),
+ unit,
+ model.attr_by_name[name][
+ attribute
+ ].stats.generic_param_dependence_ratio(),
+ )
+ )
+ else:
+ print(
+ "{:10s}: {:28s} : {:.2f} {:s}".format(
+ name,
+ attribute,
+ static_model(name, attribute),
+ unit,
+ )
)
- )
if with_dependence:
for param in model.parameters:
print(
@@ -164,6 +176,17 @@ def print_splitinfo(param_names, info, prefix=""):
else:
param_name = f"arg{info.param_index - len(param_names)}"
print_splitinfo(param_names, v, f"{prefix} {param_name}={k}")
+ if type(info) is df.ScalarSplitFunction:
+ if info.param_index < len(param_names):
+ param_name = param_names[info.param_index]
+ else:
+ param_name = f"arg{info.param_index - len(param_names)}"
+ print_splitinfo(
+ param_names, info.child_le, f"{prefix} {param_name}≤{info.threshold}"
+ )
+ print_splitinfo(
+ param_names, info.child_gt, f"{prefix} {param_name}>{info.threshold}"
+ )
elif type(info) is df.AnalyticFunction:
print_analyticinfo(prefix, info)
elif type(info) is df.StaticFunction:
@@ -183,6 +206,8 @@ def print_model(prefix, info, feature_names):
print_cartinfo(prefix, info, feature_names)
elif type(info) is df.SplitFunction:
print_splitinfo(feature_names, info, prefix)
+ elif type(info) is df.ScalarSplitFunction:
+ print_splitinfo(feature_names, info, prefix)
elif type(info) is df.LMTFunction:
print_lmtinfo(prefix, info, feature_names)
else: