summaryrefslogtreecommitdiff
path: root/lib/cli.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/cli.py')
-rw-r--r--lib/cli.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/lib/cli.py b/lib/cli.py
index a176e88..91f2add 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -51,6 +51,25 @@ def print_splitinfo(param_names, info, prefix=""):
print(f"{prefix}: UNKNOWN")
+def print_model_size(model):
+ for name in model.names:
+ for attribute in model.attributes(name):
+ try:
+ num_nodes = model.attr_by_name[name][
+ attribute
+ ].model_function.get_number_of_nodes()
+ max_depth = model.attr_by_name[name][
+ attribute
+ ].model_function.get_max_depth()
+ print(
+ f"{name:15s} {attribute:20s}: {num_nodes:6d} nodes @ {max_depth:3d} max depth"
+ )
+ except AttributeError:
+ print(
+ f"{name:15s} {attribute:20s}: {model.attr_by_name[name][attribute].model_function}"
+ )
+
+
def format_quality_measures(result):
if "smape" in result:
return "{:6.2f}% / {:9.0f}".format(result["smape"], result["mae"])
@@ -104,6 +123,11 @@ def add_standard_arguments(parser):
help="Export model and model quality to LaTeX dataref file",
)
parser.add_argument(
+ "--show-model-size",
+ action="store_true",
+ help="Show model size (e.g. regression tree height and node count)",
+ )
+ parser.add_argument(
"--cross-validate",
metavar="<method>:<count>",
type=str,