summaryrefslogtreecommitdiff
path: root/lib/dfatool.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/dfatool.py')
-rwxr-xr-xlib/dfatool.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py
index 30ffbd5..896dc12 100755
--- a/lib/dfatool.py
+++ b/lib/dfatool.py
@@ -466,7 +466,7 @@ class ParamFunction:
class AnalyticFunction:
- def __init__(self, function_str, parameters, num_args, verbose = True, regression_args = None, function_lambda = None):
+ def __init__(self, function_str, parameters, num_args, verbose = True, regression_args = None):
self._parameter_names = parameters
self._num_args = num_args
self._model_str = function_str
@@ -490,15 +490,17 @@ class AnalyticFunction:
rawfunction = rawfunction.replace('regression_arg({:d})'.format(i), 'reg_param[{:d}]'.format(i))
self._function_str = rawfunction
self._function = eval('lambda reg_param, model_param: ' + rawfunction)
- elif type(function_str) == function:
+ else:
self._function_str = 'raise ValueError'
self._function = function_str
if regression_args:
self._regression_args = regression_args.copy()
self._fit_success = True
- else:
+ elif type(function_str) == str:
self._regression_args = list(np.ones((num_vars)))
+ else:
+ self._regression_args = None
def get_fit_data(self, by_param, state_or_tran, model_attribute):
dimension = len(self._parameter_names) + self._num_args
@@ -553,7 +555,9 @@ class AnalyticFunction:
return False
return True
- def eval(self, param_list):
+ def eval(self, param_list, arg_list = []):
+ if self._regression_args == None:
+ return self._function(param_list, arg_list)
return self._function(self._regression_args, param_list)
class analytic: