diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2022-09-23 15:22:19 +0200 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2022-09-23 15:22:19 +0200 |
commit | 656e4a22c55cd785a8f6fe079adfb7d249f42e1e (patch) | |
tree | 2cd606578c02045cc18efd1bc7d3a39ee5edf56c /lib | |
parent | e224ced44f95880d86f4913396d7c621fe2f2db1 (diff) |
do not build dtree in static and LUT cross-validation runs
Diffstat (limited to 'lib')
-rw-r--r-- | lib/validation.py | 25 |
1 files changed, 17 insertions, 8 deletions
diff --git a/lib/validation.py b/lib/validation.py index 0e735a0..95815ac 100644 --- a/lib/validation.py +++ b/lib/validation.py @@ -107,7 +107,7 @@ class CrossValidator: self.args = args self.kwargs = kwargs - def kfold(self, model_getter, k=10): + def kfold(self, model_getter, k=10, static=False): """ Perform k-fold cross-validation and return average model quality. @@ -159,9 +159,11 @@ class CrossValidator: for name in self.names: training_and_validation_sets[i][name] = subsets_by_name[name][i] - return self._generic_xv(model_getter, training_and_validation_sets) + return self._generic_xv( + model_getter, training_and_validation_sets, static=static + ) - def montecarlo(self, model_getter, count=200): + def montecarlo(self, model_getter, count=200, static=False): """ Perform Monte Carlo cross-validation and return average model quality. @@ -208,9 +210,11 @@ class CrossValidator: for name in self.names: training_and_validation_sets[i][name] = subsets_by_name[name][i] - return self._generic_xv(model_getter, training_and_validation_sets) + return self._generic_xv( + model_getter, training_and_validation_sets, static=static + ) - def _generic_xv(self, model_getter, training_and_validation_sets): + def _generic_xv(self, model_getter, training_and_validation_sets, static=False): ret = dict() models = list() @@ -225,7 +229,9 @@ class CrossValidator: } for training_and_validation_by_name in training_and_validation_sets: - model, res = self._single_xv(model_getter, training_and_validation_by_name) + model, res = self._single_xv( + model_getter, training_and_validation_by_name, static=static + ) models.append(model) for name in self.names: for attribute in self.by_name[name]["attributes"]: @@ -249,7 +255,7 @@ class CrossValidator: return ret, models - def _single_xv(self, model_getter, tv_set_dict): + def _single_xv(self, model_getter, tv_set_dict, static=False): training = dict() validation = dict() for name in self.names: @@ -279,8 +285,11 @@ class CrossValidator: for idx in validation_subset: validation[name]["param"].append(self.by_name[name]["param"][idx]) + kwargs = self.kwargs.copy() + if static: + kwargs["force_tree"] = False training_data = self.model_class( - training, self.parameters, *self.args, **self.kwargs + training, self.parameters, *self.args, **kwargs ) training_model = model_getter(training_data) kwargs = self.kwargs.copy() |