diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2020-07-01 10:25:47 +0200 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2020-07-01 10:25:47 +0200 |
commit | 2833115dff3da0e9b9a84fc5642b3a43034b27af (patch) | |
tree | d255c3f925d46383956f278bab50128663334c8e /bin | |
parent | 08b1449e27da52e186f951914290b56b18bc64b2 (diff) |
Restore k-fold cross validation support
Diffstat (limited to 'bin')
-rwxr-xr-x | bin/analyze-archive.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py index cfb832f..212fd2e 100755 --- a/bin/analyze-archive.py +++ b/bin/analyze-archive.py @@ -67,12 +67,22 @@ Options: --cross-validate=<method>:<count> Perform cross validation when computing model quality. Only works with --show-quality=table at the moment. + If <method> is "montecarlo": Randomly divide data into 2/3 training and 1/3 validation, <count> times. Reported model quality is the average of all validation runs. Data is partitioned without regard for parameter values, so a specific parameter combination may be present in both training and validation sets or just one of them. + If <method> is "kfold": Perform k-fold cross validation with k=<count>. + Divide data into 1-1/k training and 1/k validation, <count> times. + In the first set, items 0, k, 2k, ... ard used for validation, in the + second set, items 1, k+1, 2k+1, ... and so on. + validation, <count> times. Reported model quality is the average of all + validation runs. Data is partitioned without regard for parameter values, + so a specific parameter combination may be present in both training and + validation sets or just one of them. + --function-override=<name attribute function>[;<name> <attribute> <function>;...] Manually specify the function to fit for <name> <attribute>. A function specified this way bypasses parameter detection: It is always assigned, @@ -549,6 +559,8 @@ if __name__ == "__main__": if xv_method == "montecarlo": static_quality = xv.montecarlo(lambda m: m.get_static(), xv_count) + elif xv_method == "kfold": + static_quality = xv.kfold(lambda m: m.get_static(), xv_count) else: static_quality = model.assess(static_model) @@ -558,6 +570,8 @@ if __name__ == "__main__": if xv_method == "montecarlo": lut_quality = xv.montecarlo(lambda m: m.get_param_lut(fallback=True), xv_count) + elif xv_method == "kfold": + lut_quality = xv.kfold(lambda m: m.get_param_lut(fallback=True), xv_count) else: lut_quality = model.assess(lut_model) @@ -651,6 +665,8 @@ if __name__ == "__main__": if xv_method == "montecarlo": analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) + elif xv_method == "kfold": + analytic_quality = xv.kfold(lambda m: m.get_fitted()[0], xv_count) else: analytic_quality = model.assess(param_model) |