diff options
author | Daniel Friesel <derf@finalrewind.org> | 2018-04-10 09:41:23 +0200 |
---|---|---|
committer | Daniel Friesel <derf@finalrewind.org> | 2018-04-10 09:41:23 +0200 |
commit | 35c25777d7898190bc81a990d6fa40af28e152e9 (patch) | |
tree | 722e006982314ed5dec3b2ff849e6eab3fe4e861 /lib | |
parent | 4b3ff326b8f97b5f849b8619d984d2e2d50ab9b9 (diff) |
add support for safe division/log/sqrt
Diffstat (limited to 'lib')
-rwxr-xr-x | lib/dfatool.py | 39 |
1 files changed, 28 insertions, 11 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py index 17e7014..dd8e288 100755 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -16,6 +16,7 @@ import tarfile from multiprocessing import Pool arg_support_enabled = True +safe_functions_enabled = True def running_mean(x, N): cumsum = np.cumsum(np.insert(x, 0, 0)) @@ -545,6 +546,9 @@ class analytic: _num0_8 = np.vectorize(lambda x: 8 - bin(int(x)).count("1")) _num0_16 = np.vectorize(lambda x: 16 - bin(int(x)).count("1")) _num1 = np.vectorize(lambda x: bin(int(x)).count("1")) + _safe_log = np.vectorize(lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1.) + _safe_frac = np.vectorize(lambda x: 1 / x if np.abs(x) > 0.001 else 1.) + _safe_sqrt = np.vectorize(lambda x: np.sqrt(np.abs(x))) _function_map = { 'linear' : lambda x: x, @@ -557,6 +561,9 @@ class analytic: 'num0_8' : _num0_8, 'num0_16' : _num0_16, 'num1' : _num1, + 'safe_log' : lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1., + 'safe_frac' : lambda x: 1 / x if np.abs(x) > 0.001 else 1., + 'safe_sqrt': lambda x: np.sqrt(np.abs(x)), } def functions(): @@ -614,6 +621,23 @@ class analytic: ), } + if safe_functions_enabled: + functions['safe_log'] = ParamFunction( + lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._safe_log(model_param), + lambda model_param: True, + 2 + ) + functions['safe_frac'] = ParamFunction( + lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._safe_frac(model_param), + lambda model_param: True, + 2 + ) + functions['safe_sqrt'] = ParamFunction( + lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._safe_sqrt(model_param), + lambda model_param: True, + 2 + ) + return functions def _fmap(reference_type, reference_name, function_type): @@ -649,17 +673,6 @@ class analytic: buf += ' * {}'.format(analytic._fmap('parameter', function_item[0], function_item[1]['best'])) return AnalyticFunction(buf, arg_idx, parameter_names, num_args) - #def function_powerset(function_descriptions): - # function_buffer = lambda param, arg: 0 - # param_idx = 0 - # for combination in powerset(function_descriptions): - # new_function = lambda param, arg: param[param_idx] - # param_idx += 1 - # for function_name in combination: - # new_function = lambda param, arg: new_function(param, arg) * analytic._function_map[function_name](arg) - # new_function = lambda param, arg: param[param_idx] * - # function_buffer = lambda param, arg: function_buffer(param, arg) + - def _try_fits_parallel(arg): return { 'key' : arg['key'], @@ -839,6 +852,10 @@ class EnergyModel: self.stats[state_or_trans] = {} for key in self.by_name[state_or_trans]['attributes']: if key in self.by_name[state_or_trans]: + #try: + # print(state_or_trans, key, np.corrcoef(self.by_name[state_or_trans][key], np.array(self.by_name[state_or_trans]['param']).T)) + #except TypeError as e: + # print(state_or_trans, key, e) self.stats[state_or_trans][key] = _compute_param_statistics(self.by_name, self.by_param, self._parameter_names, self._num_args, state_or_trans, key) #queue.append({ # 'state_or_trans' : state_or_trans, |