summaryrefslogtreecommitdiff
path: root/lib/modular_arithmetic.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/modular_arithmetic.py')
-rw-r--r--lib/modular_arithmetic.py57
1 files changed, 38 insertions, 19 deletions
diff --git a/lib/modular_arithmetic.py b/lib/modular_arithmetic.py
index 0a69b79..c5ed1aa 100644
--- a/lib/modular_arithmetic.py
+++ b/lib/modular_arithmetic.py
@@ -3,6 +3,7 @@
import operator
import functools
+
@functools.total_ordering
class Mod:
"""A class for modular arithmetic, useful to simulate behaviour of uint8 and other limited data types.
@@ -14,20 +15,21 @@ class Mod:
:param val: stored integer value
Param mod: modulus
"""
- __slots__ = ['val','mod']
+
+ __slots__ = ["val", "mod"]
def __init__(self, val, mod):
if isinstance(val, Mod):
val = val.val
if not isinstance(val, int):
- raise ValueError('Value must be integer')
- if not isinstance(mod, int) or mod<=0:
- raise ValueError('Modulo must be positive integer')
+ raise ValueError("Value must be integer")
+ if not isinstance(mod, int) or mod <= 0:
+ raise ValueError("Modulo must be positive integer")
self.val = val % mod
self.mod = mod
def __repr__(self):
- return 'Mod({}, {})'.format(self.val, self.mod)
+ return "Mod({}, {})".format(self.val, self.mod)
def __int__(self):
return self.val
@@ -50,7 +52,7 @@ class Mod:
def _check_operand(self, other):
if not isinstance(other, (int, Mod)):
- raise TypeError('Only integer and Mod operands are supported')
+ raise TypeError("Only integer and Mod operands are supported")
def __pow__(self, other):
self._check_operand(other)
@@ -61,32 +63,46 @@ class Mod:
return Mod(self.mod - self.val, self.mod)
def __pos__(self):
- return self # The unary plus operator does nothing.
+ return self # The unary plus operator does nothing.
def __abs__(self):
- return self # The value is always kept non-negative, so the abs function should do nothing.
+ return self # The value is always kept non-negative, so the abs function should do nothing.
+
# Helper functions to build common operands based on a template.
# They need to be implemented as functions for the closures to work properly.
def _make_op(opname):
- op_fun = getattr(operator, opname) # Fetch the operator by name from the operator module
+ op_fun = getattr(
+ operator, opname
+ ) # Fetch the operator by name from the operator module
+
def op(self, other):
self._check_operand(other)
return Mod(op_fun(self.val, int(other)) % self.mod, self.mod)
+
return op
+
def _make_reflected_op(opname):
op_fun = getattr(operator, opname)
+
def op(self, other):
self._check_operand(other)
return Mod(op_fun(int(other), self.val) % self.mod, self.mod)
+
return op
+
# Build the actual operator overload methods based on the template.
-for opname, reflected_opname in [('__add__', '__radd__'), ('__sub__', '__rsub__'), ('__mul__', '__rmul__')]:
+for opname, reflected_opname in [
+ ("__add__", "__radd__"),
+ ("__sub__", "__rsub__"),
+ ("__mul__", "__rmul__"),
+]:
setattr(Mod, opname, _make_op(opname))
setattr(Mod, reflected_opname, _make_reflected_op(opname))
+
class Uint8(Mod):
__slots__ = []
@@ -94,7 +110,8 @@ class Uint8(Mod):
super().__init__(val, 256)
def __repr__(self):
- return 'Uint8({})'.format(self.val)
+ return "Uint8({})".format(self.val)
+
class Uint16(Mod):
__slots__ = []
@@ -103,7 +120,8 @@ class Uint16(Mod):
super().__init__(val, 65536)
def __repr__(self):
- return 'Uint16({})'.format(self.val)
+ return "Uint16({})".format(self.val)
+
class Uint32(Mod):
__slots__ = []
@@ -112,7 +130,8 @@ class Uint32(Mod):
super().__init__(val, 4294967296)
def __repr__(self):
- return 'Uint32({})'.format(self.val)
+ return "Uint32({})".format(self.val)
+
class Uint64(Mod):
__slots__ = []
@@ -121,7 +140,7 @@ class Uint64(Mod):
super().__init__(val, 18446744073709551616)
def __repr__(self):
- return 'Uint64({})'.format(self.val)
+ return "Uint64({})".format(self.val)
def simulate_int_type(int_type: str) -> Mod:
@@ -131,12 +150,12 @@ def simulate_int_type(int_type: str) -> Mod:
:param int_type: uint8_t / uint16_t / uint32_t / uint64_t
:returns: `Mod` subclass, e.g. Uint8
"""
- if int_type == 'uint8_t':
+ if int_type == "uint8_t":
return Uint8
- if int_type == 'uint16_t':
+ if int_type == "uint16_t":
return Uint16
- if int_type == 'uint32_t':
+ if int_type == "uint32_t":
return Uint32
- if int_type == 'uint64_t':
+ if int_type == "uint64_t":
return Uint64
- raise ValueError('unsupported integer type: {}'.format(int_type))
+ raise ValueError("unsupported integer type: {}".format(int_type))