summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xbin/analyze-archive.py29
-rwxr-xr-xbin/test.py2
-rwxr-xr-xlib/dfatool.py31
3 files changed, 44 insertions, 18 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py
index 007ec25..5f22b18 100755
--- a/bin/analyze-archive.py
+++ b/bin/analyze-archive.py
@@ -41,9 +41,14 @@ Options:
--discard-outliers=
not supported at the moment
---cross-validate
+--cross-validate=<method>:<count>
Perform cross validation when computing model quality.
Only works with --show-quality=table at the moment.
+ If <method> is "montecarlo": Randomly divide data into 2/3 training and 1/3
+ validation, <count> times. Reported model quality is the average of all
+ validation runs. Data is partitioned without regard for parameter values,
+ so a specific parameter combination may be present in both training and
+ validation sets or just one of them.
--function-override=<name attribute function>[;<name> <attribute> <function>;...]
Manually specify the function to fit for <name> <attribute>. A function
@@ -164,12 +169,14 @@ if __name__ == '__main__':
show_quality = []
hwmodel = None
energymodel_export_file = None
+ xv_method = None
+ xv_count = 10
try:
optspec = (
'plot-unparam= plot-param= show-models= show-quality= '
'ignored-trace-indexes= discard-outliers= function-override= '
- 'cross-validate '
+ 'cross-validate= '
'with-safe-functions hwmodel= export-energymodel='
)
raw_opts, args = getopt.getopt(sys.argv[1:], "", optspec.split(' '))
@@ -197,6 +204,10 @@ if __name__ == '__main__':
if 'show-quality' in opts:
show_quality = opts['show-quality'].split(',')
+ if 'cross-validate' in opts:
+ xv_method, xv_count = opts['cross-validate'].split(':')
+ xv_count = int(xv_count)
+
if 'with-safe-functions' in opts:
safe_functions_enabled = True
@@ -218,7 +229,7 @@ if __name__ == '__main__':
function_override = function_override,
hwmodel = hwmodel)
- if 'cross-validate' in opts:
+ if xv_method:
xv = CrossValidator(PTAModel, by_name, parameters, arg_count)
if 'plot-unparam' in opts:
@@ -251,8 +262,8 @@ if __name__ == '__main__':
model.stats.generic_param_dependence_ratio(trans, 'rel_energy_next')))
print('{:10s}: {:.0f} µs'.format(trans, static_model(trans, 'duration')))
- if 'cross-validate' in opts:
- static_quality = xv.montecarlo(lambda m: m.get_static())
+ if xv_method == 'montecarlo':
+ static_quality = xv.montecarlo(lambda m: m.get_static(), xv_count)
else:
static_quality = model.assess(static_model)
@@ -260,8 +271,8 @@ if __name__ == '__main__':
print('--- LUT ---')
lut_model = model.get_param_lut()
- if 'cross-validate' in opts:
- lut_quality = xv.montecarlo(lambda m: m.get_param_lut())
+ if xv_method == 'montecarlo':
+ lut_quality = xv.montecarlo(lambda m: m.get_param_lut(fallback=True), xv_count)
else:
lut_quality = model.assess(lut_model)
@@ -305,8 +316,8 @@ if __name__ == '__main__':
print('{:10s}: {:10s}: {}'.format(trans, attribute, param_info(trans, attribute)['function']._model_str))
print('{:10s} {:10s} {}'.format('', '', param_info(trans, attribute)['function']._regression_args))
- if 'cross-validate' in opts:
- analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0])
+ if xv_method == 'montecarlo':
+ analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count)
else:
analytic_quality = model.assess(param_model)
diff --git a/bin/test.py b/bin/test.py
index f53a9ef..c05ed51 100755
--- a/bin/test.py
+++ b/bin/test.py
@@ -87,7 +87,7 @@ class TestStaticModel(unittest.TestCase):
self.assertAlmostEqual(static_model('off', 'duration'), 9130, places=0)
self.assertAlmostEqual(static_model('setBrightness', 'duration'), 9130, places=0)
- param_lut_model = model.get_param_lut()
+ param_lut_model = model.get_param_lut(fallback=True)
self.assertAlmostEqual(param_lut_model('OFF', 'power', param=[None, None]), 7124, places=0)
with self.assertRaises(KeyError):
param_lut_model('ON', 'power', param=[None, None])
diff --git a/lib/dfatool.py b/lib/dfatool.py
index 82ed35f..c4529ce 100755
--- a/lib/dfatool.py
+++ b/lib/dfatool.py
@@ -295,7 +295,7 @@ class CrossValidator:
self.parameters = sorted(parameters)
self.arg_count = arg_count
- def montecarlo(self, model_getter, count = 2):
+ def montecarlo(self, model_getter, count = 200):
"""
Perform Monte Carlo cross-validation and return average model quality.
@@ -1013,18 +1013,27 @@ class AnalyticModel:
return static_mean_getter
- def get_param_lut(self):
+ def get_param_lut(self, fallback = False):
"""
Get parameter-look-up-table model function: name, attribute, parameter values -> model value.
The function can only give model values for parameter combinations
- present in by_param. It raises KeyError for other values.
+ present in by_param. By default, it raises KeyError for other values.
+
+ arguments:
+ fallback -- Fall back to the (non-parameter-aware) static model when encountering unknown parameter values
"""
+ static_model = self._get_model_from_dict(self.by_name, np.median)
lut_model = self._get_model_from_dict(self.by_param, np.median)
def lut_median_getter(name, key, param, arg = [], **kwargs):
param.extend(map(soft_cast_int, arg))
- return lut_model[(name, tuple(param))][key]
+ try:
+ return lut_model[(name, tuple(param))][key]
+ except KeyError:
+ if fallback:
+ return static_model[name][key]
+ raise
return lut_median_getter
@@ -1380,21 +1389,27 @@ class PTAModel:
return static_mean_getter
- def get_param_lut(self):
+ def get_param_lut(self, fallback = False):
"""
Get parameter-look-up-table model function: name, attribute, parameter values -> model value.
The function can only give model values for parameter combinations
- present in by_param. It raises KeyError for other values.
+ present in by_param. By default, it raises KeyError for other values.
+
+ arguments:
+ fallback -- Fall back to the (non-parameter-aware) static model when encountering unknown parameter values
"""
static_model = self._get_model_from_dict(self.by_name, np.median)
lut_model = self._get_model_from_dict(self.by_param, np.median)
def lut_median_getter(name, key, param, arg = [], **kwargs):
param.extend(map(soft_cast_int, arg))
- if (name, tuple(param)) in lut_model:
+ try:
return lut_model[(name, tuple(param))][key]
- return static_model[name][key]
+ except KeyError:
+ if fallback:
+ return static_model[name][key]
+ raise
return lut_median_getter