summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-10-24 12:17:43 +0200
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-10-24 12:17:43 +0200
commita0933fef969c4555452fcbf70e6183eddf141956 (patch)
tree5246a776b4672e78bd0dda6d9da473a312c3a08e /lib
parentb8519f00d9c30a7726435aac6989455a7ba91afe (diff)
add --add-total-observation for behaviour model evaluation
Diffstat (limited to 'lib')
-rw-r--r--lib/cli.py13
-rw-r--r--lib/validation.py22
2 files changed, 28 insertions, 7 deletions
diff --git a/lib/cli.py b/lib/cli.py
index dfe0e34..b94bdbb 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -293,7 +293,11 @@ def model_quality_table(
buf = f"{key:>{key_len}s} {attr:>{attr_len}s}"
for results, info in ((lut, None), (model, model_info), (static, None)):
buf += " "
- if results is not None and (
+
+ # special case for "TOTAL" (--add-total-observation)
+ if attr == "TOTAL" and attr not in results[key]:
+ buf += f"""{"----":>7s} """
+ elif results is not None and (
info is None
or (
attr != "energy_Pt"
@@ -311,7 +315,7 @@ def model_quality_table(
buf += format_quality_measures(result, error_metric=error_metric)
else:
buf += f"""{"----":>7s} """
- if type(model_info(key, attr)) is not df.StaticFunction:
+ if attr != "TOTAL" and type(model_info(key, attr)) is not df.StaticFunction:
if model[key][attr]["mae"] > static[key][attr]["mae"]:
buf += " :-("
elif (
@@ -629,6 +633,11 @@ def add_standard_arguments(parser):
help="Show model complexity score and details (e.g. regression tree height and node count)",
)
parser.add_argument(
+ "--add-total-observation",
+ action="store_true",
+ help="Add a TOTAL observation for each <key> that consists of the sums of its <attribute> entries. This allows for cross-validation of behaviour models vs. non-behaviour-aware models.",
+ )
+ parser.add_argument(
"--cross-validate",
metavar="<method>:<count>",
type=str,
diff --git a/lib/validation.py b/lib/validation.py
index 958a9e0..bf6764d 100644
--- a/lib/validation.py
+++ b/lib/validation.py
@@ -109,7 +109,7 @@ class CrossValidator:
self.args = args
self.kwargs = kwargs
- def kfold(self, model_getter, k=10, static=False):
+ def kfold(self, model_getter, k=10, static=False, with_sum=False):
"""
Perform k-fold cross-validation and return average model quality.
@@ -161,10 +161,10 @@ class CrossValidator:
training_and_validation_sets[i][name] = subsets_by_name[name][i]
return self._generic_xv(
- model_getter, training_and_validation_sets, static=static
+ model_getter, training_and_validation_sets, static=static, with_sum=with_sum
)
- def montecarlo(self, model_getter, count=200, static=False):
+ def montecarlo(self, model_getter, count=200, static=False, with_sum=False):
"""
Perform Monte Carlo cross-validation and return average model quality.
@@ -211,10 +211,12 @@ class CrossValidator:
training_and_validation_sets[i][name] = subsets_by_name[name][i]
return self._generic_xv(
- model_getter, training_and_validation_sets, static=static
+ model_getter, training_and_validation_sets, static=static, with_sum=with_sum
)
- def _generic_xv(self, model_getter, training_and_validation_sets, static=False):
+ def _generic_xv(
+ self, model_getter, training_and_validation_sets, static=False, with_sum=False
+ ):
ret = dict()
models = list()
@@ -268,6 +270,16 @@ class CrossValidator:
)
)
+ if with_sum:
+ for name in self.names:
+ attr_0 = self.by_name[name]["attributes"][0]
+ gt_sum = np.zeros(len(ret[name][attr_0]["groundTruth"]))
+ mo_sum = np.zeros(len(ret[name][attr_0]["modelOutput"]))
+ for attribute in self.by_name[name]["attributes"]:
+ gt_sum += np.array(ret[name][attribute]["groundTruth"])
+ mo_sum += np.array(ret[name][attribute]["modelOutput"])
+ ret[name]["TOTAL"] = regression_measures(mo_sum, gt_sum)
+
return ret, models
def _single_xv(self, model_getter, tv_set_dict, static=False):