diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2022-01-05 15:23:51 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2022-01-05 15:23:51 +0100 |
commit | 2c5bcd77f2c952cc5269ca3e4b6e0a7323ebd085 (patch) | |
tree | 93da4dc33c77855445e6aa21f45c12b1803861fa /lib/validation.py | |
parent | d9aee2a314ae6d3fc0216893a4ccfd8bb66ffa9c (diff) |
cross validation: return intermediate models used for XV
These are interesting for statistics, e.g. to determine the average dtree size
Diffstat (limited to 'lib/validation.py')
-rw-r--r-- | lib/validation.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/lib/validation.py b/lib/validation.py index 47395dc..5c65fe3 100644 --- a/lib/validation.py +++ b/lib/validation.py @@ -211,6 +211,7 @@ class CrossValidator: def _generic_xv(self, model_getter, training_and_validation_sets): ret = dict() + models = list() for name in self.names: ret[name] = dict() @@ -223,7 +224,8 @@ class CrossValidator: } for training_and_validation_by_name in training_and_validation_sets: - res = self._single_xv(model_getter, training_and_validation_by_name) + model, res = self._single_xv(model_getter, training_and_validation_by_name) + models.append(model) for name in self.names: for attribute in self.by_name[name]["attributes"]: for measure in ("mae", "rmsd", "mape", "smape"): @@ -238,7 +240,7 @@ class CrossValidator: ret[name][attribute][f"{measure}_list"] ) - return ret + return ret, models def _single_xv(self, model_getter, tv_set_dict): training = dict() @@ -278,4 +280,4 @@ class CrossValidator: validation, self.parameters, *self.args, **self.kwargs ) - return validation_data.assess(training_model) + return training_model, validation_data.assess(training_model) |