summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <derf@finalrewind.org>2018-04-16 13:16:00 +0200
committerDaniel Friesel <derf@finalrewind.org>2018-04-17 10:04:41 +0200
commit67ae1c880ca856f0dcec4a64f7d1dd63f4f3147b (patch)
tree663188ac5dcfbcdb1fc562b4436daf0ce40d5854
parent3451efd9c493f311a31f4e573f153a64e86f7aae (diff)
Properly toggle safe functions feature from analyze-archive.py
-rwxr-xr-xbin/analyze-archive.py9
-rwxr-xr-xbin/test.py4
-rwxr-xr-xlib/dfatool.py15
3 files changed, 16 insertions, 12 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py
index 7779b4b..df3ecbe 100755
--- a/bin/analyze-archive.py
+++ b/bin/analyze-archive.py
@@ -60,12 +60,14 @@ if __name__ == '__main__':
ignored_trace_indexes = None
discard_outliers = None
tex_output = False
+ safe_functions_enabled = False
function_override = {}
try:
optspec = (
'plot-unparam= plot-param= '
- 'ignored-trace-indexes= discard-outliers= function-override= tex-output'
+ 'ignored-trace-indexes= discard-outliers= function-override= tex-output '
+ 'with-safe-functions'
)
raw_opts, args = getopt.getopt(sys.argv[1:], "", optspec.split(' '))
@@ -89,6 +91,9 @@ if __name__ == '__main__':
if 'tex-output' in opts:
tex_output = True
+ if 'with-safe-functions' in opts:
+ safe_functions_enabled = True
+
except getopt.GetoptError as err:
print(err)
sys.exit(2)
@@ -135,7 +140,7 @@ if __name__ == '__main__':
lut_quality = model.assess(lut_model)
print('--- param model ---')
- param_model, param_info = model.get_fitted()
+ param_model, param_info = model.get_fitted(safe_functions_enabled = safe_functions_enabled)
if not tex_output:
for state in model.states():
for attribute in ['power']:
diff --git a/bin/test.py b/bin/test.py
index 60e8648..433b423 100755
--- a/bin/test.py
+++ b/bin/test.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
-from dfatool import EnergyModel, RawData
+from dfatool import EnergyModel, RawData, analytic
import unittest
class TestStaticModel(unittest.TestCase):
@@ -229,5 +229,5 @@ class TestStaticModel(unittest.TestCase):
self.assertAlmostEqual(param_info('RX', 'power')['function']._regression_args[1], 206, places=0)
if __name__ == '__main__':
- dfatool.safe_function_enabled = False
+ analytic.safe_function_enabled = False
unittest.main()
diff --git a/lib/dfatool.py b/lib/dfatool.py
index 76bf51a..ad37f72 100755
--- a/lib/dfatool.py
+++ b/lib/dfatool.py
@@ -16,7 +16,6 @@ import tarfile
from multiprocessing import Pool
arg_support_enabled = True
-safe_functions_enabled = True
def running_mean(x, N):
cumsum = np.cumsum(np.insert(x, 0, 0))
@@ -557,7 +556,7 @@ class analytic:
'exponential' : np.exp,
'square' : lambda x : x ** 2,
'inverse' : lambda x : 1 / x,
- 'sqrt' : np.sqrt,
+ 'sqrt' : lambda x: np.sqrt(np.abs(x)),
'num0_8' : _num0_8,
'num0_16' : _num0_16,
'num1' : _num1,
@@ -566,7 +565,7 @@ class analytic:
'safe_sqrt': lambda x: np.sqrt(np.abs(x)),
}
- def functions():
+ def functions(safe_functions_enabled = False):
functions = {
'linear' : ParamFunction(
lambda reg_param, model_param: reg_param[0] + reg_param[1] * model_param,
@@ -680,8 +679,8 @@ def _try_fits_parallel(arg):
}
-def _try_fits(by_param, state_or_tran, model_attribute, param_index):
- functions = analytic.functions()
+def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functions_enabled = False):
+ functions = analytic.functions(safe_functions_enabled = safe_functions_enabled)
for param_key in filter(lambda x: x[0] == state_or_tran, by_param.keys()):
@@ -1006,7 +1005,7 @@ class EnergyModel:
return self._parameter_names[param_index]
return str(param_index)
- def get_fitted(self):
+ def get_fitted(self, safe_functions_enabled = False):
if 'fitted_model_getter' in self.cache and 'fitted_info_getter' in self.cache:
return self.cache['fitted_model_getter'], self.cache['fitted_info_getter']
@@ -1023,7 +1022,7 @@ class EnergyModel:
if self.param_dependence_ratio(state_or_tran, model_attribute, parameter_name) > 0.5:
fit_queue.append({
'key' : [state_or_tran, model_attribute, parameter_name],
- 'args' : [self.by_param, state_or_tran, model_attribute, parameter_index]
+ 'args' : [self.by_param, state_or_tran, model_attribute, parameter_index, safe_functions_enabled]
})
#fit_results[parameter_name] = _try_fits(self.by_param, state_or_tran, model_attribute, parameter_index)
#print('{} {} is {}'.format(state_or_tran, parameter_name, fit_results[parameter_name]['best']))
@@ -1032,7 +1031,7 @@ class EnergyModel:
if self.arg_dependence_ratio(state_or_tran, model_attribute, arg_index) > 0.5:
fit_queue.append({
'key' : [state_or_tran, model_attribute, arg_index],
- 'args' : [param_subdict, state_or_tran, model_attribute, len(self._parameter_names) + arg_index]
+ 'args' : [param_subdict, state_or_tran, model_attribute, len(self._parameter_names) + arg_index, safe_functions_enabled]
})
#fit_results[_arg_name(arg_index)] = _try_fits(self.by_param, state_or_tran, model_attribute, len(self._parameter_names) + arg_index)
#if 'args' in self.by_name[state_or_tran]: