summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-02-18 15:24:07 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2022-02-18 15:24:07 +0100
commite1033f50fd5eabf20d2d0e0abef672d9c67e8b59 (patch)
treeab759698706311ce2e976e8e15311906d9976dc0
parent2e12c7d6a1c23d134743e3dbe22cb740348625b0 (diff)
add show_model_size option
-rwxr-xr-xbin/analyze-archive.py3
-rwxr-xr-xbin/analyze-kconfig.py3
-rw-r--r--lib/cli.py24
3 files changed, 30 insertions, 0 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py
index 5cc01f6..88b6802 100755
--- a/bin/analyze-archive.py
+++ b/bin/analyze-archive.py
@@ -1048,6 +1048,9 @@ if __name__ == "__main__":
]
)
+ if args.show_model_size:
+ dfatool.cli.print_model_size(model)
+
if args.plot_param:
for kv in args.plot_param.split(";"):
try:
diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py
index 8e60472..b9f071f 100755
--- a/bin/analyze-kconfig.py
+++ b/bin/analyze-kconfig.py
@@ -307,6 +307,9 @@ def main():
param_values = model.distinct_param_values_by_name[name][i]
print(f" Parameter {param} ∈ {param_values}")
+ if args.show_model_size:
+ dfatool.cli.print_model_size(model)
+
if args.export_model:
with open("nfpkeys.json", "r") as f:
nfpkeys = json.load(f)
diff --git a/lib/cli.py b/lib/cli.py
index a176e88..91f2add 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -51,6 +51,25 @@ def print_splitinfo(param_names, info, prefix=""):
print(f"{prefix}: UNKNOWN")
+def print_model_size(model):
+ for name in model.names:
+ for attribute in model.attributes(name):
+ try:
+ num_nodes = model.attr_by_name[name][
+ attribute
+ ].model_function.get_number_of_nodes()
+ max_depth = model.attr_by_name[name][
+ attribute
+ ].model_function.get_max_depth()
+ print(
+ f"{name:15s} {attribute:20s}: {num_nodes:6d} nodes @ {max_depth:3d} max depth"
+ )
+ except AttributeError:
+ print(
+ f"{name:15s} {attribute:20s}: {model.attr_by_name[name][attribute].model_function}"
+ )
+
+
def format_quality_measures(result):
if "smape" in result:
return "{:6.2f}% / {:9.0f}".format(result["smape"], result["mae"])
@@ -104,6 +123,11 @@ def add_standard_arguments(parser):
help="Export model and model quality to LaTeX dataref file",
)
parser.add_argument(
+ "--show-model-size",
+ action="store_true",
+ help="Show model size (e.g. regression tree height and node count)",
+ )
+ parser.add_argument(
"--cross-validate",
metavar="<method>:<count>",
type=str,