diff options
-rwxr-xr-x | bin/analyze-archive.py | 29 | ||||
-rwxr-xr-x | bin/test.py | 2 | ||||
-rwxr-xr-x | lib/dfatool.py | 31 |
3 files changed, 44 insertions, 18 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py index 007ec25..5f22b18 100755 --- a/bin/analyze-archive.py +++ b/bin/analyze-archive.py @@ -41,9 +41,14 @@ Options: --discard-outliers= not supported at the moment ---cross-validate +--cross-validate=<method>:<count> Perform cross validation when computing model quality. Only works with --show-quality=table at the moment. + If <method> is "montecarlo": Randomly divide data into 2/3 training and 1/3 + validation, <count> times. Reported model quality is the average of all + validation runs. Data is partitioned without regard for parameter values, + so a specific parameter combination may be present in both training and + validation sets or just one of them. --function-override=<name attribute function>[;<name> <attribute> <function>;...] Manually specify the function to fit for <name> <attribute>. A function @@ -164,12 +169,14 @@ if __name__ == '__main__': show_quality = [] hwmodel = None energymodel_export_file = None + xv_method = None + xv_count = 10 try: optspec = ( 'plot-unparam= plot-param= show-models= show-quality= ' 'ignored-trace-indexes= discard-outliers= function-override= ' - 'cross-validate ' + 'cross-validate= ' 'with-safe-functions hwmodel= export-energymodel=' ) raw_opts, args = getopt.getopt(sys.argv[1:], "", optspec.split(' ')) @@ -197,6 +204,10 @@ if __name__ == '__main__': if 'show-quality' in opts: show_quality = opts['show-quality'].split(',') + if 'cross-validate' in opts: + xv_method, xv_count = opts['cross-validate'].split(':') + xv_count = int(xv_count) + if 'with-safe-functions' in opts: safe_functions_enabled = True @@ -218,7 +229,7 @@ if __name__ == '__main__': function_override = function_override, hwmodel = hwmodel) - if 'cross-validate' in opts: + if xv_method: xv = CrossValidator(PTAModel, by_name, parameters, arg_count) if 'plot-unparam' in opts: @@ -251,8 +262,8 @@ if __name__ == '__main__': model.stats.generic_param_dependence_ratio(trans, 'rel_energy_next'))) print('{:10s}: {:.0f} µs'.format(trans, static_model(trans, 'duration'))) - if 'cross-validate' in opts: - static_quality = xv.montecarlo(lambda m: m.get_static()) + if xv_method == 'montecarlo': + static_quality = xv.montecarlo(lambda m: m.get_static(), xv_count) else: static_quality = model.assess(static_model) @@ -260,8 +271,8 @@ if __name__ == '__main__': print('--- LUT ---') lut_model = model.get_param_lut() - if 'cross-validate' in opts: - lut_quality = xv.montecarlo(lambda m: m.get_param_lut()) + if xv_method == 'montecarlo': + lut_quality = xv.montecarlo(lambda m: m.get_param_lut(fallback=True), xv_count) else: lut_quality = model.assess(lut_model) @@ -305,8 +316,8 @@ if __name__ == '__main__': print('{:10s}: {:10s}: {}'.format(trans, attribute, param_info(trans, attribute)['function']._model_str)) print('{:10s} {:10s} {}'.format('', '', param_info(trans, attribute)['function']._regression_args)) - if 'cross-validate' in opts: - analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0]) + if xv_method == 'montecarlo': + analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) else: analytic_quality = model.assess(param_model) diff --git a/bin/test.py b/bin/test.py index f53a9ef..c05ed51 100755 --- a/bin/test.py +++ b/bin/test.py @@ -87,7 +87,7 @@ class TestStaticModel(unittest.TestCase): self.assertAlmostEqual(static_model('off', 'duration'), 9130, places=0) self.assertAlmostEqual(static_model('setBrightness', 'duration'), 9130, places=0) - param_lut_model = model.get_param_lut() + param_lut_model = model.get_param_lut(fallback=True) self.assertAlmostEqual(param_lut_model('OFF', 'power', param=[None, None]), 7124, places=0) with self.assertRaises(KeyError): param_lut_model('ON', 'power', param=[None, None]) diff --git a/lib/dfatool.py b/lib/dfatool.py index 82ed35f..c4529ce 100755 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -295,7 +295,7 @@ class CrossValidator: self.parameters = sorted(parameters) self.arg_count = arg_count - def montecarlo(self, model_getter, count = 2): + def montecarlo(self, model_getter, count = 200): """ Perform Monte Carlo cross-validation and return average model quality. @@ -1013,18 +1013,27 @@ class AnalyticModel: return static_mean_getter - def get_param_lut(self): + def get_param_lut(self, fallback = False): """ Get parameter-look-up-table model function: name, attribute, parameter values -> model value. The function can only give model values for parameter combinations - present in by_param. It raises KeyError for other values. + present in by_param. By default, it raises KeyError for other values. + + arguments: + fallback -- Fall back to the (non-parameter-aware) static model when encountering unknown parameter values """ + static_model = self._get_model_from_dict(self.by_name, np.median) lut_model = self._get_model_from_dict(self.by_param, np.median) def lut_median_getter(name, key, param, arg = [], **kwargs): param.extend(map(soft_cast_int, arg)) - return lut_model[(name, tuple(param))][key] + try: + return lut_model[(name, tuple(param))][key] + except KeyError: + if fallback: + return static_model[name][key] + raise return lut_median_getter @@ -1380,21 +1389,27 @@ class PTAModel: return static_mean_getter - def get_param_lut(self): + def get_param_lut(self, fallback = False): """ Get parameter-look-up-table model function: name, attribute, parameter values -> model value. The function can only give model values for parameter combinations - present in by_param. It raises KeyError for other values. + present in by_param. By default, it raises KeyError for other values. + + arguments: + fallback -- Fall back to the (non-parameter-aware) static model when encountering unknown parameter values """ static_model = self._get_model_from_dict(self.by_name, np.median) lut_model = self._get_model_from_dict(self.by_param, np.median) def lut_median_getter(name, key, param, arg = [], **kwargs): param.extend(map(soft_cast_int, arg)) - if (name, tuple(param)) in lut_model: + try: return lut_model[(name, tuple(param))][key] - return static_model[name][key] + except KeyError: + if fallback: + return static_model[name][key] + raise return lut_median_getter |