diff options
Diffstat (limited to 'lib')
-rwxr-xr-x | lib/dfatool.py | 15 |
1 files changed, 7 insertions, 8 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py index 76bf51a..ad37f72 100755 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -16,7 +16,6 @@ import tarfile from multiprocessing import Pool arg_support_enabled = True -safe_functions_enabled = True def running_mean(x, N): cumsum = np.cumsum(np.insert(x, 0, 0)) @@ -557,7 +556,7 @@ class analytic: 'exponential' : np.exp, 'square' : lambda x : x ** 2, 'inverse' : lambda x : 1 / x, - 'sqrt' : np.sqrt, + 'sqrt' : lambda x: np.sqrt(np.abs(x)), 'num0_8' : _num0_8, 'num0_16' : _num0_16, 'num1' : _num1, @@ -566,7 +565,7 @@ class analytic: 'safe_sqrt': lambda x: np.sqrt(np.abs(x)), } - def functions(): + def functions(safe_functions_enabled = False): functions = { 'linear' : ParamFunction( lambda reg_param, model_param: reg_param[0] + reg_param[1] * model_param, @@ -680,8 +679,8 @@ def _try_fits_parallel(arg): } -def _try_fits(by_param, state_or_tran, model_attribute, param_index): - functions = analytic.functions() +def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functions_enabled = False): + functions = analytic.functions(safe_functions_enabled = safe_functions_enabled) for param_key in filter(lambda x: x[0] == state_or_tran, by_param.keys()): @@ -1006,7 +1005,7 @@ class EnergyModel: return self._parameter_names[param_index] return str(param_index) - def get_fitted(self): + def get_fitted(self, safe_functions_enabled = False): if 'fitted_model_getter' in self.cache and 'fitted_info_getter' in self.cache: return self.cache['fitted_model_getter'], self.cache['fitted_info_getter'] @@ -1023,7 +1022,7 @@ class EnergyModel: if self.param_dependence_ratio(state_or_tran, model_attribute, parameter_name) > 0.5: fit_queue.append({ 'key' : [state_or_tran, model_attribute, parameter_name], - 'args' : [self.by_param, state_or_tran, model_attribute, parameter_index] + 'args' : [self.by_param, state_or_tran, model_attribute, parameter_index, safe_functions_enabled] }) #fit_results[parameter_name] = _try_fits(self.by_param, state_or_tran, model_attribute, parameter_index) #print('{} {} is {}'.format(state_or_tran, parameter_name, fit_results[parameter_name]['best'])) @@ -1032,7 +1031,7 @@ class EnergyModel: if self.arg_dependence_ratio(state_or_tran, model_attribute, arg_index) > 0.5: fit_queue.append({ 'key' : [state_or_tran, model_attribute, arg_index], - 'args' : [param_subdict, state_or_tran, model_attribute, len(self._parameter_names) + arg_index] + 'args' : [param_subdict, state_or_tran, model_attribute, len(self._parameter_names) + arg_index, safe_functions_enabled] }) #fit_results[_arg_name(arg_index)] = _try_fits(self.by_param, state_or_tran, model_attribute, len(self._parameter_names) + arg_index) #if 'args' in self.by_name[state_or_tran]: |