diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-10-24 12:17:43 +0200 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-10-24 12:17:43 +0200 |
commit | a0933fef969c4555452fcbf70e6183eddf141956 (patch) | |
tree | 5246a776b4672e78bd0dda6d9da473a312c3a08e /lib | |
parent | b8519f00d9c30a7726435aac6989455a7ba91afe (diff) |
add --add-total-observation for behaviour model evaluation
Diffstat (limited to 'lib')
-rw-r--r-- | lib/cli.py | 13 | ||||
-rw-r--r-- | lib/validation.py | 22 |
2 files changed, 28 insertions, 7 deletions
@@ -293,7 +293,11 @@ def model_quality_table( buf = f"{key:>{key_len}s} {attr:>{attr_len}s}" for results, info in ((lut, None), (model, model_info), (static, None)): buf += " " - if results is not None and ( + + # special case for "TOTAL" (--add-total-observation) + if attr == "TOTAL" and attr not in results[key]: + buf += f"""{"----":>7s} """ + elif results is not None and ( info is None or ( attr != "energy_Pt" @@ -311,7 +315,7 @@ def model_quality_table( buf += format_quality_measures(result, error_metric=error_metric) else: buf += f"""{"----":>7s} """ - if type(model_info(key, attr)) is not df.StaticFunction: + if attr != "TOTAL" and type(model_info(key, attr)) is not df.StaticFunction: if model[key][attr]["mae"] > static[key][attr]["mae"]: buf += " :-(" elif ( @@ -629,6 +633,11 @@ def add_standard_arguments(parser): help="Show model complexity score and details (e.g. regression tree height and node count)", ) parser.add_argument( + "--add-total-observation", + action="store_true", + help="Add a TOTAL observation for each <key> that consists of the sums of its <attribute> entries. This allows for cross-validation of behaviour models vs. non-behaviour-aware models.", + ) + parser.add_argument( "--cross-validate", metavar="<method>:<count>", type=str, diff --git a/lib/validation.py b/lib/validation.py index 958a9e0..bf6764d 100644 --- a/lib/validation.py +++ b/lib/validation.py @@ -109,7 +109,7 @@ class CrossValidator: self.args = args self.kwargs = kwargs - def kfold(self, model_getter, k=10, static=False): + def kfold(self, model_getter, k=10, static=False, with_sum=False): """ Perform k-fold cross-validation and return average model quality. @@ -161,10 +161,10 @@ class CrossValidator: training_and_validation_sets[i][name] = subsets_by_name[name][i] return self._generic_xv( - model_getter, training_and_validation_sets, static=static + model_getter, training_and_validation_sets, static=static, with_sum=with_sum ) - def montecarlo(self, model_getter, count=200, static=False): + def montecarlo(self, model_getter, count=200, static=False, with_sum=False): """ Perform Monte Carlo cross-validation and return average model quality. @@ -211,10 +211,12 @@ class CrossValidator: training_and_validation_sets[i][name] = subsets_by_name[name][i] return self._generic_xv( - model_getter, training_and_validation_sets, static=static + model_getter, training_and_validation_sets, static=static, with_sum=with_sum ) - def _generic_xv(self, model_getter, training_and_validation_sets, static=False): + def _generic_xv( + self, model_getter, training_and_validation_sets, static=False, with_sum=False + ): ret = dict() models = list() @@ -268,6 +270,16 @@ class CrossValidator: ) ) + if with_sum: + for name in self.names: + attr_0 = self.by_name[name]["attributes"][0] + gt_sum = np.zeros(len(ret[name][attr_0]["groundTruth"])) + mo_sum = np.zeros(len(ret[name][attr_0]["modelOutput"])) + for attribute in self.by_name[name]["attributes"]: + gt_sum += np.array(ret[name][attribute]["groundTruth"]) + mo_sum += np.array(ret[name][attribute]["modelOutput"]) + ret[name]["TOTAL"] = regression_measures(mo_sum, gt_sum) + return ret, models def _single_xv(self, model_getter, tv_set_dict, static=False): |