diff options
author | Daniel Friesel <derf@finalrewind.org> | 2019-02-13 16:33:00 +0100 |
---|---|---|
committer | Daniel Friesel <derf@finalrewind.org> | 2019-02-13 16:33:00 +0100 |
commit | 906a01beeb24fafa691f585018a0ec809a88de08 (patch) | |
tree | 006a3c421c8360d6b3be71710ab7549b99f6bce5 /bin/analyze-archive.py | |
parent | df1beb6debcd7c527d49a1a722b023b29f38b859 (diff) |
add generic monte carlo cross validation
Diffstat (limited to 'bin/analyze-archive.py')
-rwxr-xr-x | bin/analyze-archive.py | 32 |
1 files changed, 26 insertions, 6 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py index 21f9ceb..007ec25 100755 --- a/bin/analyze-archive.py +++ b/bin/analyze-archive.py @@ -41,6 +41,10 @@ Options: --discard-outliers= not supported at the moment +--cross-validate + Perform cross validation when computing model quality. + Only works with --show-quality=table at the moment. + --function-override=<name attribute function>[;<name> <attribute> <function>;...] Manually specify the function to fit for <name> <attribute>. A function specified this way bypasses parameter detection: It is always assigned, @@ -65,6 +69,7 @@ import re import sys from dfatool import PTAModel, RawData, pta_trace_to_aggregate from dfatool import soft_cast_int, is_numeric, gplearn_to_function +from dfatool import CrossValidator opts = {} @@ -86,14 +91,14 @@ def format_quality_measures(result): return '{:6} {:9.0f}'.format('', result['mae']) def model_quality_table(result_lists, info_list): - for state_or_tran in result_lists[0]['by_dfa_component'].keys(): - for key in result_lists[0]['by_dfa_component'][state_or_tran].keys(): + for state_or_tran in result_lists[0]['by_name'].keys(): + for key in result_lists[0]['by_name'][state_or_tran].keys(): buf = '{:20s} {:15s}'.format(state_or_tran, key) for i, results in enumerate(result_lists): info = info_list[i] buf += ' ||| ' if info == None or info(state_or_tran, key): - result = results['by_dfa_component'][state_or_tran][key] + result = results['by_name'][state_or_tran][key] buf += format_quality_measures(result) else: buf += '{:6}----{:9}'.format('', '') @@ -164,6 +169,7 @@ if __name__ == '__main__': optspec = ( 'plot-unparam= plot-param= show-models= show-quality= ' 'ignored-trace-indexes= discard-outliers= function-override= ' + 'cross-validate ' 'with-safe-functions hwmodel= export-energymodel=' ) raw_opts, args = getopt.getopt(sys.argv[1:], "", optspec.split(' ')) @@ -212,6 +218,8 @@ if __name__ == '__main__': function_override = function_override, hwmodel = hwmodel) + if 'cross-validate' in opts: + xv = CrossValidator(PTAModel, by_name, parameters, arg_count) if 'plot-unparam' in opts: for kv in opts['plot-unparam'].split(';'): @@ -242,12 +250,20 @@ if __name__ == '__main__': model.stats.generic_param_dependence_ratio(trans, 'rel_energy_prev'), model.stats.generic_param_dependence_ratio(trans, 'rel_energy_next'))) print('{:10s}: {:.0f} µs'.format(trans, static_model(trans, 'duration'))) - static_quality = model.assess(static_model) + + if 'cross-validate' in opts: + static_quality = xv.montecarlo(lambda m: m.get_static()) + else: + static_quality = model.assess(static_model) if len(show_models): print('--- LUT ---') lut_model = model.get_param_lut() - lut_quality = model.assess(lut_model) + + if 'cross-validate' in opts: + lut_quality = xv.montecarlo(lambda m: m.get_param_lut()) + else: + lut_quality = model.assess(lut_model) if len(show_models): print('--- param model ---') @@ -288,7 +304,11 @@ if __name__ == '__main__': if param_info(trans, attribute): print('{:10s}: {:10s}: {}'.format(trans, attribute, param_info(trans, attribute)['function']._model_str)) print('{:10s} {:10s} {}'.format('', '', param_info(trans, attribute)['function']._regression_args)) - analytic_quality = model.assess(param_model) + + if 'cross-validate' in opts: + analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0]) + else: + analytic_quality = model.assess(param_model) if 'tex' in show_models or 'tex' in show_quality: print_text_model_data(model, static_model, static_quality, lut_model, lut_quality, param_model, param_info, analytic_quality) |