diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2023-12-07 14:32:57 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2023-12-07 14:32:57 +0100 |
commit | 2dfdb0e9839ccbd9b0f93025774480a089100dba (patch) | |
tree | b7eaee7aa339e73ef962ced866873c6ec8887675 | |
parent | b9695e80314fe91b891c9a2e3f792805b9c729bd (diff) |
xv: calculate measures from samples rather than averaging intermediates
-rw-r--r-- | lib/model.py | 8 | ||||
-rw-r--r-- | lib/validation.py | 29 | ||||
-rwxr-xr-x | test/test_ptamodel.py | 4 |
3 files changed, 22 insertions, 19 deletions
diff --git a/lib/model.py b/lib/model.py index 427b5ec..71d3367 100644 --- a/lib/model.py +++ b/lib/model.py @@ -430,8 +430,8 @@ class AnalyticModel: detailed_results[name][attribute] = measures if return_raw: raw_results[name]["attribute"][attribute] = { - "groundTruth": list(elem[attribute]), - "modelOutput": list(predicted_data), + "groundTruth": elem[attribute], + "modelOutput": predicted_data, } if return_raw: @@ -1150,8 +1150,8 @@ class PTAModel(AnalyticModel): detailed_results[name]["energy_Pt"] = measures if return_raw: raw_results[name]["attribute"]["energy_Pt"] = { - "groundTruth": list(elem["power"] * elem["duration"]), - "modelOutput": list(predicted_data), + "groundTruth": elem["power"] * elem["duration"], + "modelOutput": predicted_data, } if return_raw: diff --git a/lib/validation.py b/lib/validation.py index 95815ac..8552ca5 100644 --- a/lib/validation.py +++ b/lib/validation.py @@ -2,6 +2,7 @@ import logging import numpy as np +from .utils import regression_measures logger = logging.getLogger(__name__) @@ -222,23 +223,23 @@ class CrossValidator: ret[name] = dict() for attribute in self.by_name[name]["attributes"]: ret[name][attribute] = { - "mae_list": list(), - "rmsd_list": list(), - "mape_list": list(), - "smape_list": list(), + "groundTruth": list(), + "modelOutput": list(), } for training_and_validation_by_name in training_and_validation_sets: - model, res = self._single_xv( + model, (res, raw) = self._single_xv( model_getter, training_and_validation_by_name, static=static ) models.append(model) for name in self.names: for attribute in self.by_name[name]["attributes"]: - for measure in ("mae", "rmsd", "mape", "smape"): - ret[name][attribute][f"{measure}_list"].append( - res[name][attribute][measure] - ) + ret[name][attribute]["groundTruth"].extend( + raw[name]["attribute"][attribute]["groundTruth"] + ) + ret[name][attribute]["modelOutput"].extend( + raw[name]["attribute"][attribute]["modelOutput"] + ) if self.export_filename: import json @@ -248,10 +249,12 @@ class CrossValidator: for name in self.names: for attribute in self.by_name[name]["attributes"]: - for measure in ("mae", "rmsd", "mape", "smape"): - ret[name][attribute][measure] = np.mean( - ret[name][attribute][f"{measure}_list"] + ret[name][attribute].update( + regression_measures( + np.array(ret[name][attribute]["modelOutput"]), + np.array(ret[name][attribute]["groundTruth"]), ) + ) return ret, models @@ -299,4 +302,4 @@ class CrossValidator: validation, self.parameters, *self.args, **kwargs ) - return training_data, validation_data.assess(training_model) + return training_data, validation_data.assess(training_model, return_raw=True) diff --git a/test/test_ptamodel.py b/test/test_ptamodel.py index b40289c..cd774c3 100755 --- a/test/test_ptamodel.py +++ b/test/test_ptamodel.py @@ -330,8 +330,8 @@ class TestSynthetic(unittest.TestCase): # the Root Mean Square Deviation must not be greater the scale (i.e., standard deviation) of the normal distribution # Low Mean Absolute Error (< 2) self.assertTrue(static_quality["raw_state_1"]["duration"]["mae"] < 2) - # Low Root Mean Square Deviation (< scale == 2) - self.assertTrue(static_quality["raw_state_1"]["duration"]["rmsd"] < 2) + # Low Root Mean Square Deviation (< scale+eps == 2.02) + self.assertTrue(static_quality["raw_state_1"]["duration"]["rmsd"] < 2.02) # Relatively low error percentage (~~ MAE * 100% / s1_duration_base) self.assertAlmostEqual( static_quality["raw_state_1"]["duration"]["smape"], |