mod_matrix

ソースコード

from titan_pylib.math.mod_matrix import ModMatrix

view on github

展開済みコード

  1# from titan_pylib.math.mod_matrix import ModMatrix
  2from typing import Union, Final
  3
  4_titan_pylib_ModMatrix_MOD: Final[int] = 998244353
  5
  6
  7class ModMatrix:
  8
  9    @staticmethod
 10    def zeros(n: int, m: int) -> "ModMatrix":
 11        return ModMatrix([[0] * m for _ in range(n)], _exter=False)
 12
 13    @staticmethod
 14    def ones(n: int, m: int) -> "ModMatrix":
 15        return ModMatrix([[1] * m for _ in range(n)], _exter=False)
 16
 17    @staticmethod
 18    def identity(n: int) -> "ModMatrix":
 19        a = [[0] * n for _ in range(n)]
 20        for i in range(n):
 21            a[i][i] = 1
 22        return ModMatrix(a, _exter=False)
 23
 24    def __init__(self, a: list[list[int]], _exter=True) -> None:
 25        self.n: int = len(a)
 26        self.m: int = len(a[0]) if self.n > 0 else 0
 27        if _exter:
 28            for ai in a:
 29                for j in range(self.m):
 30                    ai[j] %= _titan_pylib_ModMatrix_MOD
 31        self.a: list[list[int]] = a
 32
 33    def det(self, inplace=False) -> int:
 34        # 上三角行列の行列式はその対角成分の総積であることを利用
 35        assert self.n == self.m
 36        a = self.a if inplace else [a[:] for a in self.a]
 37        flip = 0
 38        res = 1
 39        for i, ai in enumerate(a):
 40            if ai[i] == 0:
 41                for j in range(i + 1, self.n):
 42                    if a[j][i] != 0:
 43                        a[i], a[j] = a[j], a[i]
 44                        ai = a[i]
 45                        flip ^= 1
 46                        break
 47                else:
 48                    return 0
 49            inv = pow(ai[i], -1, _titan_pylib_ModMatrix_MOD)
 50            for j in range(i + 1, self.n):
 51                aj = a[j]
 52                freq = aj[i] * inv % _titan_pylib_ModMatrix_MOD
 53                for k in range(i + 1, self.n):  # i+1スタートで十分
 54                    aj[k] = (aj[k] - freq * ai[k]) % _titan_pylib_ModMatrix_MOD
 55            res *= ai[i]
 56            res %= _titan_pylib_ModMatrix_MOD
 57        if flip:
 58            res = -res % _titan_pylib_ModMatrix_MOD
 59        return res
 60
 61    def inv(self, inplace=False) -> Union[None, "ModMatrix"]:
 62        # 掃き出し法の利用
 63        assert self.n == self.m
 64        a = self.a if inplace else [a[:] for a in self.a]
 65        r = [[0] * self.n for _ in range(self.n)]
 66        for i in range(self.n):
 67            r[i][i] = 1
 68        for i in range(self.n):
 69            ai = a[i]
 70            ri = r[i]
 71            if ai[i] == 0:
 72                for j in range(i + 1, self.n):
 73                    if a[j][i] != 0:
 74                        a[i], a[j] = a[j], a[i]
 75                        ai = a[i]
 76                        r[i], r[j] = r[j], r[i]
 77                        ri = r[i]
 78                        break
 79                else:
 80                    return None
 81            tmp = pow(ai[i], _titan_pylib_ModMatrix_MOD - 2, _titan_pylib_ModMatrix_MOD)
 82            for j in range(self.n):
 83                ai[j] = ai[j] * tmp % _titan_pylib_ModMatrix_MOD
 84                ri[j] = ri[j] * tmp % _titan_pylib_ModMatrix_MOD
 85            for j in range(self.n):
 86                if i == j:
 87                    continue
 88                aj = a[j]
 89                rj = r[j]
 90                aji = aj[i]
 91                for k in range(self.n):
 92                    aj[k] = (aj[k] - ai[k] * aji) % _titan_pylib_ModMatrix_MOD
 93                    rj[k] = (rj[k] - ri[k] * aji) % _titan_pylib_ModMatrix_MOD
 94        return ModMatrix(r, _exter=False)
 95
 96    @classmethod
 97    def linear_equations(cls, A: "ModMatrix", b: "ModMatrix", inplace=False):
 98        # A_inv = A.inv(inplace=inplace)
 99        # res = A_inv @ b
100        # return res
101        pass
102
103    def __add__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
104        if isinstance(other, int):
105            other %= _titan_pylib_ModMatrix_MOD
106            res = [a[:] for a in self.a]
107            for i in range(self.n):
108                resi = res[i]
109                for j in range(self.m):
110                    val = resi[j] + other
111                    resi[j] = (
112                        val
113                        if val < _titan_pylib_ModMatrix_MOD
114                        else val - _titan_pylib_ModMatrix_MOD
115                    )
116            return ModMatrix(res, _exter=False)
117        elif isinstance(other, ModMatrix):
118            assert self.n == other.n and self.m == other.m
119            res = [a[:] for a in self.a]
120            for i in range(self.n):
121                resi = res[i]
122                oi = other.a[i]
123                for j in range(self.m):
124                    val = resi[j] + oi[j]
125                    resi[j] = (
126                        val
127                        if val < _titan_pylib_ModMatrix_MOD
128                        else val - _titan_pylib_ModMatrix_MOD
129                    )
130            return ModMatrix(res, _exter=False)
131        else:
132            raise TypeError
133
134    def __sub__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
135        if isinstance(other, int):
136            other %= _titan_pylib_ModMatrix_MOD
137            res = [a[:] for a in self.a]
138            for i in range(self.n):
139                resi = res[i]
140                for j in range(self.m):
141                    val = resi[j] - other
142                    resi[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val
143            return ModMatrix(res, _exter=False)
144        elif isinstance(other, ModMatrix):
145            assert self.n == other.n and self.m == other.m
146            res = [a[:] for a in self.a]
147            for i in range(self.n):
148                resi = res[i]
149                oi = other.a[i]
150                for j in range(self.m):
151                    val = resi[j] - oi[j]
152                    resi[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val
153            return ModMatrix(res, _exter=False)
154        else:
155            raise TypeError
156
157    def __mul__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
158        if isinstance(other, int):
159            other %= _titan_pylib_ModMatrix_MOD
160            res = [a[:] for a in self.a]
161            for i in range(self.n):
162                resi = res[i]
163                for j in range(self.m):
164                    resi[j] = resi[j] * other % _titan_pylib_ModMatrix_MOD
165            return ModMatrix(res, _exter=False)
166        if isinstance(other, ModMatrix):
167            assert self.n == other.n and self.m == other.m
168            res = [a[:] for a in self.a]
169            for i in range(self.n):
170                resi = res[i]
171                oi = other.a[i]
172                for j in range(self.m):
173                    resi[j] = resi[j] * oi[j] % _titan_pylib_ModMatrix_MOD
174            return ModMatrix(res, _exter=False)
175        raise TypeError
176
177    def __matmul__(self, other: "ModMatrix") -> "ModMatrix":
178        if isinstance(other, ModMatrix):
179            assert self.m == other.n
180            res = [[0] * other.m for _ in range(self.n)]
181            for i in range(self.n):
182                si = self.a[i]
183                res_i = res[i]
184                for k in range(self.m):
185                    ok = other.a[k]
186                    sik = si[k]
187                    for j in range(other.m):
188                        res_i[j] = (res_i[j] + ok[j] * sik) % _titan_pylib_ModMatrix_MOD
189            return ModMatrix(res, _exter=False)
190        raise TypeError
191
192    def __pow__(self, n: int) -> "ModMatrix":
193        assert self.n == self.m
194        res = ModMatrix.identity(self.n)
195        a = ModMatrix([a[:] for a in self.a], _exter=False)
196        while n > 0:
197            if n & 1 == 1:
198                res @= a
199            a @= a
200            n >>= 1
201        return res
202
203    __radd__ = __add__
204    __rmul__ = __mul__
205
206    def __rsub__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
207        if isinstance(other, int):
208            other %= _titan_pylib_ModMatrix_MOD
209            res = [a[:] for a in self.a]
210            for i in range(self.n):
211                resi = res[i]
212                for j in range(self.m):
213                    val = other - resi[j]
214                    resi[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val
215            return ModMatrix(res, _exter=False)
216        elif isinstance(other, ModMatrix):
217            assert self.n == other.n and self.m == other.m
218            res = [a[:] for a in self.a]
219            for i in range(self.n):
220                resi = res[i]
221                oi = other.a[i]
222                for j in range(self.m):
223                    val = oi[j] - resi[j]
224                    resi[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val
225            return ModMatrix(res, _exter=False)
226        else:
227            raise TypeError
228
229    def __iadd__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
230        if isinstance(other, int):
231            other %= _titan_pylib_ModMatrix_MOD
232            for i in range(self.n):
233                si = self.a[i]
234                for j in range(self.m):
235                    val = si[j] + other
236                    si[j] = (
237                        val
238                        if val < _titan_pylib_ModMatrix_MOD
239                        else val - _titan_pylib_ModMatrix_MOD
240                    )
241        elif isinstance(other, ModMatrix):
242            assert self.n == other.n and self.m == other.m
243            for i in range(self.n):
244                si = self.a[i]
245                oi = other.a[i]
246                for j in range(self.m):
247                    val = si[j] + oi[j]
248                    si[j] = (
249                        val
250                        if val < _titan_pylib_ModMatrix_MOD
251                        else val - _titan_pylib_ModMatrix_MOD
252                    )
253        else:
254            raise TypeError
255        return self
256
257    def __isub__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
258        if isinstance(other, int):
259            other %= _titan_pylib_ModMatrix_MOD
260            for i in range(self.n):
261                si = self.a[i]
262                for j in range(self.m):
263                    val = si[j] - other
264                    si[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val
265        elif isinstance(other, ModMatrix):
266            assert self.n == other.n and self.m == other.m
267            for i in range(self.n):
268                si = self.a[i]
269                oi = other.a[i]
270                for j in range(self.m):
271                    val = si[j] - oi[j]
272                    si[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val
273        else:
274            raise TypeError
275        return self
276
277    def __imul__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
278        if isinstance(other, int):
279            other %= _titan_pylib_ModMatrix_MOD
280            for i in range(self.n):
281                si = self.a[i]
282                for j in range(self.m):
283                    si[j] = si[j] * other % _titan_pylib_ModMatrix_MOD
284        elif isinstance(other, ModMatrix):
285            assert self.n == other.n and self.m == other.m
286            for i in range(self.n):
287                si = self.a[i]
288                oi = other.a[i]
289                for j in range(self.m):
290                    si[j] = si[j] * oi[j] % _titan_pylib_ModMatrix_MOD
291        else:
292            raise TypeError
293        return self
294
295    def __imatmul__(self, other: "ModMatrix") -> "ModMatrix":
296        return self.__matmul__(other)
297
298    def __ipow__(self, n: int) -> "ModMatrix":
299        assert self.n == self.m
300        res = ModMatrix.identity(self.n)
301        while n:
302            if n & 1:
303                res @= self
304            self @= self
305            n >>= 1
306        return res
307
308    def __ne__(self) -> "ModMatrix":
309        a = [a[:] for a in self.a]
310        for i in range(self.n):
311            for j in range(self.m):
312                a[i][j] = (-a[i][j]) % _titan_pylib_ModMatrix_MOD
313        return ModMatrix(a, _exter=False)
314
315    def add(self, n: int, m: int, key: int) -> None:
316        assert 0 <= n < self.n and 0 <= m < self.m
317        self.a[n][m] = (self.a[n][m] + key) % _titan_pylib_ModMatrix_MOD
318
319    def get(self, n: int, m: int) -> int:
320        assert 0 <= n < self.n and 0 <= m < self.m
321        return self.a[n][m]
322
323    def get_n(self, n: int) -> list[int]:
324        assert 0 <= n < self.n
325        return self.a[n]
326
327    def set(self, n: int, m: int, key: int) -> None:
328        assert 0 <= n < self.n and 0 <= m < self.m
329        self.a[n][m] = key % _titan_pylib_ModMatrix_MOD
330
331    def tolist(self) -> list[list[int]]:
332        return [a[:] for a in self.a]
333
334    def show(self) -> None:
335        for a in self.a:
336            print(*a)
337        print()
338
339    def __iter__(self):
340        self.__iter = 0
341        return self
342
343    def __next__(self):
344        if self.__iter == self.n:
345            raise StopIteration
346        self.__iter += 1
347        return self.a[self.__iter - 1]
348
349    def __str__(self):
350        return str(self.a)

仕様

class ModMatrix(a: list[list[int]], _exter=True)[source]

Bases: object

add(n: int, m: int, key: int) None[source]
det(inplace=False) int[source]
get(n: int, m: int) int[source]
get_n(n: int) list[int][source]
static identity(n: int) ModMatrix[source]
inv(inplace=False) None | ModMatrix[source]
classmethod linear_equations(A: ModMatrix, b: ModMatrix, inplace=False)[source]
static ones(n: int, m: int) ModMatrix[source]
set(n: int, m: int, key: int) None[source]
show() None[source]
tolist() list[list[int]][source]
static zeros(n: int, m: int) ModMatrix[source]