diff options
Diffstat (limited to 'lib/validation.py')
-rw-r--r-- | lib/validation.py | 29 |
1 files changed, 16 insertions, 13 deletions
diff --git a/lib/validation.py b/lib/validation.py index 95815ac..8552ca5 100644 --- a/lib/validation.py +++ b/lib/validation.py @@ -2,6 +2,7 @@ import logging import numpy as np +from .utils import regression_measures logger = logging.getLogger(__name__) @@ -222,23 +223,23 @@ class CrossValidator: ret[name] = dict() for attribute in self.by_name[name]["attributes"]: ret[name][attribute] = { - "mae_list": list(), - "rmsd_list": list(), - "mape_list": list(), - "smape_list": list(), + "groundTruth": list(), + "modelOutput": list(), } for training_and_validation_by_name in training_and_validation_sets: - model, res = self._single_xv( + model, (res, raw) = self._single_xv( model_getter, training_and_validation_by_name, static=static ) models.append(model) for name in self.names: for attribute in self.by_name[name]["attributes"]: - for measure in ("mae", "rmsd", "mape", "smape"): - ret[name][attribute][f"{measure}_list"].append( - res[name][attribute][measure] - ) + ret[name][attribute]["groundTruth"].extend( + raw[name]["attribute"][attribute]["groundTruth"] + ) + ret[name][attribute]["modelOutput"].extend( + raw[name]["attribute"][attribute]["modelOutput"] + ) if self.export_filename: import json @@ -248,10 +249,12 @@ class CrossValidator: for name in self.names: for attribute in self.by_name[name]["attributes"]: - for measure in ("mae", "rmsd", "mape", "smape"): - ret[name][attribute][measure] = np.mean( - ret[name][attribute][f"{measure}_list"] + ret[name][attribute].update( + regression_measures( + np.array(ret[name][attribute]["modelOutput"]), + np.array(ret[name][attribute]["groundTruth"]), ) + ) return ret, models @@ -299,4 +302,4 @@ class CrossValidator: validation, self.parameters, *self.args, **kwargs ) - return training_data, validation_data.assess(training_model) + return training_data, validation_data.assess(training_model, return_raw=True) |