summaryrefslogtreecommitdiff
path: root/lib/cli.py
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-01-10 10:55:56 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-01-10 10:55:56 +0100
commit0c3f350a577cfb1b36d45707ae3f36c2fe0d46ba (patch)
tree36ca0e745fdd7fcd4d44a94ecb89cabbb9b24268 /lib/cli.py
parenteff2256fc529245e302b45844c651ff403c025bf (diff)
refactor --show-model=param into lib/cli.py
Diffstat (limited to 'lib/cli.py')
-rw-r--r--lib/cli.py36
1 files changed, 23 insertions, 13 deletions
diff --git a/lib/cli.py b/lib/cli.py
index 5accc1f..0ebe42d 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -1,11 +1,6 @@
#!/usr/bin/env python3
-from dfatool.functions import (
- SplitFunction,
- AnalyticFunction,
- StaticFunction,
- FOLFunction,
-)
+import dfatool.functions as df
import dfatool.plotter
import logging
import numpy as np
@@ -115,21 +110,36 @@ def _print_cartinfo(prefix, model, feature_names):
def print_splitinfo(param_names, info, prefix=""):
- if type(info) is SplitFunction:
+ if type(info) is df.SplitFunction:
for k, v in info.child.items():
if info.param_index < len(param_names):
param_name = param_names[info.param_index]
else:
param_name = f"arg{info.param_index - len(param_names)}"
print_splitinfo(param_names, v, f"{prefix} {param_name}={k}")
- elif type(info) is AnalyticFunction:
+ elif type(info) is df.AnalyticFunction:
print_analyticinfo(prefix, info)
- elif type(info) is StaticFunction:
+ elif type(info) is df.StaticFunction:
print(f"{prefix}: {info.value}")
else:
print(f"{prefix}: UNKNOWN")
+def print_model(prefix, info, feature_names):
+ if type(info) is df.StaticFunction:
+ print_staticinfo(prefix, info)
+ elif type(info) is df.AnalyticFunction:
+ print_analyticinfo(prefix, info)
+ elif type(info) is df.FOLFunction:
+ print_analyticinfo(prefix, info)
+ elif type(info) is df.CARTFunction:
+ print_cartinfo(prefix, info, feature_names)
+ elif type(info) is df.SplitFunction:
+ print_splitinfo(feature_names, info, prefix)
+ else:
+ print(f"{prefix}: {type(info)} UNIMPLEMENTED")
+
+
def print_model_size(model):
for name in model.names:
for attribute in model.attributes(name):
@@ -196,13 +206,13 @@ def model_quality_table(
info is None
or (
key != "energy_Pt"
- and type(info(key, attr)) is not StaticFunction
+ and type(info(key, attr)) is not df.StaticFunction
)
or (
key == "energy_Pt"
and (
- type(info(key, "power")) is not StaticFunction
- or type(info(key, "duration")) is not StaticFunction
+ type(info(key, "power")) is not df.StaticFunction
+ or type(info(key, "duration")) is not df.StaticFunction
)
)
):
@@ -210,7 +220,7 @@ def model_quality_table(
buf += format_quality_measures(result, error_metric=error_metric)
else:
buf += f"""{"----":>7s} """
- if type(model_info(key, attr)) is not StaticFunction:
+ if type(model_info(key, attr)) is not df.StaticFunction:
if model[key][attr]["mae"] > static[key][attr]["mae"]:
buf += " :-("
elif (