diff options
Diffstat (limited to 'lib/modular_arithmetic.py')
-rw-r--r-- | lib/modular_arithmetic.py | 57 |
1 files changed, 38 insertions, 19 deletions
diff --git a/lib/modular_arithmetic.py b/lib/modular_arithmetic.py index 0a69b79..c5ed1aa 100644 --- a/lib/modular_arithmetic.py +++ b/lib/modular_arithmetic.py @@ -3,6 +3,7 @@ import operator import functools + @functools.total_ordering class Mod: """A class for modular arithmetic, useful to simulate behaviour of uint8 and other limited data types. @@ -14,20 +15,21 @@ class Mod: :param val: stored integer value Param mod: modulus """ - __slots__ = ['val','mod'] + + __slots__ = ["val", "mod"] def __init__(self, val, mod): if isinstance(val, Mod): val = val.val if not isinstance(val, int): - raise ValueError('Value must be integer') - if not isinstance(mod, int) or mod<=0: - raise ValueError('Modulo must be positive integer') + raise ValueError("Value must be integer") + if not isinstance(mod, int) or mod <= 0: + raise ValueError("Modulo must be positive integer") self.val = val % mod self.mod = mod def __repr__(self): - return 'Mod({}, {})'.format(self.val, self.mod) + return "Mod({}, {})".format(self.val, self.mod) def __int__(self): return self.val @@ -50,7 +52,7 @@ class Mod: def _check_operand(self, other): if not isinstance(other, (int, Mod)): - raise TypeError('Only integer and Mod operands are supported') + raise TypeError("Only integer and Mod operands are supported") def __pow__(self, other): self._check_operand(other) @@ -61,32 +63,46 @@ class Mod: return Mod(self.mod - self.val, self.mod) def __pos__(self): - return self # The unary plus operator does nothing. + return self # The unary plus operator does nothing. def __abs__(self): - return self # The value is always kept non-negative, so the abs function should do nothing. + return self # The value is always kept non-negative, so the abs function should do nothing. + # Helper functions to build common operands based on a template. # They need to be implemented as functions for the closures to work properly. def _make_op(opname): - op_fun = getattr(operator, opname) # Fetch the operator by name from the operator module + op_fun = getattr( + operator, opname + ) # Fetch the operator by name from the operator module + def op(self, other): self._check_operand(other) return Mod(op_fun(self.val, int(other)) % self.mod, self.mod) + return op + def _make_reflected_op(opname): op_fun = getattr(operator, opname) + def op(self, other): self._check_operand(other) return Mod(op_fun(int(other), self.val) % self.mod, self.mod) + return op + # Build the actual operator overload methods based on the template. -for opname, reflected_opname in [('__add__', '__radd__'), ('__sub__', '__rsub__'), ('__mul__', '__rmul__')]: +for opname, reflected_opname in [ + ("__add__", "__radd__"), + ("__sub__", "__rsub__"), + ("__mul__", "__rmul__"), +]: setattr(Mod, opname, _make_op(opname)) setattr(Mod, reflected_opname, _make_reflected_op(opname)) + class Uint8(Mod): __slots__ = [] @@ -94,7 +110,8 @@ class Uint8(Mod): super().__init__(val, 256) def __repr__(self): - return 'Uint8({})'.format(self.val) + return "Uint8({})".format(self.val) + class Uint16(Mod): __slots__ = [] @@ -103,7 +120,8 @@ class Uint16(Mod): super().__init__(val, 65536) def __repr__(self): - return 'Uint16({})'.format(self.val) + return "Uint16({})".format(self.val) + class Uint32(Mod): __slots__ = [] @@ -112,7 +130,8 @@ class Uint32(Mod): super().__init__(val, 4294967296) def __repr__(self): - return 'Uint32({})'.format(self.val) + return "Uint32({})".format(self.val) + class Uint64(Mod): __slots__ = [] @@ -121,7 +140,7 @@ class Uint64(Mod): super().__init__(val, 18446744073709551616) def __repr__(self): - return 'Uint64({})'.format(self.val) + return "Uint64({})".format(self.val) def simulate_int_type(int_type: str) -> Mod: @@ -131,12 +150,12 @@ def simulate_int_type(int_type: str) -> Mod: :param int_type: uint8_t / uint16_t / uint32_t / uint64_t :returns: `Mod` subclass, e.g. Uint8 """ - if int_type == 'uint8_t': + if int_type == "uint8_t": return Uint8 - if int_type == 'uint16_t': + if int_type == "uint16_t": return Uint16 - if int_type == 'uint32_t': + if int_type == "uint32_t": return Uint32 - if int_type == 'uint64_t': + if int_type == "uint64_t": return Uint64 - raise ValueError('unsupported integer type: {}'.format(int_type)) + raise ValueError("unsupported integer type: {}".format(int_type)) |