summaryrefslogtreecommitdiff
path: root/lib/dfatool.py
diff options
context:
space:
mode:
authorDaniel Friesel <derf@finalrewind.org>2018-04-10 09:41:23 +0200
committerDaniel Friesel <derf@finalrewind.org>2018-04-10 09:41:23 +0200
commit35c25777d7898190bc81a990d6fa40af28e152e9 (patch)
tree722e006982314ed5dec3b2ff849e6eab3fe4e861 /lib/dfatool.py
parent4b3ff326b8f97b5f849b8619d984d2e2d50ab9b9 (diff)
add support for safe division/log/sqrt
Diffstat (limited to 'lib/dfatool.py')
-rwxr-xr-xlib/dfatool.py39
1 files changed, 28 insertions, 11 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py
index 17e7014..dd8e288 100755
--- a/lib/dfatool.py
+++ b/lib/dfatool.py
@@ -16,6 +16,7 @@ import tarfile
from multiprocessing import Pool
arg_support_enabled = True
+safe_functions_enabled = True
def running_mean(x, N):
cumsum = np.cumsum(np.insert(x, 0, 0))
@@ -545,6 +546,9 @@ class analytic:
_num0_8 = np.vectorize(lambda x: 8 - bin(int(x)).count("1"))
_num0_16 = np.vectorize(lambda x: 16 - bin(int(x)).count("1"))
_num1 = np.vectorize(lambda x: bin(int(x)).count("1"))
+ _safe_log = np.vectorize(lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1.)
+ _safe_frac = np.vectorize(lambda x: 1 / x if np.abs(x) > 0.001 else 1.)
+ _safe_sqrt = np.vectorize(lambda x: np.sqrt(np.abs(x)))
_function_map = {
'linear' : lambda x: x,
@@ -557,6 +561,9 @@ class analytic:
'num0_8' : _num0_8,
'num0_16' : _num0_16,
'num1' : _num1,
+ 'safe_log' : lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1.,
+ 'safe_frac' : lambda x: 1 / x if np.abs(x) > 0.001 else 1.,
+ 'safe_sqrt': lambda x: np.sqrt(np.abs(x)),
}
def functions():
@@ -614,6 +621,23 @@ class analytic:
),
}
+ if safe_functions_enabled:
+ functions['safe_log'] = ParamFunction(
+ lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._safe_log(model_param),
+ lambda model_param: True,
+ 2
+ )
+ functions['safe_frac'] = ParamFunction(
+ lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._safe_frac(model_param),
+ lambda model_param: True,
+ 2
+ )
+ functions['safe_sqrt'] = ParamFunction(
+ lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._safe_sqrt(model_param),
+ lambda model_param: True,
+ 2
+ )
+
return functions
def _fmap(reference_type, reference_name, function_type):
@@ -649,17 +673,6 @@ class analytic:
buf += ' * {}'.format(analytic._fmap('parameter', function_item[0], function_item[1]['best']))
return AnalyticFunction(buf, arg_idx, parameter_names, num_args)
- #def function_powerset(function_descriptions):
- # function_buffer = lambda param, arg: 0
- # param_idx = 0
- # for combination in powerset(function_descriptions):
- # new_function = lambda param, arg: param[param_idx]
- # param_idx += 1
- # for function_name in combination:
- # new_function = lambda param, arg: new_function(param, arg) * analytic._function_map[function_name](arg)
- # new_function = lambda param, arg: param[param_idx] *
- # function_buffer = lambda param, arg: function_buffer(param, arg) +
-
def _try_fits_parallel(arg):
return {
'key' : arg['key'],
@@ -839,6 +852,10 @@ class EnergyModel:
self.stats[state_or_trans] = {}
for key in self.by_name[state_or_trans]['attributes']:
if key in self.by_name[state_or_trans]:
+ #try:
+ # print(state_or_trans, key, np.corrcoef(self.by_name[state_or_trans][key], np.array(self.by_name[state_or_trans]['param']).T))
+ #except TypeError as e:
+ # print(state_or_trans, key, e)
self.stats[state_or_trans][key] = _compute_param_statistics(self.by_name, self.by_param, self._parameter_names, self._num_args, state_or_trans, key)
#queue.append({
# 'state_or_trans' : state_or_trans,