summaryrefslogtreecommitdiff
path: root/lib/modular_arithmetic.py
blob: c5ed1aabb5507aa4d137c1d239acb9d77a53872d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Based on https://rosettacode.org/wiki/Modular_arithmetic#Python
# Licensed under GFDL 1.2 https://www.gnu.org/licenses/old-licenses/fdl-1.2.html
import operator
import functools


@functools.total_ordering
class Mod:
    """A class for modular arithmetic, useful to simulate behaviour of uint8 and other limited data types.

    Does not support negative values, therefore it cannot be used to emulate signed integers.

    Overloads a==b, a<b, a>b, a+b, a-b, a*b, a**b, and -a

    :param val: stored integer value
    Param mod: modulus
    """

    __slots__ = ["val", "mod"]

    def __init__(self, val, mod):
        if isinstance(val, Mod):
            val = val.val
        if not isinstance(val, int):
            raise ValueError("Value must be integer")
        if not isinstance(mod, int) or mod <= 0:
            raise ValueError("Modulo must be positive integer")
        self.val = val % mod
        self.mod = mod

    def __repr__(self):
        return "Mod({}, {})".format(self.val, self.mod)

    def __int__(self):
        return self.val

    def __eq__(self, other):
        if isinstance(other, Mod):
            self.val == other.val
        elif isinstance(other, int):
            return self.val == other
        else:
            return NotImplemented

    def __lt__(self, other):
        if isinstance(other, Mod):
            return self.val < other.val
        elif isinstance(other, int):
            return self.val < other
        else:
            return NotImplemented

    def _check_operand(self, other):
        if not isinstance(other, (int, Mod)):
            raise TypeError("Only integer and Mod operands are supported")

    def __pow__(self, other):
        self._check_operand(other)
        # We use the built-in modular exponentiation function, this way we can avoid working with huge numbers.
        return __class__(pow(self.val, int(other), self.mod), self.mod)

    def __neg__(self):
        return Mod(self.mod - self.val, self.mod)

    def __pos__(self):
        return self  # The unary plus operator does nothing.

    def __abs__(self):
        return self  # The value is always kept non-negative, so the abs function should do nothing.


# Helper functions to build common operands based on a template.
# They need to be implemented as functions for the closures to work properly.
def _make_op(opname):
    op_fun = getattr(
        operator, opname
    )  # Fetch the operator by name from the operator module

    def op(self, other):
        self._check_operand(other)
        return Mod(op_fun(self.val, int(other)) % self.mod, self.mod)

    return op


def _make_reflected_op(opname):
    op_fun = getattr(operator, opname)

    def op(self, other):
        self._check_operand(other)
        return Mod(op_fun(int(other), self.val) % self.mod, self.mod)

    return op


# Build the actual operator overload methods based on the template.
for opname, reflected_opname in [
    ("__add__", "__radd__"),
    ("__sub__", "__rsub__"),
    ("__mul__", "__rmul__"),
]:
    setattr(Mod, opname, _make_op(opname))
    setattr(Mod, reflected_opname, _make_reflected_op(opname))


class Uint8(Mod):
    __slots__ = []

    def __init__(self, val):
        super().__init__(val, 256)

    def __repr__(self):
        return "Uint8({})".format(self.val)


class Uint16(Mod):
    __slots__ = []

    def __init__(self, val):
        super().__init__(val, 65536)

    def __repr__(self):
        return "Uint16({})".format(self.val)


class Uint32(Mod):
    __slots__ = []

    def __init__(self, val):
        super().__init__(val, 4294967296)

    def __repr__(self):
        return "Uint32({})".format(self.val)


class Uint64(Mod):
    __slots__ = []

    def __init__(self, val):
        super().__init__(val, 18446744073709551616)

    def __repr__(self):
        return "Uint64({})".format(self.val)


def simulate_int_type(int_type: str) -> Mod:
    """
    Return `Mod` subclass for given `int_type`

    :param int_type: uint8_t / uint16_t / uint32_t / uint64_t
    :returns: `Mod` subclass, e.g. Uint8
    """
    if int_type == "uint8_t":
        return Uint8
    if int_type == "uint16_t":
        return Uint16
    if int_type == "uint32_t":
        return Uint32
    if int_type == "uint64_t":
        return Uint64
    raise ValueError("unsupported integer type: {}".format(int_type))