summaryrefslogtreecommitdiff
path: root/lib/validation.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/validation.py')
-rw-r--r--lib/validation.py29
1 files changed, 16 insertions, 13 deletions
diff --git a/lib/validation.py b/lib/validation.py
index 95815ac..8552ca5 100644
--- a/lib/validation.py
+++ b/lib/validation.py
@@ -2,6 +2,7 @@
import logging
import numpy as np
+from .utils import regression_measures
logger = logging.getLogger(__name__)
@@ -222,23 +223,23 @@ class CrossValidator:
ret[name] = dict()
for attribute in self.by_name[name]["attributes"]:
ret[name][attribute] = {
- "mae_list": list(),
- "rmsd_list": list(),
- "mape_list": list(),
- "smape_list": list(),
+ "groundTruth": list(),
+ "modelOutput": list(),
}
for training_and_validation_by_name in training_and_validation_sets:
- model, res = self._single_xv(
+ model, (res, raw) = self._single_xv(
model_getter, training_and_validation_by_name, static=static
)
models.append(model)
for name in self.names:
for attribute in self.by_name[name]["attributes"]:
- for measure in ("mae", "rmsd", "mape", "smape"):
- ret[name][attribute][f"{measure}_list"].append(
- res[name][attribute][measure]
- )
+ ret[name][attribute]["groundTruth"].extend(
+ raw[name]["attribute"][attribute]["groundTruth"]
+ )
+ ret[name][attribute]["modelOutput"].extend(
+ raw[name]["attribute"][attribute]["modelOutput"]
+ )
if self.export_filename:
import json
@@ -248,10 +249,12 @@ class CrossValidator:
for name in self.names:
for attribute in self.by_name[name]["attributes"]:
- for measure in ("mae", "rmsd", "mape", "smape"):
- ret[name][attribute][measure] = np.mean(
- ret[name][attribute][f"{measure}_list"]
+ ret[name][attribute].update(
+ regression_measures(
+ np.array(ret[name][attribute]["modelOutput"]),
+ np.array(ret[name][attribute]["groundTruth"]),
)
+ )
return ret, models
@@ -299,4 +302,4 @@ class CrossValidator:
validation, self.parameters, *self.args, **kwargs
)
- return training_data, validation_data.assess(training_model)
+ return training_data, validation_data.assess(training_model, return_raw=True)