summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-09-23 15:22:19 +0200
committerDaniel Friesel <daniel.friesel@uos.de>2022-09-23 15:22:19 +0200
commit656e4a22c55cd785a8f6fe079adfb7d249f42e1e (patch)
tree2cd606578c02045cc18efd1bc7d3a39ee5edf56c /lib
parente224ced44f95880d86f4913396d7c621fe2f2db1 (diff)
do not build dtree in static and LUT cross-validation runs
Diffstat (limited to 'lib')
-rw-r--r--lib/validation.py25
1 files changed, 17 insertions, 8 deletions
diff --git a/lib/validation.py b/lib/validation.py
index 0e735a0..95815ac 100644
--- a/lib/validation.py
+++ b/lib/validation.py
@@ -107,7 +107,7 @@ class CrossValidator:
self.args = args
self.kwargs = kwargs
- def kfold(self, model_getter, k=10):
+ def kfold(self, model_getter, k=10, static=False):
"""
Perform k-fold cross-validation and return average model quality.
@@ -159,9 +159,11 @@ class CrossValidator:
for name in self.names:
training_and_validation_sets[i][name] = subsets_by_name[name][i]
- return self._generic_xv(model_getter, training_and_validation_sets)
+ return self._generic_xv(
+ model_getter, training_and_validation_sets, static=static
+ )
- def montecarlo(self, model_getter, count=200):
+ def montecarlo(self, model_getter, count=200, static=False):
"""
Perform Monte Carlo cross-validation and return average model quality.
@@ -208,9 +210,11 @@ class CrossValidator:
for name in self.names:
training_and_validation_sets[i][name] = subsets_by_name[name][i]
- return self._generic_xv(model_getter, training_and_validation_sets)
+ return self._generic_xv(
+ model_getter, training_and_validation_sets, static=static
+ )
- def _generic_xv(self, model_getter, training_and_validation_sets):
+ def _generic_xv(self, model_getter, training_and_validation_sets, static=False):
ret = dict()
models = list()
@@ -225,7 +229,9 @@ class CrossValidator:
}
for training_and_validation_by_name in training_and_validation_sets:
- model, res = self._single_xv(model_getter, training_and_validation_by_name)
+ model, res = self._single_xv(
+ model_getter, training_and_validation_by_name, static=static
+ )
models.append(model)
for name in self.names:
for attribute in self.by_name[name]["attributes"]:
@@ -249,7 +255,7 @@ class CrossValidator:
return ret, models
- def _single_xv(self, model_getter, tv_set_dict):
+ def _single_xv(self, model_getter, tv_set_dict, static=False):
training = dict()
validation = dict()
for name in self.names:
@@ -279,8 +285,11 @@ class CrossValidator:
for idx in validation_subset:
validation[name]["param"].append(self.by_name[name]["param"][idx])
+ kwargs = self.kwargs.copy()
+ if static:
+ kwargs["force_tree"] = False
training_data = self.model_class(
- training, self.parameters, *self.args, **self.kwargs
+ training, self.parameters, *self.args, **kwargs
)
training_model = model_getter(training_data)
kwargs = self.kwargs.copy()