summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/model.py8
-rw-r--r--lib/validation.py29
-rwxr-xr-xtest/test_ptamodel.py4
3 files changed, 22 insertions, 19 deletions
diff --git a/lib/model.py b/lib/model.py
index 427b5ec..71d3367 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -430,8 +430,8 @@ class AnalyticModel:
detailed_results[name][attribute] = measures
if return_raw:
raw_results[name]["attribute"][attribute] = {
- "groundTruth": list(elem[attribute]),
- "modelOutput": list(predicted_data),
+ "groundTruth": elem[attribute],
+ "modelOutput": predicted_data,
}
if return_raw:
@@ -1150,8 +1150,8 @@ class PTAModel(AnalyticModel):
detailed_results[name]["energy_Pt"] = measures
if return_raw:
raw_results[name]["attribute"]["energy_Pt"] = {
- "groundTruth": list(elem["power"] * elem["duration"]),
- "modelOutput": list(predicted_data),
+ "groundTruth": elem["power"] * elem["duration"],
+ "modelOutput": predicted_data,
}
if return_raw:
diff --git a/lib/validation.py b/lib/validation.py
index 95815ac..8552ca5 100644
--- a/lib/validation.py
+++ b/lib/validation.py
@@ -2,6 +2,7 @@
import logging
import numpy as np
+from .utils import regression_measures
logger = logging.getLogger(__name__)
@@ -222,23 +223,23 @@ class CrossValidator:
ret[name] = dict()
for attribute in self.by_name[name]["attributes"]:
ret[name][attribute] = {
- "mae_list": list(),
- "rmsd_list": list(),
- "mape_list": list(),
- "smape_list": list(),
+ "groundTruth": list(),
+ "modelOutput": list(),
}
for training_and_validation_by_name in training_and_validation_sets:
- model, res = self._single_xv(
+ model, (res, raw) = self._single_xv(
model_getter, training_and_validation_by_name, static=static
)
models.append(model)
for name in self.names:
for attribute in self.by_name[name]["attributes"]:
- for measure in ("mae", "rmsd", "mape", "smape"):
- ret[name][attribute][f"{measure}_list"].append(
- res[name][attribute][measure]
- )
+ ret[name][attribute]["groundTruth"].extend(
+ raw[name]["attribute"][attribute]["groundTruth"]
+ )
+ ret[name][attribute]["modelOutput"].extend(
+ raw[name]["attribute"][attribute]["modelOutput"]
+ )
if self.export_filename:
import json
@@ -248,10 +249,12 @@ class CrossValidator:
for name in self.names:
for attribute in self.by_name[name]["attributes"]:
- for measure in ("mae", "rmsd", "mape", "smape"):
- ret[name][attribute][measure] = np.mean(
- ret[name][attribute][f"{measure}_list"]
+ ret[name][attribute].update(
+ regression_measures(
+ np.array(ret[name][attribute]["modelOutput"]),
+ np.array(ret[name][attribute]["groundTruth"]),
)
+ )
return ret, models
@@ -299,4 +302,4 @@ class CrossValidator:
validation, self.parameters, *self.args, **kwargs
)
- return training_data, validation_data.assess(training_model)
+ return training_data, validation_data.assess(training_model, return_raw=True)
diff --git a/test/test_ptamodel.py b/test/test_ptamodel.py
index b40289c..cd774c3 100755
--- a/test/test_ptamodel.py
+++ b/test/test_ptamodel.py
@@ -330,8 +330,8 @@ class TestSynthetic(unittest.TestCase):
# the Root Mean Square Deviation must not be greater the scale (i.e., standard deviation) of the normal distribution
# Low Mean Absolute Error (< 2)
self.assertTrue(static_quality["raw_state_1"]["duration"]["mae"] < 2)
- # Low Root Mean Square Deviation (< scale == 2)
- self.assertTrue(static_quality["raw_state_1"]["duration"]["rmsd"] < 2)
+ # Low Root Mean Square Deviation (< scale+eps == 2.02)
+ self.assertTrue(static_quality["raw_state_1"]["duration"]["rmsd"] < 2.02)
# Relatively low error percentage (~~ MAE * 100% / s1_duration_base)
self.assertAlmostEqual(
static_quality["raw_state_1"]["duration"]["smape"],