diff options
Diffstat (limited to 'lib')
| -rwxr-xr-x | lib/dfatool.py | 46 | 
1 files changed, 27 insertions, 19 deletions
| diff --git a/lib/dfatool.py b/lib/dfatool.py index 6e0e829..30ffbd5 100755 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -466,7 +466,7 @@ class ParamFunction:  class AnalyticFunction: -    def __init__(self, function_str, num_vars, parameters, num_args, verbose = True): +    def __init__(self, function_str, parameters, num_args, verbose = True, regression_args = None, function_lambda = None):          self._parameter_names = parameters          self._num_args = num_args          self._model_str = function_str @@ -475,19 +475,30 @@ class AnalyticFunction:          self.fit_success = False          self.verbose = verbose -        for i in range(len(parameters)): -            if rawfunction.find('parameter({})'.format(parameters[i])) >= 0: -                self._dependson[i] = True -                rawfunction = rawfunction.replace('parameter({})'.format(parameters[i]), 'model_param[{:d}]'.format(i)) -        for i in range(0, num_args): -            if rawfunction.find('function_arg({:d})'.format(i)) >= 0: -                self._dependson[len(parameters) + i] = True -                rawfunction = rawfunction.replace('function_arg({:d})'.format(i), 'model_param[{:d}]'.format(len(parameters) + i)) -        for i in range(num_vars): -            rawfunction = rawfunction.replace('regression_arg({:d})'.format(i), 'reg_param[{:d}]'.format(i)) -        self._function_str = rawfunction -        self._function = eval('lambda reg_param, model_param: ' + rawfunction); -        self._regression_args = list(np.ones((num_vars))) +        if type(function_str) == str: +            num_vars_re = re.compile(r'regression_arg\(([0-9]+)\)') +            num_vars = max(map(int, num_vars_re.findall(function_str))) + 1 +            for i in range(len(parameters)): +                if rawfunction.find('parameter({})'.format(parameters[i])) >= 0: +                    self._dependson[i] = True +                    rawfunction = rawfunction.replace('parameter({})'.format(parameters[i]), 'model_param[{:d}]'.format(i)) +            for i in range(0, num_args): +                if rawfunction.find('function_arg({:d})'.format(i)) >= 0: +                    self._dependson[len(parameters) + i] = True +                    rawfunction = rawfunction.replace('function_arg({:d})'.format(i), 'model_param[{:d}]'.format(len(parameters) + i)) +            for i in range(num_vars): +                rawfunction = rawfunction.replace('regression_arg({:d})'.format(i), 'reg_param[{:d}]'.format(i)) +            self._function_str = rawfunction +            self._function = eval('lambda reg_param, model_param: ' + rawfunction) +        elif type(function_str) == function: +            self._function_str = 'raise ValueError' +            self._function = function_str + +        if regression_args: +            self._regression_args = regression_args.copy() +            self._fit_success = True +        else: +            self._regression_args = list(np.ones((num_vars)))      def get_fit_data(self, by_param, state_or_tran, model_attribute):          dimension = len(self._parameter_names) + self._num_args @@ -674,7 +685,7 @@ class analytic:                      buf += ' * {}'.format(analytic._fmap('function_arg', function_item[0], function_item[1]['best']))                  else:                      buf += ' * {}'.format(analytic._fmap('parameter', function_item[0], function_item[1]['best'])) -        return AnalyticFunction(buf, arg_idx, parameter_names, num_args) +        return AnalyticFunction(buf, parameter_names, num_args)  def _try_fits_parallel(arg):      return { @@ -1109,10 +1120,7 @@ class EnergyModel:                  if (state_or_tran, model_attribute) in self.function_override:                      function_str = self.function_override[(state_or_tran, model_attribute)] -                    var_re = re.compile(r'regression_arg\(([0-9]*)\)') -                    var_count = max(map(int, var_re.findall(function_str))) + 1 -                    x = AnalyticFunction(function_str, -                        var_count, self._parameter_names, num_args) +                    x = AnalyticFunction(function_str, self._parameter_names, num_args)                      x.fit(self.by_param, state_or_tran, model_attribute)                      if x.fit_success:                          param_model[state_or_tran][model_attribute] = { | 
