diff options
Diffstat (limited to 'lib/functions.py')
-rw-r--r-- | lib/functions.py | 33 |
1 files changed, 16 insertions, 17 deletions
diff --git a/lib/functions.py b/lib/functions.py index 76be562..87fb2f0 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -14,9 +14,7 @@ arg_support_enabled = True def powerset(iterable): """ - Calculate powerset of `iterable` elements. - - Returns an iterable containing one tuple for each powerset element. + Return powerset of `iterable` elements. Example: `powerset([1, 2])` -> `[(), (1), (2), (1, 2)]` """ @@ -39,15 +37,14 @@ class ParamFunction: (-> single float as model input) are used. However, n-dimensional functions (-> list of float as model input) are also supported. - :param param_function: regression function. Must have the signature - (reg_param, model_param) -> float. + :param param_function: regression function (reg_param, model_param) -> float. reg_param is a list of regression variable values, model_param is the model input value (float). - Example: lambda rp, mp: rp[0] + rp[1] * mp + Example: `lambda rp, mp: rp[0] + rp[1] * mp` :param validation_function: function used to check whether param_function is defined for a given model_param. Signature: model_param -> bool - Example: lambda mp: mp > 0 + Example: `lambda mp: mp > 0` :param num_vars: How many regression variables are used by this function, i.e., the length of param_function's reg_param argument. """ @@ -55,32 +52,33 @@ class ParamFunction: self._validation_function = validation_function self._num_variables = num_vars - def is_valid(self, arg): + def is_valid(self, arg: float) -> bool: """ Check whether the regression function is defined for the given argument. - Returns bool. + :param arg: argument (e.g. model parameter) to check for + :returns: True iff the function is defined for `arg` """ return self._validation_function(arg) - def eval(self, param, args): + def eval(self, param: list, arg: float) -> float: """ Evaluate regression function. :param param: regression variable values (list of float) :param arg: model input (float) - :return: regression function output (float) + :returns: regression function output (float) """ - return self._param_function(param, args) + return self._param_function(param, arg) - def error_function(self, P, X, y): + def error_function(self, P: list, X: float, y: float) -> float: """ Calculate model error. :param P: regression variables as returned by optimization (list of float) :param X: model input (float) :param y: expected model output / ground truth for model input (float) - :return: Deviation between model output and ground truth (float) + :returns: Deviation between model output and ground truth (float) """ return self._param_function(P, X) - y @@ -89,16 +87,17 @@ class NormalizationFunction: Wrapper for parameter normalization functions used in YAML PTA/DFA models. """ - def __init__(self, function_str): + def __init__(self, function_str: str): """ Create a new normalization function from `function_str`. - :param function_str: Function string. Signature: (param) -> float + :param function_str: Function string. Must use the single argument + `param` and return a float. """ self._function_str = function_str self._function = eval('lambda param: ' + function_str) - def eval(self, param_value): + def eval(self, param_value: float) -> float: """ Evaluate the normalization function and return its output. |