diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-10 10:55:56 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-10 10:55:56 +0100 |
commit | 0c3f350a577cfb1b36d45707ae3f36c2fe0d46ba (patch) | |
tree | 36ca0e745fdd7fcd4d44a94ecb89cabbb9b24268 /lib/cli.py | |
parent | eff2256fc529245e302b45844c651ff403c025bf (diff) |
refactor --show-model=param into lib/cli.py
Diffstat (limited to 'lib/cli.py')
-rw-r--r-- | lib/cli.py | 36 |
1 files changed, 23 insertions, 13 deletions
@@ -1,11 +1,6 @@ #!/usr/bin/env python3 -from dfatool.functions import ( - SplitFunction, - AnalyticFunction, - StaticFunction, - FOLFunction, -) +import dfatool.functions as df import dfatool.plotter import logging import numpy as np @@ -115,21 +110,36 @@ def _print_cartinfo(prefix, model, feature_names): def print_splitinfo(param_names, info, prefix=""): - if type(info) is SplitFunction: + if type(info) is df.SplitFunction: for k, v in info.child.items(): if info.param_index < len(param_names): param_name = param_names[info.param_index] else: param_name = f"arg{info.param_index - len(param_names)}" print_splitinfo(param_names, v, f"{prefix} {param_name}={k}") - elif type(info) is AnalyticFunction: + elif type(info) is df.AnalyticFunction: print_analyticinfo(prefix, info) - elif type(info) is StaticFunction: + elif type(info) is df.StaticFunction: print(f"{prefix}: {info.value}") else: print(f"{prefix}: UNKNOWN") +def print_model(prefix, info, feature_names): + if type(info) is df.StaticFunction: + print_staticinfo(prefix, info) + elif type(info) is df.AnalyticFunction: + print_analyticinfo(prefix, info) + elif type(info) is df.FOLFunction: + print_analyticinfo(prefix, info) + elif type(info) is df.CARTFunction: + print_cartinfo(prefix, info, feature_names) + elif type(info) is df.SplitFunction: + print_splitinfo(feature_names, info, prefix) + else: + print(f"{prefix}: {type(info)} UNIMPLEMENTED") + + def print_model_size(model): for name in model.names: for attribute in model.attributes(name): @@ -196,13 +206,13 @@ def model_quality_table( info is None or ( key != "energy_Pt" - and type(info(key, attr)) is not StaticFunction + and type(info(key, attr)) is not df.StaticFunction ) or ( key == "energy_Pt" and ( - type(info(key, "power")) is not StaticFunction - or type(info(key, "duration")) is not StaticFunction + type(info(key, "power")) is not df.StaticFunction + or type(info(key, "duration")) is not df.StaticFunction ) ) ): @@ -210,7 +220,7 @@ def model_quality_table( buf += format_quality_measures(result, error_metric=error_metric) else: buf += f"""{"----":>7s} """ - if type(model_info(key, attr)) is not StaticFunction: + if type(model_info(key, attr)) is not df.StaticFunction: if model[key][attr]["mae"] > static[key][attr]["mae"]: buf += " :-(" elif ( |