diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-11-23 09:24:07 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-11-23 09:24:07 +0100 |
commit | fa9405d911d7c6ea4cd3f6b19535f7d13a6f65d2 (patch) | |
tree | 95236624ad37246873a11f8c2ef2cca1efb8d079 | |
parent | 5b5380430eb3b701c1c43e18524a6b4759f46e27 (diff) |
add parameter-aware cross validation
-rwxr-xr-x | bin/analyze-archive.py | 6 | ||||
-rwxr-xr-x | bin/analyze-kconfig.py | 6 | ||||
-rw-r--r-- | lib/validation.py | 43 |
3 files changed, 52 insertions, 3 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py index c66570e..f1b9f71 100755 --- a/bin/analyze-archive.py +++ b/bin/analyze-archive.py @@ -477,6 +477,11 @@ if __name__ == "__main__": "Only works with --show-quality=table at the moment.", ) parser.add_argument( + "--parameter-aware-cross-validation", + action="store_true", + help="Perform parameter-aware cross-validation: ensure that parameter values (and not just observations) are mutually exclusive between training and validation sets.", + ) + parser.add_argument( "--with-safe-functions", action="store_true", help="Include 'safe' functions (safe_log, safe_inv, safe_sqrt) which are also defined for 0 and -1. " @@ -655,6 +660,7 @@ if __name__ == "__main__": if xv_method: xv = CrossValidator(PTAModel, by_name, parameters, arg_count) + xv.parameter_aware = args.parameter_aware_cross_validation if args.info: for state in model.states: diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py index 55bf53b..d4b87c6 100755 --- a/bin/analyze-kconfig.py +++ b/bin/analyze-kconfig.py @@ -99,6 +99,11 @@ def main(): metavar="METHOD:COUNT", ) parser.add_argument( + "--parameter-aware-cross-validation", + action="store_true", + help="Perform parameter-aware cross-validation: ensure that parameter values (and not just observations) are mutually exclusive between training and validation sets.", + ) + parser.add_argument( "--show-model", choices=["static", "paramdetection", "param", "all", "tex", "html"], action="append", @@ -228,6 +233,7 @@ def main(): force_tree=args.force_tree, max_std=max_std, ) + xv.parameter_aware = args.parameter_aware_cross_validation else: xv_method = None diff --git a/lib/validation.py b/lib/validation.py index e10ba6c..81b1819 100644 --- a/lib/validation.py +++ b/lib/validation.py @@ -25,6 +25,35 @@ def _xv_partitions_kfold(length, k=10): return pairs +def _xv_param_partitions_kfold(param_values, k=10): + indexes_by_param_value = dict() + distinct_pv = list() + for i, param_value in enumerate(param_values): + pv = tuple(param_value) + if pv in indexes_by_param_value: + indexes_by_param_value[pv].append(i) + else: + distinct_pv.append(pv) + indexes_by_param_value[pv] = [i] + + indexes = np.arange(len(distinct_pv)) + num_slices = k + pairs = list() + for i in range(num_slices): + training_groups = np.delete(indexes, slice(i, None, num_slices)) + validation_groups = indexes[i::num_slices] + training = list() + for group in training_groups: + training.extend(indexes_by_param_value[distinct_pv[group]]) + validation = list() + for group in validation_groups: + validation.extend(indexes_by_param_value[distinct_pv[group]]) + if not (len(training) and len(validation)): + return None + pairs.append((training, validation)) + return pairs + + def _xv_partition_montecarlo(length): """ Return training and validation set for Monte Carlo cross-validation on `length` items. @@ -73,6 +102,7 @@ class CrossValidator: self.by_name = by_name self.names = sorted(by_name.keys()) self.parameters = sorted(parameters) + self.parameter_aware = False self.args = args self.kwargs = kwargs @@ -112,9 +142,16 @@ class CrossValidator: training_and_validation_sets = list() for name in self.names: - sample_count = len(self.by_name[name]["param"]) - subsets_by_name[name] = list() - subsets_by_name[name] = _xv_partitions_kfold(sample_count, k) + param_values = self.by_name[name]["param"] + if self.parameter_aware: + subsets_by_name[name] = _xv_param_partitions_kfold(param_values, k) + if subsets_by_name[name] is None: + logger.warning( + f"Insufficient amount of parameter combinations for {name}, falling back to parameter-unaware cross-validation" + ) + subsets_by_name[name] = _xv_partitions_kfold(len(param_values), k) + else: + subsets_by_name[name] = _xv_partitions_kfold(len(param_values), k) for i in range(k): training_and_validation_sets.append(dict()) |