diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/validation.py | 21 |
1 files changed, 9 insertions, 12 deletions
diff --git a/lib/validation.py b/lib/validation.py index 98d49c1..ee147fe 100644 --- a/lib/validation.py +++ b/lib/validation.py @@ -179,6 +179,7 @@ class CrossValidator: for attribute in self.by_name[name]["attributes"]: ret["by_name"][name][attribute] = { "mae_list": list(), + "rmsd_list": list(), "smape_list": list(), } @@ -186,21 +187,17 @@ class CrossValidator: res = self._single_xv(model_getter, training_and_validation_by_name) for name in self.names: for attribute in self.by_name[name]["attributes"]: - ret["by_name"][name][attribute]["mae_list"].append( - res["by_name"][name][attribute]["mae"] - ) - ret["by_name"][name][attribute]["smape_list"].append( - res["by_name"][name][attribute]["smape"] - ) + for measure in ("mae", "rmsd", "smape"): + ret["by_name"][name][attribute][f"{measure}_list"].append( + res["by_name"][name][attribute][measure] + ) for name in self.names: for attribute in self.by_name[name]["attributes"]: - ret["by_name"][name][attribute]["mae"] = np.mean( - ret["by_name"][name][attribute]["mae_list"] - ) - ret["by_name"][name][attribute]["smape"] = np.mean( - ret["by_name"][name][attribute]["smape_list"] - ) + for measure in ("mae", "rmsd", "smape"): + ret["by_name"][name][attribute][measure] = np.mean( + ret["by_name"][name][attribute][f"{measure}_list"] + ) return ret |