From 2833115dff3da0e9b9a84fc5642b3a43034b27af Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Wed, 1 Jul 2020 10:25:47 +0200 Subject: Restore k-fold cross validation support --- bin/analyze-archive.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) (limited to 'bin/analyze-archive.py') diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py index cfb832f..212fd2e 100755 --- a/bin/analyze-archive.py +++ b/bin/analyze-archive.py @@ -67,12 +67,22 @@ Options: --cross-validate=: Perform cross validation when computing model quality. Only works with --show-quality=table at the moment. + If is "montecarlo": Randomly divide data into 2/3 training and 1/3 validation, times. Reported model quality is the average of all validation runs. Data is partitioned without regard for parameter values, so a specific parameter combination may be present in both training and validation sets or just one of them. + If is "kfold": Perform k-fold cross validation with k=. + Divide data into 1-1/k training and 1/k validation, times. + In the first set, items 0, k, 2k, ... ard used for validation, in the + second set, items 1, k+1, 2k+1, ... and so on. + validation, times. Reported model quality is the average of all + validation runs. Data is partitioned without regard for parameter values, + so a specific parameter combination may be present in both training and + validation sets or just one of them. + --function-override=[; ;...] Manually specify the function to fit for . A function specified this way bypasses parameter detection: It is always assigned, @@ -549,6 +559,8 @@ if __name__ == "__main__": if xv_method == "montecarlo": static_quality = xv.montecarlo(lambda m: m.get_static(), xv_count) + elif xv_method == "kfold": + static_quality = xv.kfold(lambda m: m.get_static(), xv_count) else: static_quality = model.assess(static_model) @@ -558,6 +570,8 @@ if __name__ == "__main__": if xv_method == "montecarlo": lut_quality = xv.montecarlo(lambda m: m.get_param_lut(fallback=True), xv_count) + elif xv_method == "kfold": + lut_quality = xv.kfold(lambda m: m.get_param_lut(fallback=True), xv_count) else: lut_quality = model.assess(lut_model) @@ -651,6 +665,8 @@ if __name__ == "__main__": if xv_method == "montecarlo": analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) + elif xv_method == "kfold": + analytic_quality = xv.kfold(lambda m: m.get_fitted()[0], xv_count) else: analytic_quality = model.assess(param_model) -- cgit v1.2.3