diff options
author | Daniel Friesel <derf@finalrewind.org> | 2018-04-16 13:16:00 +0200 |
---|---|---|
committer | Daniel Friesel <derf@finalrewind.org> | 2018-04-17 10:04:41 +0200 |
commit | 67ae1c880ca856f0dcec4a64f7d1dd63f4f3147b (patch) | |
tree | 663188ac5dcfbcdb1fc562b4436daf0ce40d5854 | |
parent | 3451efd9c493f311a31f4e573f153a64e86f7aae (diff) |
Properly toggle safe functions feature from analyze-archive.py
-rwxr-xr-x | bin/analyze-archive.py | 9 | ||||
-rwxr-xr-x | bin/test.py | 4 | ||||
-rwxr-xr-x | lib/dfatool.py | 15 |
3 files changed, 16 insertions, 12 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py index 7779b4b..df3ecbe 100755 --- a/bin/analyze-archive.py +++ b/bin/analyze-archive.py @@ -60,12 +60,14 @@ if __name__ == '__main__': ignored_trace_indexes = None discard_outliers = None tex_output = False + safe_functions_enabled = False function_override = {} try: optspec = ( 'plot-unparam= plot-param= ' - 'ignored-trace-indexes= discard-outliers= function-override= tex-output' + 'ignored-trace-indexes= discard-outliers= function-override= tex-output ' + 'with-safe-functions' ) raw_opts, args = getopt.getopt(sys.argv[1:], "", optspec.split(' ')) @@ -89,6 +91,9 @@ if __name__ == '__main__': if 'tex-output' in opts: tex_output = True + if 'with-safe-functions' in opts: + safe_functions_enabled = True + except getopt.GetoptError as err: print(err) sys.exit(2) @@ -135,7 +140,7 @@ if __name__ == '__main__': lut_quality = model.assess(lut_model) print('--- param model ---') - param_model, param_info = model.get_fitted() + param_model, param_info = model.get_fitted(safe_functions_enabled = safe_functions_enabled) if not tex_output: for state in model.states(): for attribute in ['power']: diff --git a/bin/test.py b/bin/test.py index 60e8648..433b423 100755 --- a/bin/test.py +++ b/bin/test.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from dfatool import EnergyModel, RawData +from dfatool import EnergyModel, RawData, analytic import unittest class TestStaticModel(unittest.TestCase): @@ -229,5 +229,5 @@ class TestStaticModel(unittest.TestCase): self.assertAlmostEqual(param_info('RX', 'power')['function']._regression_args[1], 206, places=0) if __name__ == '__main__': - dfatool.safe_function_enabled = False + analytic.safe_function_enabled = False unittest.main() diff --git a/lib/dfatool.py b/lib/dfatool.py index 76bf51a..ad37f72 100755 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -16,7 +16,6 @@ import tarfile from multiprocessing import Pool arg_support_enabled = True -safe_functions_enabled = True def running_mean(x, N): cumsum = np.cumsum(np.insert(x, 0, 0)) @@ -557,7 +556,7 @@ class analytic: 'exponential' : np.exp, 'square' : lambda x : x ** 2, 'inverse' : lambda x : 1 / x, - 'sqrt' : np.sqrt, + 'sqrt' : lambda x: np.sqrt(np.abs(x)), 'num0_8' : _num0_8, 'num0_16' : _num0_16, 'num1' : _num1, @@ -566,7 +565,7 @@ class analytic: 'safe_sqrt': lambda x: np.sqrt(np.abs(x)), } - def functions(): + def functions(safe_functions_enabled = False): functions = { 'linear' : ParamFunction( lambda reg_param, model_param: reg_param[0] + reg_param[1] * model_param, @@ -680,8 +679,8 @@ def _try_fits_parallel(arg): } -def _try_fits(by_param, state_or_tran, model_attribute, param_index): - functions = analytic.functions() +def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functions_enabled = False): + functions = analytic.functions(safe_functions_enabled = safe_functions_enabled) for param_key in filter(lambda x: x[0] == state_or_tran, by_param.keys()): @@ -1006,7 +1005,7 @@ class EnergyModel: return self._parameter_names[param_index] return str(param_index) - def get_fitted(self): + def get_fitted(self, safe_functions_enabled = False): if 'fitted_model_getter' in self.cache and 'fitted_info_getter' in self.cache: return self.cache['fitted_model_getter'], self.cache['fitted_info_getter'] @@ -1023,7 +1022,7 @@ class EnergyModel: if self.param_dependence_ratio(state_or_tran, model_attribute, parameter_name) > 0.5: fit_queue.append({ 'key' : [state_or_tran, model_attribute, parameter_name], - 'args' : [self.by_param, state_or_tran, model_attribute, parameter_index] + 'args' : [self.by_param, state_or_tran, model_attribute, parameter_index, safe_functions_enabled] }) #fit_results[parameter_name] = _try_fits(self.by_param, state_or_tran, model_attribute, parameter_index) #print('{} {} is {}'.format(state_or_tran, parameter_name, fit_results[parameter_name]['best'])) @@ -1032,7 +1031,7 @@ class EnergyModel: if self.arg_dependence_ratio(state_or_tran, model_attribute, arg_index) > 0.5: fit_queue.append({ 'key' : [state_or_tran, model_attribute, arg_index], - 'args' : [param_subdict, state_or_tran, model_attribute, len(self._parameter_names) + arg_index] + 'args' : [param_subdict, state_or_tran, model_attribute, len(self._parameter_names) + arg_index, safe_functions_enabled] }) #fit_results[_arg_name(arg_index)] = _try_fits(self.by_param, state_or_tran, model_attribute, len(self._parameter_names) + arg_index) #if 'args' in self.by_name[state_or_tran]: |