diff options
Diffstat (limited to 'lib/dfatool.py')
-rwxr-xr-x | lib/dfatool.py | 46 |
1 files changed, 27 insertions, 19 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py index 6e0e829..30ffbd5 100755 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -466,7 +466,7 @@ class ParamFunction: class AnalyticFunction: - def __init__(self, function_str, num_vars, parameters, num_args, verbose = True): + def __init__(self, function_str, parameters, num_args, verbose = True, regression_args = None, function_lambda = None): self._parameter_names = parameters self._num_args = num_args self._model_str = function_str @@ -475,19 +475,30 @@ class AnalyticFunction: self.fit_success = False self.verbose = verbose - for i in range(len(parameters)): - if rawfunction.find('parameter({})'.format(parameters[i])) >= 0: - self._dependson[i] = True - rawfunction = rawfunction.replace('parameter({})'.format(parameters[i]), 'model_param[{:d}]'.format(i)) - for i in range(0, num_args): - if rawfunction.find('function_arg({:d})'.format(i)) >= 0: - self._dependson[len(parameters) + i] = True - rawfunction = rawfunction.replace('function_arg({:d})'.format(i), 'model_param[{:d}]'.format(len(parameters) + i)) - for i in range(num_vars): - rawfunction = rawfunction.replace('regression_arg({:d})'.format(i), 'reg_param[{:d}]'.format(i)) - self._function_str = rawfunction - self._function = eval('lambda reg_param, model_param: ' + rawfunction); - self._regression_args = list(np.ones((num_vars))) + if type(function_str) == str: + num_vars_re = re.compile(r'regression_arg\(([0-9]+)\)') + num_vars = max(map(int, num_vars_re.findall(function_str))) + 1 + for i in range(len(parameters)): + if rawfunction.find('parameter({})'.format(parameters[i])) >= 0: + self._dependson[i] = True + rawfunction = rawfunction.replace('parameter({})'.format(parameters[i]), 'model_param[{:d}]'.format(i)) + for i in range(0, num_args): + if rawfunction.find('function_arg({:d})'.format(i)) >= 0: + self._dependson[len(parameters) + i] = True + rawfunction = rawfunction.replace('function_arg({:d})'.format(i), 'model_param[{:d}]'.format(len(parameters) + i)) + for i in range(num_vars): + rawfunction = rawfunction.replace('regression_arg({:d})'.format(i), 'reg_param[{:d}]'.format(i)) + self._function_str = rawfunction + self._function = eval('lambda reg_param, model_param: ' + rawfunction) + elif type(function_str) == function: + self._function_str = 'raise ValueError' + self._function = function_str + + if regression_args: + self._regression_args = regression_args.copy() + self._fit_success = True + else: + self._regression_args = list(np.ones((num_vars))) def get_fit_data(self, by_param, state_or_tran, model_attribute): dimension = len(self._parameter_names) + self._num_args @@ -674,7 +685,7 @@ class analytic: buf += ' * {}'.format(analytic._fmap('function_arg', function_item[0], function_item[1]['best'])) else: buf += ' * {}'.format(analytic._fmap('parameter', function_item[0], function_item[1]['best'])) - return AnalyticFunction(buf, arg_idx, parameter_names, num_args) + return AnalyticFunction(buf, parameter_names, num_args) def _try_fits_parallel(arg): return { @@ -1109,10 +1120,7 @@ class EnergyModel: if (state_or_tran, model_attribute) in self.function_override: function_str = self.function_override[(state_or_tran, model_attribute)] - var_re = re.compile(r'regression_arg\(([0-9]*)\)') - var_count = max(map(int, var_re.findall(function_str))) + 1 - x = AnalyticFunction(function_str, - var_count, self._parameter_names, num_args) + x = AnalyticFunction(function_str, self._parameter_names, num_args) x.fit(self.by_param, state_or_tran, model_attribute) if x.fit_success: param_model[state_or_tran][model_attribute] = { |