summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-11-23 09:24:07 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2021-11-23 09:24:07 +0100
commitfa9405d911d7c6ea4cd3f6b19535f7d13a6f65d2 (patch)
tree95236624ad37246873a11f8c2ef2cca1efb8d079
parent5b5380430eb3b701c1c43e18524a6b4759f46e27 (diff)
add parameter-aware cross validation
-rwxr-xr-xbin/analyze-archive.py6
-rwxr-xr-xbin/analyze-kconfig.py6
-rw-r--r--lib/validation.py43
3 files changed, 52 insertions, 3 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py
index c66570e..f1b9f71 100755
--- a/bin/analyze-archive.py
+++ b/bin/analyze-archive.py
@@ -477,6 +477,11 @@ if __name__ == "__main__":
"Only works with --show-quality=table at the moment.",
)
parser.add_argument(
+ "--parameter-aware-cross-validation",
+ action="store_true",
+ help="Perform parameter-aware cross-validation: ensure that parameter values (and not just observations) are mutually exclusive between training and validation sets.",
+ )
+ parser.add_argument(
"--with-safe-functions",
action="store_true",
help="Include 'safe' functions (safe_log, safe_inv, safe_sqrt) which are also defined for 0 and -1. "
@@ -655,6 +660,7 @@ if __name__ == "__main__":
if xv_method:
xv = CrossValidator(PTAModel, by_name, parameters, arg_count)
+ xv.parameter_aware = args.parameter_aware_cross_validation
if args.info:
for state in model.states:
diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py
index 55bf53b..d4b87c6 100755
--- a/bin/analyze-kconfig.py
+++ b/bin/analyze-kconfig.py
@@ -99,6 +99,11 @@ def main():
metavar="METHOD:COUNT",
)
parser.add_argument(
+ "--parameter-aware-cross-validation",
+ action="store_true",
+ help="Perform parameter-aware cross-validation: ensure that parameter values (and not just observations) are mutually exclusive between training and validation sets.",
+ )
+ parser.add_argument(
"--show-model",
choices=["static", "paramdetection", "param", "all", "tex", "html"],
action="append",
@@ -228,6 +233,7 @@ def main():
force_tree=args.force_tree,
max_std=max_std,
)
+ xv.parameter_aware = args.parameter_aware_cross_validation
else:
xv_method = None
diff --git a/lib/validation.py b/lib/validation.py
index e10ba6c..81b1819 100644
--- a/lib/validation.py
+++ b/lib/validation.py
@@ -25,6 +25,35 @@ def _xv_partitions_kfold(length, k=10):
return pairs
+def _xv_param_partitions_kfold(param_values, k=10):
+ indexes_by_param_value = dict()
+ distinct_pv = list()
+ for i, param_value in enumerate(param_values):
+ pv = tuple(param_value)
+ if pv in indexes_by_param_value:
+ indexes_by_param_value[pv].append(i)
+ else:
+ distinct_pv.append(pv)
+ indexes_by_param_value[pv] = [i]
+
+ indexes = np.arange(len(distinct_pv))
+ num_slices = k
+ pairs = list()
+ for i in range(num_slices):
+ training_groups = np.delete(indexes, slice(i, None, num_slices))
+ validation_groups = indexes[i::num_slices]
+ training = list()
+ for group in training_groups:
+ training.extend(indexes_by_param_value[distinct_pv[group]])
+ validation = list()
+ for group in validation_groups:
+ validation.extend(indexes_by_param_value[distinct_pv[group]])
+ if not (len(training) and len(validation)):
+ return None
+ pairs.append((training, validation))
+ return pairs
+
+
def _xv_partition_montecarlo(length):
"""
Return training and validation set for Monte Carlo cross-validation on `length` items.
@@ -73,6 +102,7 @@ class CrossValidator:
self.by_name = by_name
self.names = sorted(by_name.keys())
self.parameters = sorted(parameters)
+ self.parameter_aware = False
self.args = args
self.kwargs = kwargs
@@ -112,9 +142,16 @@ class CrossValidator:
training_and_validation_sets = list()
for name in self.names:
- sample_count = len(self.by_name[name]["param"])
- subsets_by_name[name] = list()
- subsets_by_name[name] = _xv_partitions_kfold(sample_count, k)
+ param_values = self.by_name[name]["param"]
+ if self.parameter_aware:
+ subsets_by_name[name] = _xv_param_partitions_kfold(param_values, k)
+ if subsets_by_name[name] is None:
+ logger.warning(
+ f"Insufficient amount of parameter combinations for {name}, falling back to parameter-unaware cross-validation"
+ )
+ subsets_by_name[name] = _xv_partitions_kfold(len(param_values), k)
+ else:
+ subsets_by_name[name] = _xv_partitions_kfold(len(param_values), k)
for i in range(k):
training_and_validation_sets.append(dict())