From fa9405d911d7c6ea4cd3f6b19535f7d13a6f65d2 Mon Sep 17 00:00:00 2001
From: Daniel Friesel <daniel.friesel@uos.de>
Date: Tue, 23 Nov 2021 09:24:07 +0100
Subject: add parameter-aware cross validation

---
 lib/validation.py | 43 ++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 40 insertions(+), 3 deletions(-)

(limited to 'lib')

diff --git a/lib/validation.py b/lib/validation.py
index e10ba6c..81b1819 100644
--- a/lib/validation.py
+++ b/lib/validation.py
@@ -25,6 +25,35 @@ def _xv_partitions_kfold(length, k=10):
     return pairs
 
 
+def _xv_param_partitions_kfold(param_values, k=10):
+    indexes_by_param_value = dict()
+    distinct_pv = list()
+    for i, param_value in enumerate(param_values):
+        pv = tuple(param_value)
+        if pv in indexes_by_param_value:
+            indexes_by_param_value[pv].append(i)
+        else:
+            distinct_pv.append(pv)
+            indexes_by_param_value[pv] = [i]
+
+    indexes = np.arange(len(distinct_pv))
+    num_slices = k
+    pairs = list()
+    for i in range(num_slices):
+        training_groups = np.delete(indexes, slice(i, None, num_slices))
+        validation_groups = indexes[i::num_slices]
+        training = list()
+        for group in training_groups:
+            training.extend(indexes_by_param_value[distinct_pv[group]])
+        validation = list()
+        for group in validation_groups:
+            validation.extend(indexes_by_param_value[distinct_pv[group]])
+        if not (len(training) and len(validation)):
+            return None
+        pairs.append((training, validation))
+    return pairs
+
+
 def _xv_partition_montecarlo(length):
     """
     Return training and validation set for Monte Carlo cross-validation on `length` items.
@@ -73,6 +102,7 @@ class CrossValidator:
         self.by_name = by_name
         self.names = sorted(by_name.keys())
         self.parameters = sorted(parameters)
+        self.parameter_aware = False
         self.args = args
         self.kwargs = kwargs
 
@@ -112,9 +142,16 @@ class CrossValidator:
         training_and_validation_sets = list()
 
         for name in self.names:
-            sample_count = len(self.by_name[name]["param"])
-            subsets_by_name[name] = list()
-            subsets_by_name[name] = _xv_partitions_kfold(sample_count, k)
+            param_values = self.by_name[name]["param"]
+            if self.parameter_aware:
+                subsets_by_name[name] = _xv_param_partitions_kfold(param_values, k)
+                if subsets_by_name[name] is None:
+                    logger.warning(
+                        f"Insufficient amount of parameter combinations for {name}, falling back to parameter-unaware cross-validation"
+                    )
+                    subsets_by_name[name] = _xv_partitions_kfold(len(param_values), k)
+            else:
+                subsets_by_name[name] = _xv_partitions_kfold(len(param_values), k)
 
         for i in range(k):
             training_and_validation_sets.append(dict())
-- 
cgit v1.2.3