summaryrefslogtreecommitdiff
path: root/bin/analyze-archive.py
diff options
context:
space:
mode:
authorDaniel Friesel <derf@finalrewind.org>2019-02-13 16:33:00 +0100
committerDaniel Friesel <derf@finalrewind.org>2019-02-13 16:33:00 +0100
commit906a01beeb24fafa691f585018a0ec809a88de08 (patch)
tree006a3c421c8360d6b3be71710ab7549b99f6bce5 /bin/analyze-archive.py
parentdf1beb6debcd7c527d49a1a722b023b29f38b859 (diff)
add generic monte carlo cross validation
Diffstat (limited to 'bin/analyze-archive.py')
-rwxr-xr-xbin/analyze-archive.py32
1 files changed, 26 insertions, 6 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py
index 21f9ceb..007ec25 100755
--- a/bin/analyze-archive.py
+++ b/bin/analyze-archive.py
@@ -41,6 +41,10 @@ Options:
--discard-outliers=
not supported at the moment
+--cross-validate
+ Perform cross validation when computing model quality.
+ Only works with --show-quality=table at the moment.
+
--function-override=<name attribute function>[;<name> <attribute> <function>;...]
Manually specify the function to fit for <name> <attribute>. A function
specified this way bypasses parameter detection: It is always assigned,
@@ -65,6 +69,7 @@ import re
import sys
from dfatool import PTAModel, RawData, pta_trace_to_aggregate
from dfatool import soft_cast_int, is_numeric, gplearn_to_function
+from dfatool import CrossValidator
opts = {}
@@ -86,14 +91,14 @@ def format_quality_measures(result):
return '{:6} {:9.0f}'.format('', result['mae'])
def model_quality_table(result_lists, info_list):
- for state_or_tran in result_lists[0]['by_dfa_component'].keys():
- for key in result_lists[0]['by_dfa_component'][state_or_tran].keys():
+ for state_or_tran in result_lists[0]['by_name'].keys():
+ for key in result_lists[0]['by_name'][state_or_tran].keys():
buf = '{:20s} {:15s}'.format(state_or_tran, key)
for i, results in enumerate(result_lists):
info = info_list[i]
buf += ' ||| '
if info == None or info(state_or_tran, key):
- result = results['by_dfa_component'][state_or_tran][key]
+ result = results['by_name'][state_or_tran][key]
buf += format_quality_measures(result)
else:
buf += '{:6}----{:9}'.format('', '')
@@ -164,6 +169,7 @@ if __name__ == '__main__':
optspec = (
'plot-unparam= plot-param= show-models= show-quality= '
'ignored-trace-indexes= discard-outliers= function-override= '
+ 'cross-validate '
'with-safe-functions hwmodel= export-energymodel='
)
raw_opts, args = getopt.getopt(sys.argv[1:], "", optspec.split(' '))
@@ -212,6 +218,8 @@ if __name__ == '__main__':
function_override = function_override,
hwmodel = hwmodel)
+ if 'cross-validate' in opts:
+ xv = CrossValidator(PTAModel, by_name, parameters, arg_count)
if 'plot-unparam' in opts:
for kv in opts['plot-unparam'].split(';'):
@@ -242,12 +250,20 @@ if __name__ == '__main__':
model.stats.generic_param_dependence_ratio(trans, 'rel_energy_prev'),
model.stats.generic_param_dependence_ratio(trans, 'rel_energy_next')))
print('{:10s}: {:.0f} µs'.format(trans, static_model(trans, 'duration')))
- static_quality = model.assess(static_model)
+
+ if 'cross-validate' in opts:
+ static_quality = xv.montecarlo(lambda m: m.get_static())
+ else:
+ static_quality = model.assess(static_model)
if len(show_models):
print('--- LUT ---')
lut_model = model.get_param_lut()
- lut_quality = model.assess(lut_model)
+
+ if 'cross-validate' in opts:
+ lut_quality = xv.montecarlo(lambda m: m.get_param_lut())
+ else:
+ lut_quality = model.assess(lut_model)
if len(show_models):
print('--- param model ---')
@@ -288,7 +304,11 @@ if __name__ == '__main__':
if param_info(trans, attribute):
print('{:10s}: {:10s}: {}'.format(trans, attribute, param_info(trans, attribute)['function']._model_str))
print('{:10s} {:10s} {}'.format('', '', param_info(trans, attribute)['function']._regression_args))
- analytic_quality = model.assess(param_model)
+
+ if 'cross-validate' in opts:
+ analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0])
+ else:
+ analytic_quality = model.assess(param_model)
if 'tex' in show_models or 'tex' in show_quality:
print_text_model_data(model, static_model, static_quality, lut_model, lut_quality, param_model, param_info, analytic_quality)