diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/validation.py | 43 |
1 files changed, 40 insertions, 3 deletions
diff --git a/lib/validation.py b/lib/validation.py index e10ba6c..81b1819 100644 --- a/lib/validation.py +++ b/lib/validation.py @@ -25,6 +25,35 @@ def _xv_partitions_kfold(length, k=10): return pairs +def _xv_param_partitions_kfold(param_values, k=10): + indexes_by_param_value = dict() + distinct_pv = list() + for i, param_value in enumerate(param_values): + pv = tuple(param_value) + if pv in indexes_by_param_value: + indexes_by_param_value[pv].append(i) + else: + distinct_pv.append(pv) + indexes_by_param_value[pv] = [i] + + indexes = np.arange(len(distinct_pv)) + num_slices = k + pairs = list() + for i in range(num_slices): + training_groups = np.delete(indexes, slice(i, None, num_slices)) + validation_groups = indexes[i::num_slices] + training = list() + for group in training_groups: + training.extend(indexes_by_param_value[distinct_pv[group]]) + validation = list() + for group in validation_groups: + validation.extend(indexes_by_param_value[distinct_pv[group]]) + if not (len(training) and len(validation)): + return None + pairs.append((training, validation)) + return pairs + + def _xv_partition_montecarlo(length): """ Return training and validation set for Monte Carlo cross-validation on `length` items. @@ -73,6 +102,7 @@ class CrossValidator: self.by_name = by_name self.names = sorted(by_name.keys()) self.parameters = sorted(parameters) + self.parameter_aware = False self.args = args self.kwargs = kwargs @@ -112,9 +142,16 @@ class CrossValidator: training_and_validation_sets = list() for name in self.names: - sample_count = len(self.by_name[name]["param"]) - subsets_by_name[name] = list() - subsets_by_name[name] = _xv_partitions_kfold(sample_count, k) + param_values = self.by_name[name]["param"] + if self.parameter_aware: + subsets_by_name[name] = _xv_param_partitions_kfold(param_values, k) + if subsets_by_name[name] is None: + logger.warning( + f"Insufficient amount of parameter combinations for {name}, falling back to parameter-unaware cross-validation" + ) + subsets_by_name[name] = _xv_partitions_kfold(len(param_values), k) + else: + subsets_by_name[name] = _xv_partitions_kfold(len(param_values), k) for i in range(k): training_and_validation_sets.append(dict()) |