summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <derf@finalrewind.org>2018-04-24 09:26:03 +0200
committerDaniel Friesel <derf@finalrewind.org>2018-04-24 09:28:07 +0200
commit2441a23874d58a0cb78a8051e404225c2a3990c1 (patch)
tree81fd2c0d2d061dbd72c1dcb54e2c9b05b900479e /lib
parent13810886493d2c7d8a3df9b9bf84392d1054e1c2 (diff)
AnalyticFunction: Compute num_vars from function_str
Diffstat (limited to 'lib')
-rwxr-xr-xlib/dfatool.py46
1 files changed, 27 insertions, 19 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py
index 6e0e829..30ffbd5 100755
--- a/lib/dfatool.py
+++ b/lib/dfatool.py
@@ -466,7 +466,7 @@ class ParamFunction:
class AnalyticFunction:
- def __init__(self, function_str, num_vars, parameters, num_args, verbose = True):
+ def __init__(self, function_str, parameters, num_args, verbose = True, regression_args = None, function_lambda = None):
self._parameter_names = parameters
self._num_args = num_args
self._model_str = function_str
@@ -475,19 +475,30 @@ class AnalyticFunction:
self.fit_success = False
self.verbose = verbose
- for i in range(len(parameters)):
- if rawfunction.find('parameter({})'.format(parameters[i])) >= 0:
- self._dependson[i] = True
- rawfunction = rawfunction.replace('parameter({})'.format(parameters[i]), 'model_param[{:d}]'.format(i))
- for i in range(0, num_args):
- if rawfunction.find('function_arg({:d})'.format(i)) >= 0:
- self._dependson[len(parameters) + i] = True
- rawfunction = rawfunction.replace('function_arg({:d})'.format(i), 'model_param[{:d}]'.format(len(parameters) + i))
- for i in range(num_vars):
- rawfunction = rawfunction.replace('regression_arg({:d})'.format(i), 'reg_param[{:d}]'.format(i))
- self._function_str = rawfunction
- self._function = eval('lambda reg_param, model_param: ' + rawfunction);
- self._regression_args = list(np.ones((num_vars)))
+ if type(function_str) == str:
+ num_vars_re = re.compile(r'regression_arg\(([0-9]+)\)')
+ num_vars = max(map(int, num_vars_re.findall(function_str))) + 1
+ for i in range(len(parameters)):
+ if rawfunction.find('parameter({})'.format(parameters[i])) >= 0:
+ self._dependson[i] = True
+ rawfunction = rawfunction.replace('parameter({})'.format(parameters[i]), 'model_param[{:d}]'.format(i))
+ for i in range(0, num_args):
+ if rawfunction.find('function_arg({:d})'.format(i)) >= 0:
+ self._dependson[len(parameters) + i] = True
+ rawfunction = rawfunction.replace('function_arg({:d})'.format(i), 'model_param[{:d}]'.format(len(parameters) + i))
+ for i in range(num_vars):
+ rawfunction = rawfunction.replace('regression_arg({:d})'.format(i), 'reg_param[{:d}]'.format(i))
+ self._function_str = rawfunction
+ self._function = eval('lambda reg_param, model_param: ' + rawfunction)
+ elif type(function_str) == function:
+ self._function_str = 'raise ValueError'
+ self._function = function_str
+
+ if regression_args:
+ self._regression_args = regression_args.copy()
+ self._fit_success = True
+ else:
+ self._regression_args = list(np.ones((num_vars)))
def get_fit_data(self, by_param, state_or_tran, model_attribute):
dimension = len(self._parameter_names) + self._num_args
@@ -674,7 +685,7 @@ class analytic:
buf += ' * {}'.format(analytic._fmap('function_arg', function_item[0], function_item[1]['best']))
else:
buf += ' * {}'.format(analytic._fmap('parameter', function_item[0], function_item[1]['best']))
- return AnalyticFunction(buf, arg_idx, parameter_names, num_args)
+ return AnalyticFunction(buf, parameter_names, num_args)
def _try_fits_parallel(arg):
return {
@@ -1109,10 +1120,7 @@ class EnergyModel:
if (state_or_tran, model_attribute) in self.function_override:
function_str = self.function_override[(state_or_tran, model_attribute)]
- var_re = re.compile(r'regression_arg\(([0-9]*)\)')
- var_count = max(map(int, var_re.findall(function_str))) + 1
- x = AnalyticFunction(function_str,
- var_count, self._parameter_names, num_args)
+ x = AnalyticFunction(function_str, self._parameter_names, num_args)
x.fit(self.by_param, state_or_tran, model_attribute)
if x.fit_success:
param_model[state_or_tran][model_attribute] = {