From 2c5bcd77f2c952cc5269ca3e4b6e0a7323ebd085 Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Wed, 5 Jan 2022 15:23:51 +0100 Subject: cross validation: return intermediate models used for XV These are interesting for statistics, e.g. to determine the average dtree size --- lib/validation.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'lib') diff --git a/lib/validation.py b/lib/validation.py index 47395dc..5c65fe3 100644 --- a/lib/validation.py +++ b/lib/validation.py @@ -211,6 +211,7 @@ class CrossValidator: def _generic_xv(self, model_getter, training_and_validation_sets): ret = dict() + models = list() for name in self.names: ret[name] = dict() @@ -223,7 +224,8 @@ class CrossValidator: } for training_and_validation_by_name in training_and_validation_sets: - res = self._single_xv(model_getter, training_and_validation_by_name) + model, res = self._single_xv(model_getter, training_and_validation_by_name) + models.append(model) for name in self.names: for attribute in self.by_name[name]["attributes"]: for measure in ("mae", "rmsd", "mape", "smape"): @@ -238,7 +240,7 @@ class CrossValidator: ret[name][attribute][f"{measure}_list"] ) - return ret + return ret, models def _single_xv(self, model_getter, tv_set_dict): training = dict() @@ -278,4 +280,4 @@ class CrossValidator: validation, self.parameters, *self.args, **self.kwargs ) - return validation_data.assess(training_model) + return training_model, validation_data.assess(training_model) -- cgit v1.2.3