diff options
-rw-r--r-- | lib/functions.py | 33 | ||||
-rw-r--r-- | lib/modular_arithmetic.py | 45 |
2 files changed, 46 insertions, 32 deletions
diff --git a/lib/functions.py b/lib/functions.py index 76be562..87fb2f0 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -14,9 +14,7 @@ arg_support_enabled = True def powerset(iterable): """ - Calculate powerset of `iterable` elements. - - Returns an iterable containing one tuple for each powerset element. + Return powerset of `iterable` elements. Example: `powerset([1, 2])` -> `[(), (1), (2), (1, 2)]` """ @@ -39,15 +37,14 @@ class ParamFunction: (-> single float as model input) are used. However, n-dimensional functions (-> list of float as model input) are also supported. - :param param_function: regression function. Must have the signature - (reg_param, model_param) -> float. + :param param_function: regression function (reg_param, model_param) -> float. reg_param is a list of regression variable values, model_param is the model input value (float). - Example: lambda rp, mp: rp[0] + rp[1] * mp + Example: `lambda rp, mp: rp[0] + rp[1] * mp` :param validation_function: function used to check whether param_function is defined for a given model_param. Signature: model_param -> bool - Example: lambda mp: mp > 0 + Example: `lambda mp: mp > 0` :param num_vars: How many regression variables are used by this function, i.e., the length of param_function's reg_param argument. """ @@ -55,32 +52,33 @@ class ParamFunction: self._validation_function = validation_function self._num_variables = num_vars - def is_valid(self, arg): + def is_valid(self, arg: float) -> bool: """ Check whether the regression function is defined for the given argument. - Returns bool. + :param arg: argument (e.g. model parameter) to check for + :returns: True iff the function is defined for `arg` """ return self._validation_function(arg) - def eval(self, param, args): + def eval(self, param: list, arg: float) -> float: """ Evaluate regression function. :param param: regression variable values (list of float) :param arg: model input (float) - :return: regression function output (float) + :returns: regression function output (float) """ - return self._param_function(param, args) + return self._param_function(param, arg) - def error_function(self, P, X, y): + def error_function(self, P: list, X: float, y: float) -> float: """ Calculate model error. :param P: regression variables as returned by optimization (list of float) :param X: model input (float) :param y: expected model output / ground truth for model input (float) - :return: Deviation between model output and ground truth (float) + :returns: Deviation between model output and ground truth (float) """ return self._param_function(P, X) - y @@ -89,16 +87,17 @@ class NormalizationFunction: Wrapper for parameter normalization functions used in YAML PTA/DFA models. """ - def __init__(self, function_str): + def __init__(self, function_str: str): """ Create a new normalization function from `function_str`. - :param function_str: Function string. Signature: (param) -> float + :param function_str: Function string. Must use the single argument + `param` and return a float. """ self._function_str = function_str self._function = eval('lambda param: ' + function_str) - def eval(self, param_value): + def eval(self, param_value: float) -> float: """ Evaluate the normalization function and return its output. diff --git a/lib/modular_arithmetic.py b/lib/modular_arithmetic.py index baf979a..0a69b79 100644 --- a/lib/modular_arithmetic.py +++ b/lib/modular_arithmetic.py @@ -2,11 +2,20 @@ # Licensed under GFDL 1.2 https://www.gnu.org/licenses/old-licenses/fdl-1.2.html import operator import functools - + @functools.total_ordering class Mod: + """A class for modular arithmetic, useful to simulate behaviour of uint8 and other limited data types. + + Does not support negative values, therefore it cannot be used to emulate signed integers. + + Overloads a==b, a<b, a>b, a+b, a-b, a*b, a**b, and -a + + :param val: stored integer value + Param mod: modulus + """ __slots__ = ['val','mod'] - + def __init__(self, val, mod): if isinstance(val, Mod): val = val.val @@ -16,13 +25,13 @@ class Mod: raise ValueError('Modulo must be positive integer') self.val = val % mod self.mod = mod - + def __repr__(self): return 'Mod({}, {})'.format(self.val, self.mod) - + def __int__(self): return self.val - + def __eq__(self, other): if isinstance(other, Mod): self.val == other.val @@ -30,7 +39,7 @@ class Mod: return self.val == other else: return NotImplemented - + def __lt__(self, other): if isinstance(other, Mod): return self.val < other.val @@ -38,25 +47,25 @@ class Mod: return self.val < other else: return NotImplemented - + def _check_operand(self, other): if not isinstance(other, (int, Mod)): raise TypeError('Only integer and Mod operands are supported') - + def __pow__(self, other): self._check_operand(other) # We use the built-in modular exponentiation function, this way we can avoid working with huge numbers. return __class__(pow(self.val, int(other), self.mod), self.mod) - + def __neg__(self): return Mod(self.mod - self.val, self.mod) - + def __pos__(self): return self # The unary plus operator does nothing. - + def __abs__(self): return self # The value is always kept non-negative, so the abs function should do nothing. - + # Helper functions to build common operands based on a template. # They need to be implemented as functions for the closures to work properly. def _make_op(opname): @@ -65,14 +74,14 @@ def _make_op(opname): self._check_operand(other) return Mod(op_fun(self.val, int(other)) % self.mod, self.mod) return op - + def _make_reflected_op(opname): op_fun = getattr(operator, opname) def op(self, other): self._check_operand(other) return Mod(op_fun(int(other), self.val) % self.mod, self.mod) return op - + # Build the actual operator overload methods based on the template. for opname, reflected_opname in [('__add__', '__radd__'), ('__sub__', '__rsub__'), ('__mul__', '__rmul__')]: setattr(Mod, opname, _make_op(opname)) @@ -115,7 +124,13 @@ class Uint64(Mod): return 'Uint64({})'.format(self.val) -def simulate_int_type(int_type: str): +def simulate_int_type(int_type: str) -> Mod: + """ + Return `Mod` subclass for given `int_type` + + :param int_type: uint8_t / uint16_t / uint32_t / uint64_t + :returns: `Mod` subclass, e.g. Uint8 + """ if int_type == 'uint8_t': return Uint8 if int_type == 'uint16_t': |