summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/functions.py33
-rw-r--r--lib/modular_arithmetic.py45
2 files changed, 46 insertions, 32 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 76be562..87fb2f0 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -14,9 +14,7 @@ arg_support_enabled = True
def powerset(iterable):
"""
- Calculate powerset of `iterable` elements.
-
- Returns an iterable containing one tuple for each powerset element.
+ Return powerset of `iterable` elements.
Example: `powerset([1, 2])` -> `[(), (1), (2), (1, 2)]`
"""
@@ -39,15 +37,14 @@ class ParamFunction:
(-> single float as model input) are used. However, n-dimensional
functions (-> list of float as model input) are also supported.
- :param param_function: regression function. Must have the signature
- (reg_param, model_param) -> float.
+ :param param_function: regression function (reg_param, model_param) -> float.
reg_param is a list of regression variable values,
model_param is the model input value (float).
- Example: lambda rp, mp: rp[0] + rp[1] * mp
+ Example: `lambda rp, mp: rp[0] + rp[1] * mp`
:param validation_function: function used to check whether param_function
is defined for a given model_param. Signature:
model_param -> bool
- Example: lambda mp: mp > 0
+ Example: `lambda mp: mp > 0`
:param num_vars: How many regression variables are used by this function,
i.e., the length of param_function's reg_param argument.
"""
@@ -55,32 +52,33 @@ class ParamFunction:
self._validation_function = validation_function
self._num_variables = num_vars
- def is_valid(self, arg):
+ def is_valid(self, arg: float) -> bool:
"""
Check whether the regression function is defined for the given argument.
- Returns bool.
+ :param arg: argument (e.g. model parameter) to check for
+ :returns: True iff the function is defined for `arg`
"""
return self._validation_function(arg)
- def eval(self, param, args):
+ def eval(self, param: list, arg: float) -> float:
"""
Evaluate regression function.
:param param: regression variable values (list of float)
:param arg: model input (float)
- :return: regression function output (float)
+ :returns: regression function output (float)
"""
- return self._param_function(param, args)
+ return self._param_function(param, arg)
- def error_function(self, P, X, y):
+ def error_function(self, P: list, X: float, y: float) -> float:
"""
Calculate model error.
:param P: regression variables as returned by optimization (list of float)
:param X: model input (float)
:param y: expected model output / ground truth for model input (float)
- :return: Deviation between model output and ground truth (float)
+ :returns: Deviation between model output and ground truth (float)
"""
return self._param_function(P, X) - y
@@ -89,16 +87,17 @@ class NormalizationFunction:
Wrapper for parameter normalization functions used in YAML PTA/DFA models.
"""
- def __init__(self, function_str):
+ def __init__(self, function_str: str):
"""
Create a new normalization function from `function_str`.
- :param function_str: Function string. Signature: (param) -> float
+ :param function_str: Function string. Must use the single argument
+ `param` and return a float.
"""
self._function_str = function_str
self._function = eval('lambda param: ' + function_str)
- def eval(self, param_value):
+ def eval(self, param_value: float) -> float:
"""
Evaluate the normalization function and return its output.
diff --git a/lib/modular_arithmetic.py b/lib/modular_arithmetic.py
index baf979a..0a69b79 100644
--- a/lib/modular_arithmetic.py
+++ b/lib/modular_arithmetic.py
@@ -2,11 +2,20 @@
# Licensed under GFDL 1.2 https://www.gnu.org/licenses/old-licenses/fdl-1.2.html
import operator
import functools
-
+
@functools.total_ordering
class Mod:
+ """A class for modular arithmetic, useful to simulate behaviour of uint8 and other limited data types.
+
+ Does not support negative values, therefore it cannot be used to emulate signed integers.
+
+ Overloads a==b, a<b, a>b, a+b, a-b, a*b, a**b, and -a
+
+ :param val: stored integer value
+ Param mod: modulus
+ """
__slots__ = ['val','mod']
-
+
def __init__(self, val, mod):
if isinstance(val, Mod):
val = val.val
@@ -16,13 +25,13 @@ class Mod:
raise ValueError('Modulo must be positive integer')
self.val = val % mod
self.mod = mod
-
+
def __repr__(self):
return 'Mod({}, {})'.format(self.val, self.mod)
-
+
def __int__(self):
return self.val
-
+
def __eq__(self, other):
if isinstance(other, Mod):
self.val == other.val
@@ -30,7 +39,7 @@ class Mod:
return self.val == other
else:
return NotImplemented
-
+
def __lt__(self, other):
if isinstance(other, Mod):
return self.val < other.val
@@ -38,25 +47,25 @@ class Mod:
return self.val < other
else:
return NotImplemented
-
+
def _check_operand(self, other):
if not isinstance(other, (int, Mod)):
raise TypeError('Only integer and Mod operands are supported')
-
+
def __pow__(self, other):
self._check_operand(other)
# We use the built-in modular exponentiation function, this way we can avoid working with huge numbers.
return __class__(pow(self.val, int(other), self.mod), self.mod)
-
+
def __neg__(self):
return Mod(self.mod - self.val, self.mod)
-
+
def __pos__(self):
return self # The unary plus operator does nothing.
-
+
def __abs__(self):
return self # The value is always kept non-negative, so the abs function should do nothing.
-
+
# Helper functions to build common operands based on a template.
# They need to be implemented as functions for the closures to work properly.
def _make_op(opname):
@@ -65,14 +74,14 @@ def _make_op(opname):
self._check_operand(other)
return Mod(op_fun(self.val, int(other)) % self.mod, self.mod)
return op
-
+
def _make_reflected_op(opname):
op_fun = getattr(operator, opname)
def op(self, other):
self._check_operand(other)
return Mod(op_fun(int(other), self.val) % self.mod, self.mod)
return op
-
+
# Build the actual operator overload methods based on the template.
for opname, reflected_opname in [('__add__', '__radd__'), ('__sub__', '__rsub__'), ('__mul__', '__rmul__')]:
setattr(Mod, opname, _make_op(opname))
@@ -115,7 +124,13 @@ class Uint64(Mod):
return 'Uint64({})'.format(self.val)
-def simulate_int_type(int_type: str):
+def simulate_int_type(int_type: str) -> Mod:
+ """
+ Return `Mod` subclass for given `int_type`
+
+ :param int_type: uint8_t / uint16_t / uint32_t / uint64_t
+ :returns: `Mod` subclass, e.g. Uint8
+ """
if int_type == 'uint8_t':
return Uint8
if int_type == 'uint16_t':