summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/validation.py43
1 files changed, 40 insertions, 3 deletions
diff --git a/lib/validation.py b/lib/validation.py
index e10ba6c..81b1819 100644
--- a/lib/validation.py
+++ b/lib/validation.py
@@ -25,6 +25,35 @@ def _xv_partitions_kfold(length, k=10):
return pairs
+def _xv_param_partitions_kfold(param_values, k=10):
+ indexes_by_param_value = dict()
+ distinct_pv = list()
+ for i, param_value in enumerate(param_values):
+ pv = tuple(param_value)
+ if pv in indexes_by_param_value:
+ indexes_by_param_value[pv].append(i)
+ else:
+ distinct_pv.append(pv)
+ indexes_by_param_value[pv] = [i]
+
+ indexes = np.arange(len(distinct_pv))
+ num_slices = k
+ pairs = list()
+ for i in range(num_slices):
+ training_groups = np.delete(indexes, slice(i, None, num_slices))
+ validation_groups = indexes[i::num_slices]
+ training = list()
+ for group in training_groups:
+ training.extend(indexes_by_param_value[distinct_pv[group]])
+ validation = list()
+ for group in validation_groups:
+ validation.extend(indexes_by_param_value[distinct_pv[group]])
+ if not (len(training) and len(validation)):
+ return None
+ pairs.append((training, validation))
+ return pairs
+
+
def _xv_partition_montecarlo(length):
"""
Return training and validation set for Monte Carlo cross-validation on `length` items.
@@ -73,6 +102,7 @@ class CrossValidator:
self.by_name = by_name
self.names = sorted(by_name.keys())
self.parameters = sorted(parameters)
+ self.parameter_aware = False
self.args = args
self.kwargs = kwargs
@@ -112,9 +142,16 @@ class CrossValidator:
training_and_validation_sets = list()
for name in self.names:
- sample_count = len(self.by_name[name]["param"])
- subsets_by_name[name] = list()
- subsets_by_name[name] = _xv_partitions_kfold(sample_count, k)
+ param_values = self.by_name[name]["param"]
+ if self.parameter_aware:
+ subsets_by_name[name] = _xv_param_partitions_kfold(param_values, k)
+ if subsets_by_name[name] is None:
+ logger.warning(
+ f"Insufficient amount of parameter combinations for {name}, falling back to parameter-unaware cross-validation"
+ )
+ subsets_by_name[name] = _xv_partitions_kfold(len(param_values), k)
+ else:
+ subsets_by_name[name] = _xv_partitions_kfold(len(param_values), k)
for i in range(k):
training_and_validation_sets.append(dict())