diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-06 09:09:25 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-06 09:09:25 +0100 |
commit | b04085507479962ccd40f962b62e795ad54998f0 (patch) | |
tree | 813e8fc5420a79cd9ca2a0c08ae2f40ee1a9bb90 /lib/cli.py | |
parent | 87841177a29211c70130b177e4c6f42a3b9cd5ca (diff) |
Add ScalarSplitFunction support for manual scalar splits in RMT
Diffstat (limited to 'lib/cli.py')
-rw-r--r-- | lib/cli.py | 41 |
1 files changed, 33 insertions, 8 deletions
@@ -39,15 +39,27 @@ def print_static(model, static_model, name, attribute, with_dependence=False): unit = "µs" elif attribute == "substate_count": unit = "su" - print( - "{:10s}: {:28s} : {:.2f} {:s} ({:.2f})".format( - name, - attribute, - static_model(name, attribute), - unit, - model.attr_by_name[name][attribute].stats.generic_param_dependence_ratio(), + if model.attr_by_name[name][attribute].stats: + print( + "{:10s}: {:28s} : {:.2f} {:s} ({:.2f})".format( + name, + attribute, + static_model(name, attribute), + unit, + model.attr_by_name[name][ + attribute + ].stats.generic_param_dependence_ratio(), + ) + ) + else: + print( + "{:10s}: {:28s} : {:.2f} {:s}".format( + name, + attribute, + static_model(name, attribute), + unit, + ) ) - ) if with_dependence: for param in model.parameters: print( @@ -164,6 +176,17 @@ def print_splitinfo(param_names, info, prefix=""): else: param_name = f"arg{info.param_index - len(param_names)}" print_splitinfo(param_names, v, f"{prefix} {param_name}={k}") + if type(info) is df.ScalarSplitFunction: + if info.param_index < len(param_names): + param_name = param_names[info.param_index] + else: + param_name = f"arg{info.param_index - len(param_names)}" + print_splitinfo( + param_names, info.child_le, f"{prefix} {param_name}≤{info.threshold}" + ) + print_splitinfo( + param_names, info.child_gt, f"{prefix} {param_name}>{info.threshold}" + ) elif type(info) is df.AnalyticFunction: print_analyticinfo(prefix, info) elif type(info) is df.StaticFunction: @@ -183,6 +206,8 @@ def print_model(prefix, info, feature_names): print_cartinfo(prefix, info, feature_names) elif type(info) is df.SplitFunction: print_splitinfo(feature_names, info, prefix) + elif type(info) is df.ScalarSplitFunction: + print_splitinfo(feature_names, info, prefix) elif type(info) is df.LMTFunction: print_lmtinfo(prefix, info, feature_names) else: |