From a0933fef969c4555452fcbf70e6183eddf141956 Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Thu, 24 Oct 2024 12:17:43 +0200 Subject: add --add-total-observation for behaviour model evaluation --- lib/cli.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) (limited to 'lib/cli.py') diff --git a/lib/cli.py b/lib/cli.py index dfe0e34..b94bdbb 100644 --- a/lib/cli.py +++ b/lib/cli.py @@ -293,7 +293,11 @@ def model_quality_table( buf = f"{key:>{key_len}s} {attr:>{attr_len}s}" for results, info in ((lut, None), (model, model_info), (static, None)): buf += " " - if results is not None and ( + + # special case for "TOTAL" (--add-total-observation) + if attr == "TOTAL" and attr not in results[key]: + buf += f"""{"----":>7s} """ + elif results is not None and ( info is None or ( attr != "energy_Pt" @@ -311,7 +315,7 @@ def model_quality_table( buf += format_quality_measures(result, error_metric=error_metric) else: buf += f"""{"----":>7s} """ - if type(model_info(key, attr)) is not df.StaticFunction: + if attr != "TOTAL" and type(model_info(key, attr)) is not df.StaticFunction: if model[key][attr]["mae"] > static[key][attr]["mae"]: buf += " :-(" elif ( @@ -628,6 +632,11 @@ def add_standard_arguments(parser): action="store_true", help="Show model complexity score and details (e.g. regression tree height and node count)", ) + parser.add_argument( + "--add-total-observation", + action="store_true", + help="Add a TOTAL observation for each that consists of the sums of its entries. This allows for cross-validation of behaviour models vs. non-behaviour-aware models.", + ) parser.add_argument( "--cross-validate", metavar=":", -- cgit v1.2.3