summaryrefslogtreecommitdiff
path: root/lib/dfatool.py
diff options
context:
space:
mode:
authorDaniel Friesel <derf@finalrewind.org>2018-04-16 13:16:00 +0200
committerDaniel Friesel <derf@finalrewind.org>2018-04-17 10:04:41 +0200
commit67ae1c880ca856f0dcec4a64f7d1dd63f4f3147b (patch)
tree663188ac5dcfbcdb1fc562b4436daf0ce40d5854 /lib/dfatool.py
parent3451efd9c493f311a31f4e573f153a64e86f7aae (diff)
Properly toggle safe functions feature from analyze-archive.py
Diffstat (limited to 'lib/dfatool.py')
-rwxr-xr-xlib/dfatool.py15
1 files changed, 7 insertions, 8 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py
index 76bf51a..ad37f72 100755
--- a/lib/dfatool.py
+++ b/lib/dfatool.py
@@ -16,7 +16,6 @@ import tarfile
from multiprocessing import Pool
arg_support_enabled = True
-safe_functions_enabled = True
def running_mean(x, N):
cumsum = np.cumsum(np.insert(x, 0, 0))
@@ -557,7 +556,7 @@ class analytic:
'exponential' : np.exp,
'square' : lambda x : x ** 2,
'inverse' : lambda x : 1 / x,
- 'sqrt' : np.sqrt,
+ 'sqrt' : lambda x: np.sqrt(np.abs(x)),
'num0_8' : _num0_8,
'num0_16' : _num0_16,
'num1' : _num1,
@@ -566,7 +565,7 @@ class analytic:
'safe_sqrt': lambda x: np.sqrt(np.abs(x)),
}
- def functions():
+ def functions(safe_functions_enabled = False):
functions = {
'linear' : ParamFunction(
lambda reg_param, model_param: reg_param[0] + reg_param[1] * model_param,
@@ -680,8 +679,8 @@ def _try_fits_parallel(arg):
}
-def _try_fits(by_param, state_or_tran, model_attribute, param_index):
- functions = analytic.functions()
+def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functions_enabled = False):
+ functions = analytic.functions(safe_functions_enabled = safe_functions_enabled)
for param_key in filter(lambda x: x[0] == state_or_tran, by_param.keys()):
@@ -1006,7 +1005,7 @@ class EnergyModel:
return self._parameter_names[param_index]
return str(param_index)
- def get_fitted(self):
+ def get_fitted(self, safe_functions_enabled = False):
if 'fitted_model_getter' in self.cache and 'fitted_info_getter' in self.cache:
return self.cache['fitted_model_getter'], self.cache['fitted_info_getter']
@@ -1023,7 +1022,7 @@ class EnergyModel:
if self.param_dependence_ratio(state_or_tran, model_attribute, parameter_name) > 0.5:
fit_queue.append({
'key' : [state_or_tran, model_attribute, parameter_name],
- 'args' : [self.by_param, state_or_tran, model_attribute, parameter_index]
+ 'args' : [self.by_param, state_or_tran, model_attribute, parameter_index, safe_functions_enabled]
})
#fit_results[parameter_name] = _try_fits(self.by_param, state_or_tran, model_attribute, parameter_index)
#print('{} {} is {}'.format(state_or_tran, parameter_name, fit_results[parameter_name]['best']))
@@ -1032,7 +1031,7 @@ class EnergyModel:
if self.arg_dependence_ratio(state_or_tran, model_attribute, arg_index) > 0.5:
fit_queue.append({
'key' : [state_or_tran, model_attribute, arg_index],
- 'args' : [param_subdict, state_or_tran, model_attribute, len(self._parameter_names) + arg_index]
+ 'args' : [param_subdict, state_or_tran, model_attribute, len(self._parameter_names) + arg_index, safe_functions_enabled]
})
#fit_results[_arg_name(arg_index)] = _try_fits(self.by_param, state_or_tran, model_attribute, len(self._parameter_names) + arg_index)
#if 'args' in self.by_name[state_or_tran]: