From 024e05ed88cf262e4960746aedaaa83aca472769 Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Wed, 15 Jul 2020 11:20:00 +0200 Subject: CrossValidator: Compute RMSD --- lib/validation.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/lib/validation.py b/lib/validation.py index 98d49c1..ee147fe 100644 --- a/lib/validation.py +++ b/lib/validation.py @@ -179,6 +179,7 @@ class CrossValidator: for attribute in self.by_name[name]["attributes"]: ret["by_name"][name][attribute] = { "mae_list": list(), + "rmsd_list": list(), "smape_list": list(), } @@ -186,21 +187,17 @@ class CrossValidator: res = self._single_xv(model_getter, training_and_validation_by_name) for name in self.names: for attribute in self.by_name[name]["attributes"]: - ret["by_name"][name][attribute]["mae_list"].append( - res["by_name"][name][attribute]["mae"] - ) - ret["by_name"][name][attribute]["smape_list"].append( - res["by_name"][name][attribute]["smape"] - ) + for measure in ("mae", "rmsd", "smape"): + ret["by_name"][name][attribute][f"{measure}_list"].append( + res["by_name"][name][attribute][measure] + ) for name in self.names: for attribute in self.by_name[name]["attributes"]: - ret["by_name"][name][attribute]["mae"] = np.mean( - ret["by_name"][name][attribute]["mae_list"] - ) - ret["by_name"][name][attribute]["smape"] = np.mean( - ret["by_name"][name][attribute]["smape_list"] - ) + for measure in ("mae", "rmsd", "smape"): + ret["by_name"][name][attribute][measure] = np.mean( + ret["by_name"][name][attribute][f"{measure}_list"] + ) return ret -- cgit v1.2.3