diff options
Diffstat (limited to 'bin')
-rwxr-xr-x | bin/analyze-archive.py | 29 | ||||
-rwxr-xr-x | bin/test.py | 2 |
2 files changed, 21 insertions, 10 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py index 007ec25..5f22b18 100755 --- a/bin/analyze-archive.py +++ b/bin/analyze-archive.py @@ -41,9 +41,14 @@ Options: --discard-outliers= not supported at the moment ---cross-validate +--cross-validate=<method>:<count> Perform cross validation when computing model quality. Only works with --show-quality=table at the moment. + If <method> is "montecarlo": Randomly divide data into 2/3 training and 1/3 + validation, <count> times. Reported model quality is the average of all + validation runs. Data is partitioned without regard for parameter values, + so a specific parameter combination may be present in both training and + validation sets or just one of them. --function-override=<name attribute function>[;<name> <attribute> <function>;...] Manually specify the function to fit for <name> <attribute>. A function @@ -164,12 +169,14 @@ if __name__ == '__main__': show_quality = [] hwmodel = None energymodel_export_file = None + xv_method = None + xv_count = 10 try: optspec = ( 'plot-unparam= plot-param= show-models= show-quality= ' 'ignored-trace-indexes= discard-outliers= function-override= ' - 'cross-validate ' + 'cross-validate= ' 'with-safe-functions hwmodel= export-energymodel=' ) raw_opts, args = getopt.getopt(sys.argv[1:], "", optspec.split(' ')) @@ -197,6 +204,10 @@ if __name__ == '__main__': if 'show-quality' in opts: show_quality = opts['show-quality'].split(',') + if 'cross-validate' in opts: + xv_method, xv_count = opts['cross-validate'].split(':') + xv_count = int(xv_count) + if 'with-safe-functions' in opts: safe_functions_enabled = True @@ -218,7 +229,7 @@ if __name__ == '__main__': function_override = function_override, hwmodel = hwmodel) - if 'cross-validate' in opts: + if xv_method: xv = CrossValidator(PTAModel, by_name, parameters, arg_count) if 'plot-unparam' in opts: @@ -251,8 +262,8 @@ if __name__ == '__main__': model.stats.generic_param_dependence_ratio(trans, 'rel_energy_next'))) print('{:10s}: {:.0f} µs'.format(trans, static_model(trans, 'duration'))) - if 'cross-validate' in opts: - static_quality = xv.montecarlo(lambda m: m.get_static()) + if xv_method == 'montecarlo': + static_quality = xv.montecarlo(lambda m: m.get_static(), xv_count) else: static_quality = model.assess(static_model) @@ -260,8 +271,8 @@ if __name__ == '__main__': print('--- LUT ---') lut_model = model.get_param_lut() - if 'cross-validate' in opts: - lut_quality = xv.montecarlo(lambda m: m.get_param_lut()) + if xv_method == 'montecarlo': + lut_quality = xv.montecarlo(lambda m: m.get_param_lut(fallback=True), xv_count) else: lut_quality = model.assess(lut_model) @@ -305,8 +316,8 @@ if __name__ == '__main__': print('{:10s}: {:10s}: {}'.format(trans, attribute, param_info(trans, attribute)['function']._model_str)) print('{:10s} {:10s} {}'.format('', '', param_info(trans, attribute)['function']._regression_args)) - if 'cross-validate' in opts: - analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0]) + if xv_method == 'montecarlo': + analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) else: analytic_quality = model.assess(param_model) diff --git a/bin/test.py b/bin/test.py index f53a9ef..c05ed51 100755 --- a/bin/test.py +++ b/bin/test.py @@ -87,7 +87,7 @@ class TestStaticModel(unittest.TestCase): self.assertAlmostEqual(static_model('off', 'duration'), 9130, places=0) self.assertAlmostEqual(static_model('setBrightness', 'duration'), 9130, places=0) - param_lut_model = model.get_param_lut() + param_lut_model = model.get_param_lut(fallback=True) self.assertAlmostEqual(param_lut_model('OFF', 'power', param=[None, None]), 7124, places=0) with self.assertRaises(KeyError): param_lut_model('ON', 'power', param=[None, None]) |