diff options
Diffstat (limited to 'lib/modular_arithmetic.py')
-rw-r--r-- | lib/modular_arithmetic.py | 125 |
1 files changed, 125 insertions, 0 deletions
diff --git a/lib/modular_arithmetic.py b/lib/modular_arithmetic.py new file mode 100644 index 0000000..03cb7f8 --- /dev/null +++ b/lib/modular_arithmetic.py @@ -0,0 +1,125 @@ +# Based on https://rosettacode.org/wiki/Modular_arithmetic#Python +# Licensed under GFDL 1.2 https://www.gnu.org/licenses/old-licenses/fdl-1.2.html +import operator +import functools + +@functools.total_ordering +class Mod: + __slots__ = ['val','mod'] + + def __init__(self, val, mod): + if not isinstance(val, int): + raise ValueError('Value must be integer') + if not isinstance(mod, int) or mod<=0: + raise ValueError('Modulo must be positive integer') + self.val = val % mod + self.mod = mod + + def __repr__(self): + return 'Mod({}, {})'.format(self.val, self.mod) + + def __int__(self): + return self.val + + def __eq__(self, other): + if isinstance(other, Mod): + self.val == other.val + elif isinstance(other, int): + return self.val == other + else: + return NotImplemented + + def __lt__(self, other): + if isinstance(other, Mod): + return self.val < other.val + elif isinstance(other, int): + return self.val < other + else: + return NotImplemented + + def _check_operand(self, other): + if not isinstance(other, (int, Mod)): + raise TypeError('Only integer and Mod operands are supported') + + def __pow__(self, other): + self._check_operand(other) + # We use the built-in modular exponentiation function, this way we can avoid working with huge numbers. + return __class__(pow(self.val, int(other), self.mod), self.mod) + + def __neg__(self): + return Mod(self.mod - self.val, self.mod) + + def __pos__(self): + return self # The unary plus operator does nothing. + + def __abs__(self): + return self # The value is always kept non-negative, so the abs function should do nothing. + +# Helper functions to build common operands based on a template. +# They need to be implemented as functions for the closures to work properly. +def _make_op(opname): + op_fun = getattr(operator, opname) # Fetch the operator by name from the operator module + def op(self, other): + self._check_operand(other) + return Mod(op_fun(self.val, int(other)) % self.mod, self.mod) + return op + +def _make_reflected_op(opname): + op_fun = getattr(operator, opname) + def op(self, other): + self._check_operand(other) + return Mod(op_fun(int(other), self.val) % self.mod, self.mod) + return op + +# Build the actual operator overload methods based on the template. +for opname, reflected_opname in [('__add__', '__radd__'), ('__sub__', '__rsub__'), ('__mul__', '__rmul__')]: + setattr(Mod, opname, _make_op(opname)) + setattr(Mod, reflected_opname, _make_reflected_op(opname)) + +class Uint8(Mod): + __slots__ = [] + + def __init__(self, val): + super().__init__(val, 256) + + def __repr__(self): + return 'Uint8({})'.format(self.val) + +class Uint16(Mod): + __slots__ = [] + + def __init__(self, val): + super().__init__(val, 65536) + + def __repr__(self): + return 'Uint16({})'.format(self.val) + +class Uint32(Mod): + __slots__ = [] + + def __init__(self, val): + super().__init__(val, 4294967296) + + def __repr__(self): + return 'Uint32({})'.format(self.val) + +class Uint64(Mod): + __slots__ = [] + + def __init__(self, val): + super().__init__(val, 18446744073709551616) + + def __repr__(self): + return 'Uint64({})'.format(self.val) + + +def simulate_int_type(int_type: str): + if int_type == 'uint8_t': + return Uint8 + if int_type == 'uint16_t': + return Uint16 + if int_type == 'uint32_t': + return Uint32 + if int_type == 'uint64_t': + return Uint64 + raise ValueError('unsupported integer type: {}'.format(int_type)) |