summaryrefslogtreecommitdiff
path: root/lib/modular_arithmetic.py
blob: baf979aa5e18837816685777decf4dfe7b26ee43 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Based on https://rosettacode.org/wiki/Modular_arithmetic#Python
# Licensed under GFDL 1.2 https://www.gnu.org/licenses/old-licenses/fdl-1.2.html
import operator
import functools
 
@functools.total_ordering
class Mod:
    __slots__ = ['val','mod']
 
    def __init__(self, val, mod):
        if isinstance(val, Mod):
            val = val.val
        if not isinstance(val, int):
            raise ValueError('Value must be integer')
        if not isinstance(mod, int) or mod<=0:
            raise ValueError('Modulo must be positive integer')
        self.val = val % mod
        self.mod = mod
 
    def __repr__(self):
        return 'Mod({}, {})'.format(self.val, self.mod)
 
    def __int__(self):
        return self.val
 
    def __eq__(self, other):
        if isinstance(other, Mod):
            self.val == other.val
        elif isinstance(other, int):
            return self.val == other
        else:
            return NotImplemented
 
    def __lt__(self, other):
        if isinstance(other, Mod):
            return self.val < other.val
        elif isinstance(other, int):
            return self.val < other
        else:
            return NotImplemented
 
    def _check_operand(self, other):
        if not isinstance(other, (int, Mod)):
            raise TypeError('Only integer and Mod operands are supported')
 
    def __pow__(self, other):
        self._check_operand(other)
        # We use the built-in modular exponentiation function, this way we can avoid working with huge numbers.
        return __class__(pow(self.val, int(other), self.mod), self.mod)
 
    def __neg__(self):
        return Mod(self.mod - self.val, self.mod)
 
    def __pos__(self):
        return self # The unary plus operator does nothing.
 
    def __abs__(self):
        return self # The value is always kept non-negative, so the abs function should do nothing.
 
# Helper functions to build common operands based on a template.
# They need to be implemented as functions for the closures to work properly.
def _make_op(opname):
    op_fun = getattr(operator, opname)  # Fetch the operator by name from the operator module
    def op(self, other):
        self._check_operand(other)
        return Mod(op_fun(self.val, int(other)) % self.mod, self.mod)
    return op
 
def _make_reflected_op(opname):
    op_fun = getattr(operator, opname)
    def op(self, other):
        self._check_operand(other)
        return Mod(op_fun(int(other), self.val) % self.mod, self.mod)
    return op
 
# Build the actual operator overload methods based on the template.
for opname, reflected_opname in [('__add__', '__radd__'), ('__sub__', '__rsub__'), ('__mul__', '__rmul__')]:
    setattr(Mod, opname, _make_op(opname))
    setattr(Mod, reflected_opname, _make_reflected_op(opname))

class Uint8(Mod):
    __slots__ = []

    def __init__(self, val):
        super().__init__(val, 256)

    def __repr__(self):
        return 'Uint8({})'.format(self.val)

class Uint16(Mod):
    __slots__ = []

    def __init__(self, val):
        super().__init__(val, 65536)

    def __repr__(self):
        return 'Uint16({})'.format(self.val)

class Uint32(Mod):
    __slots__ = []

    def __init__(self, val):
        super().__init__(val, 4294967296)

    def __repr__(self):
        return 'Uint32({})'.format(self.val)

class Uint64(Mod):
    __slots__ = []

    def __init__(self, val):
        super().__init__(val, 18446744073709551616)

    def __repr__(self):
        return 'Uint64({})'.format(self.val)


def simulate_int_type(int_type: str):
    if int_type == 'uint8_t':
        return Uint8
    if int_type == 'uint16_t':
        return Uint16
    if int_type == 'uint32_t':
        return Uint32
    if int_type == 'uint64_t':
        return Uint64
    raise ValueError('unsupported integer type: {}'.format(int_type))