summaryrefslogtreecommitdiff
path: root/bin
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2020-07-01 10:25:47 +0200
committerDaniel Friesel <daniel.friesel@uos.de>2020-07-01 10:25:47 +0200
commit2833115dff3da0e9b9a84fc5642b3a43034b27af (patch)
treed255c3f925d46383956f278bab50128663334c8e /bin
parent08b1449e27da52e186f951914290b56b18bc64b2 (diff)
Restore k-fold cross validation support
Diffstat (limited to 'bin')
-rwxr-xr-xbin/analyze-archive.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py
index cfb832f..212fd2e 100755
--- a/bin/analyze-archive.py
+++ b/bin/analyze-archive.py
@@ -67,12 +67,22 @@ Options:
--cross-validate=<method>:<count>
Perform cross validation when computing model quality.
Only works with --show-quality=table at the moment.
+
If <method> is "montecarlo": Randomly divide data into 2/3 training and 1/3
validation, <count> times. Reported model quality is the average of all
validation runs. Data is partitioned without regard for parameter values,
so a specific parameter combination may be present in both training and
validation sets or just one of them.
+ If <method> is "kfold": Perform k-fold cross validation with k=<count>.
+ Divide data into 1-1/k training and 1/k validation, <count> times.
+ In the first set, items 0, k, 2k, ... ard used for validation, in the
+ second set, items 1, k+1, 2k+1, ... and so on.
+ validation, <count> times. Reported model quality is the average of all
+ validation runs. Data is partitioned without regard for parameter values,
+ so a specific parameter combination may be present in both training and
+ validation sets or just one of them.
+
--function-override=<name attribute function>[;<name> <attribute> <function>;...]
Manually specify the function to fit for <name> <attribute>. A function
specified this way bypasses parameter detection: It is always assigned,
@@ -549,6 +559,8 @@ if __name__ == "__main__":
if xv_method == "montecarlo":
static_quality = xv.montecarlo(lambda m: m.get_static(), xv_count)
+ elif xv_method == "kfold":
+ static_quality = xv.kfold(lambda m: m.get_static(), xv_count)
else:
static_quality = model.assess(static_model)
@@ -558,6 +570,8 @@ if __name__ == "__main__":
if xv_method == "montecarlo":
lut_quality = xv.montecarlo(lambda m: m.get_param_lut(fallback=True), xv_count)
+ elif xv_method == "kfold":
+ lut_quality = xv.kfold(lambda m: m.get_param_lut(fallback=True), xv_count)
else:
lut_quality = model.assess(lut_model)
@@ -651,6 +665,8 @@ if __name__ == "__main__":
if xv_method == "montecarlo":
analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count)
+ elif xv_method == "kfold":
+ analytic_quality = xv.kfold(lambda m: m.get_fitted()[0], xv_count)
else:
analytic_quality = model.assess(param_model)