diff options
Diffstat (limited to 'lib/validation.py')
-rw-r--r-- | lib/validation.py | 22 |
1 files changed, 17 insertions, 5 deletions
diff --git a/lib/validation.py b/lib/validation.py index 958a9e0..bf6764d 100644 --- a/lib/validation.py +++ b/lib/validation.py @@ -109,7 +109,7 @@ class CrossValidator: self.args = args self.kwargs = kwargs - def kfold(self, model_getter, k=10, static=False): + def kfold(self, model_getter, k=10, static=False, with_sum=False): """ Perform k-fold cross-validation and return average model quality. @@ -161,10 +161,10 @@ class CrossValidator: training_and_validation_sets[i][name] = subsets_by_name[name][i] return self._generic_xv( - model_getter, training_and_validation_sets, static=static + model_getter, training_and_validation_sets, static=static, with_sum=with_sum ) - def montecarlo(self, model_getter, count=200, static=False): + def montecarlo(self, model_getter, count=200, static=False, with_sum=False): """ Perform Monte Carlo cross-validation and return average model quality. @@ -211,10 +211,12 @@ class CrossValidator: training_and_validation_sets[i][name] = subsets_by_name[name][i] return self._generic_xv( - model_getter, training_and_validation_sets, static=static + model_getter, training_and_validation_sets, static=static, with_sum=with_sum ) - def _generic_xv(self, model_getter, training_and_validation_sets, static=False): + def _generic_xv( + self, model_getter, training_and_validation_sets, static=False, with_sum=False + ): ret = dict() models = list() @@ -268,6 +270,16 @@ class CrossValidator: ) ) + if with_sum: + for name in self.names: + attr_0 = self.by_name[name]["attributes"][0] + gt_sum = np.zeros(len(ret[name][attr_0]["groundTruth"])) + mo_sum = np.zeros(len(ret[name][attr_0]["modelOutput"])) + for attribute in self.by_name[name]["attributes"]: + gt_sum += np.array(ret[name][attribute]["groundTruth"]) + mo_sum += np.array(ret[name][attribute]["modelOutput"]) + ret[name]["TOTAL"] = regression_measures(mo_sum, gt_sum) + return ret, models def _single_xv(self, model_getter, tv_set_dict, static=False): |