diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-10 10:55:56 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-10 10:55:56 +0100 |
commit | 0c3f350a577cfb1b36d45707ae3f36c2fe0d46ba (patch) | |
tree | 36ca0e745fdd7fcd4d44a94ecb89cabbb9b24268 | |
parent | eff2256fc529245e302b45844c651ff403c025bf (diff) |
refactor --show-model=param into lib/cli.py
-rwxr-xr-x | bin/analyze-archive.py | 32 | ||||
-rwxr-xr-x | bin/analyze-kconfig.py | 17 | ||||
-rwxr-xr-x | bin/analyze-log.py | 15 | ||||
-rw-r--r-- | lib/cli.py | 36 |
4 files changed, 35 insertions, 65 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py index ffc0d67..c91150d 100755 --- a/bin/analyze-archive.py +++ b/bin/analyze-archive.py @@ -679,35 +679,15 @@ if __name__ == "__main__": for state in model.states: for attribute in model.attributes(state): info = param_info(state, attribute) - if type(info) is df.AnalyticFunction: - dfatool.cli.print_analyticinfo(f"{state:10s} {attribute:15s}", info) - elif type(info) is df.CARTFunction: - dfatool.cli.print_cartinfo( - f"{state:10s} {attribute:15s}", info, model.parameters - ) - elif type(info) is df.SplitFunction: - dfatool.cli.print_splitinfo( - model.parameters, info, f"{state:10s} {attribute:15s}" - ) - elif type(info) is df.StaticFunction: - dfatool.cli.print_staticinfo(f"{state:10s} {attribute:15s}", info) - elif type(info) is df.SubstateFunction: - print(f"{state:10s} {attribute:15s}: Substate (TODO)") + dfatool.cli.print_model( + f"{state:10s} {attribute:15s}", info, model.parameters + ) for trans in model.transitions: for attribute in model.attributes(trans): info = param_info(trans, attribute) - if type(info) is df.AnalyticFunction: - dfatool.cli.print_analyticinfo(f"{trans:10s} {attribute:15s}", info) - elif type(info) is df.CARTFunction: - dfatool.cli.print_cartinfo( - f"{trans:10s} {attribute:15s}", info, model.parameters - ) - elif type(info) is df.SplitFunction: - dfatool.cli.print_splitinfo( - model.parameters, info, f"{trans:10s} {attribute:15s}" - ) - elif type(info) is df.SubstateFunction: - print(f"{state:10s} {attribute:15s}: Substate (TODO)") + dfatool.cli.print_model( + f"{trans:10s} {attribute:15s}", info, model.parameters + ) if args.with_substates: for submodel in model.submodel_by_name.values(): sub_param_model, sub_param_info = submodel.get_fitted() diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py index baf3cc7..098d6ba 100755 --- a/bin/analyze-kconfig.py +++ b/bin/analyze-kconfig.py @@ -528,20 +528,9 @@ def main(): for name in model.names: for attribute in model.attributes(name): info = param_info(name, attribute) - if type(info) is df.AnalyticFunction: - dfatool.cli.print_analyticinfo(f"{name:20s} {attribute:15s}", info) - elif type(info) is df.CARTFunction: - dfatool.cli.print_cartinfo( - f"{name:20s} {attribute:15s}", info, model.parameters - ) - elif type(info) is df.FOLFunction: - dfatool.cli.print_analyticinfo(f"{name:20s} {attribute:15s}", info) - elif type(info) is df.SplitFunction: - dfatool.cli.print_splitinfo( - model.parameters, info, f"{name:20s} {attribute:15s}" - ) - elif type(info) is df.StaticFunction: - dfatool.cli.print_staticinfo(f"{name:10s} {attribute:15s}", info) + dfatool.cli.print_model( + f"{name:20s} {attribute:15s}", info, model.parameters + ) if "table" in args.show_quality or "all" in args.show_quality: if xv_method is not None: diff --git a/bin/analyze-log.py b/bin/analyze-log.py index b1dc6ba..9c24641 100755 --- a/bin/analyze-log.py +++ b/bin/analyze-log.py @@ -227,18 +227,9 @@ def main(): for name in sorted(model.names): for attribute in sorted(model.attributes(name)): info = param_info(name, attribute) - if type(info) is df.AnalyticFunction: - dfatool.cli.print_analyticinfo(f"{name:10s} {attribute:15s}", info) - elif type(info) is df.CARTFunction: - dfatool.cli.print_cartinfo( - f"{name:10s} {attribute:15s}", info, model.parameters - ) - elif type(info) is df.SplitFunction: - dfatool.cli.print_splitinfo( - model.parameters, info, f"{name:10s} {attribute:15s}" - ) - elif type(info) is df.StaticFunction: - dfatool.cli.print_staticinfo(f"{state:10s} {attribute:15s}", info) + dfatool.cli.print_model( + f"{name:10s} {attribute:15s}", info, model.parameters + ) if "table" in args.show_quality or "all" in args.show_quality: if xv_method is not None: @@ -1,11 +1,6 @@ #!/usr/bin/env python3 -from dfatool.functions import ( - SplitFunction, - AnalyticFunction, - StaticFunction, - FOLFunction, -) +import dfatool.functions as df import dfatool.plotter import logging import numpy as np @@ -115,21 +110,36 @@ def _print_cartinfo(prefix, model, feature_names): def print_splitinfo(param_names, info, prefix=""): - if type(info) is SplitFunction: + if type(info) is df.SplitFunction: for k, v in info.child.items(): if info.param_index < len(param_names): param_name = param_names[info.param_index] else: param_name = f"arg{info.param_index - len(param_names)}" print_splitinfo(param_names, v, f"{prefix} {param_name}={k}") - elif type(info) is AnalyticFunction: + elif type(info) is df.AnalyticFunction: print_analyticinfo(prefix, info) - elif type(info) is StaticFunction: + elif type(info) is df.StaticFunction: print(f"{prefix}: {info.value}") else: print(f"{prefix}: UNKNOWN") +def print_model(prefix, info, feature_names): + if type(info) is df.StaticFunction: + print_staticinfo(prefix, info) + elif type(info) is df.AnalyticFunction: + print_analyticinfo(prefix, info) + elif type(info) is df.FOLFunction: + print_analyticinfo(prefix, info) + elif type(info) is df.CARTFunction: + print_cartinfo(prefix, info, feature_names) + elif type(info) is df.SplitFunction: + print_splitinfo(feature_names, info, prefix) + else: + print(f"{prefix}: {type(info)} UNIMPLEMENTED") + + def print_model_size(model): for name in model.names: for attribute in model.attributes(name): @@ -196,13 +206,13 @@ def model_quality_table( info is None or ( key != "energy_Pt" - and type(info(key, attr)) is not StaticFunction + and type(info(key, attr)) is not df.StaticFunction ) or ( key == "energy_Pt" and ( - type(info(key, "power")) is not StaticFunction - or type(info(key, "duration")) is not StaticFunction + type(info(key, "power")) is not df.StaticFunction + or type(info(key, "duration")) is not df.StaticFunction ) ) ): @@ -210,7 +220,7 @@ def model_quality_table( buf += format_quality_measures(result, error_metric=error_metric) else: buf += f"""{"----":>7s} """ - if type(model_info(key, attr)) is not StaticFunction: + if type(model_info(key, attr)) is not df.StaticFunction: if model[key][attr]["mae"] > static[key][attr]["mae"]: buf += " :-(" elif ( |