summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py33
1 files changed, 16 insertions, 17 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 76be562..87fb2f0 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -14,9 +14,7 @@ arg_support_enabled = True
def powerset(iterable):
"""
- Calculate powerset of `iterable` elements.
-
- Returns an iterable containing one tuple for each powerset element.
+ Return powerset of `iterable` elements.
Example: `powerset([1, 2])` -> `[(), (1), (2), (1, 2)]`
"""
@@ -39,15 +37,14 @@ class ParamFunction:
(-> single float as model input) are used. However, n-dimensional
functions (-> list of float as model input) are also supported.
- :param param_function: regression function. Must have the signature
- (reg_param, model_param) -> float.
+ :param param_function: regression function (reg_param, model_param) -> float.
reg_param is a list of regression variable values,
model_param is the model input value (float).
- Example: lambda rp, mp: rp[0] + rp[1] * mp
+ Example: `lambda rp, mp: rp[0] + rp[1] * mp`
:param validation_function: function used to check whether param_function
is defined for a given model_param. Signature:
model_param -> bool
- Example: lambda mp: mp > 0
+ Example: `lambda mp: mp > 0`
:param num_vars: How many regression variables are used by this function,
i.e., the length of param_function's reg_param argument.
"""
@@ -55,32 +52,33 @@ class ParamFunction:
self._validation_function = validation_function
self._num_variables = num_vars
- def is_valid(self, arg):
+ def is_valid(self, arg: float) -> bool:
"""
Check whether the regression function is defined for the given argument.
- Returns bool.
+ :param arg: argument (e.g. model parameter) to check for
+ :returns: True iff the function is defined for `arg`
"""
return self._validation_function(arg)
- def eval(self, param, args):
+ def eval(self, param: list, arg: float) -> float:
"""
Evaluate regression function.
:param param: regression variable values (list of float)
:param arg: model input (float)
- :return: regression function output (float)
+ :returns: regression function output (float)
"""
- return self._param_function(param, args)
+ return self._param_function(param, arg)
- def error_function(self, P, X, y):
+ def error_function(self, P: list, X: float, y: float) -> float:
"""
Calculate model error.
:param P: regression variables as returned by optimization (list of float)
:param X: model input (float)
:param y: expected model output / ground truth for model input (float)
- :return: Deviation between model output and ground truth (float)
+ :returns: Deviation between model output and ground truth (float)
"""
return self._param_function(P, X) - y
@@ -89,16 +87,17 @@ class NormalizationFunction:
Wrapper for parameter normalization functions used in YAML PTA/DFA models.
"""
- def __init__(self, function_str):
+ def __init__(self, function_str: str):
"""
Create a new normalization function from `function_str`.
- :param function_str: Function string. Signature: (param) -> float
+ :param function_str: Function string. Must use the single argument
+ `param` and return a float.
"""
self._function_str = function_str
self._function = eval('lambda param: ' + function_str)
- def eval(self, param_value):
+ def eval(self, param_value: float) -> float:
"""
Evaluate the normalization function and return its output.