summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/validation.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/lib/validation.py b/lib/validation.py
index 47395dc..5c65fe3 100644
--- a/lib/validation.py
+++ b/lib/validation.py
@@ -211,6 +211,7 @@ class CrossValidator:
def _generic_xv(self, model_getter, training_and_validation_sets):
ret = dict()
+ models = list()
for name in self.names:
ret[name] = dict()
@@ -223,7 +224,8 @@ class CrossValidator:
}
for training_and_validation_by_name in training_and_validation_sets:
- res = self._single_xv(model_getter, training_and_validation_by_name)
+ model, res = self._single_xv(model_getter, training_and_validation_by_name)
+ models.append(model)
for name in self.names:
for attribute in self.by_name[name]["attributes"]:
for measure in ("mae", "rmsd", "mape", "smape"):
@@ -238,7 +240,7 @@ class CrossValidator:
ret[name][attribute][f"{measure}_list"]
)
- return ret
+ return ret, models
def _single_xv(self, model_getter, tv_set_dict):
training = dict()
@@ -278,4 +280,4 @@ class CrossValidator:
validation, self.parameters, *self.args, **self.kwargs
)
- return validation_data.assess(training_model)
+ return training_model, validation_data.assess(training_model)