1from typing import Union, Final
2
3_titan_pylib_ModMatrix_MOD: Final[int] = 998244353
4
5
[docs]
6class ModMatrix:
7
[docs]
8 @staticmethod
9 def zeros(n: int, m: int) -> "ModMatrix":
10 return ModMatrix([[0] * m for _ in range(n)], _exter=False)
11
[docs]
12 @staticmethod
13 def ones(n: int, m: int) -> "ModMatrix":
14 return ModMatrix([[1] * m for _ in range(n)], _exter=False)
15
[docs]
16 @staticmethod
17 def identity(n: int) -> "ModMatrix":
18 a = [[0] * n for _ in range(n)]
19 for i in range(n):
20 a[i][i] = 1
21 return ModMatrix(a, _exter=False)
22
23 def __init__(self, a: list[list[int]], _exter=True) -> None:
24 self.n: int = len(a)
25 self.m: int = len(a[0]) if self.n > 0 else 0
26 if _exter:
27 for ai in a:
28 for j in range(self.m):
29 ai[j] %= _titan_pylib_ModMatrix_MOD
30 self.a: list[list[int]] = a
31
[docs]
32 def det(self, inplace=False) -> int:
33 # 上三角行列の行列式はその対角成分の総積であることを利用
34 assert self.n == self.m
35 a = self.a if inplace else [a[:] for a in self.a]
36 flip = 0
37 res = 1
38 for i, ai in enumerate(a):
39 if ai[i] == 0:
40 for j in range(i + 1, self.n):
41 if a[j][i] != 0:
42 a[i], a[j] = a[j], a[i]
43 ai = a[i]
44 flip ^= 1
45 break
46 else:
47 return 0
48 inv = pow(ai[i], -1, _titan_pylib_ModMatrix_MOD)
49 for j in range(i + 1, self.n):
50 aj = a[j]
51 freq = aj[i] * inv % _titan_pylib_ModMatrix_MOD
52 for k in range(i + 1, self.n): # i+1スタートで十分
53 aj[k] = (aj[k] - freq * ai[k]) % _titan_pylib_ModMatrix_MOD
54 res *= ai[i]
55 res %= _titan_pylib_ModMatrix_MOD
56 if flip:
57 res = -res % _titan_pylib_ModMatrix_MOD
58 return res
59
[docs]
60 def inv(self, inplace=False) -> Union[None, "ModMatrix"]:
61 # 掃き出し法の利用
62 assert self.n == self.m
63 a = self.a if inplace else [a[:] for a in self.a]
64 r = [[0] * self.n for _ in range(self.n)]
65 for i in range(self.n):
66 r[i][i] = 1
67 for i in range(self.n):
68 ai = a[i]
69 ri = r[i]
70 if ai[i] == 0:
71 for j in range(i + 1, self.n):
72 if a[j][i] != 0:
73 a[i], a[j] = a[j], a[i]
74 ai = a[i]
75 r[i], r[j] = r[j], r[i]
76 ri = r[i]
77 break
78 else:
79 return None
80 tmp = pow(ai[i], _titan_pylib_ModMatrix_MOD - 2, _titan_pylib_ModMatrix_MOD)
81 for j in range(self.n):
82 ai[j] = ai[j] * tmp % _titan_pylib_ModMatrix_MOD
83 ri[j] = ri[j] * tmp % _titan_pylib_ModMatrix_MOD
84 for j in range(self.n):
85 if i == j:
86 continue
87 aj = a[j]
88 rj = r[j]
89 aji = aj[i]
90 for k in range(self.n):
91 aj[k] = (aj[k] - ai[k] * aji) % _titan_pylib_ModMatrix_MOD
92 rj[k] = (rj[k] - ri[k] * aji) % _titan_pylib_ModMatrix_MOD
93 return ModMatrix(r, _exter=False)
94
[docs]
95 @classmethod
96 def linear_equations(cls, A: "ModMatrix", b: "ModMatrix", inplace=False):
97 # A_inv = A.inv(inplace=inplace)
98 # res = A_inv @ b
99 # return res
100 pass
101
102 def __add__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
103 if isinstance(other, int):
104 other %= _titan_pylib_ModMatrix_MOD
105 res = [a[:] for a in self.a]
106 for i in range(self.n):
107 resi = res[i]
108 for j in range(self.m):
109 val = resi[j] + other
110 resi[j] = (
111 val
112 if val < _titan_pylib_ModMatrix_MOD
113 else val - _titan_pylib_ModMatrix_MOD
114 )
115 return ModMatrix(res, _exter=False)
116 elif isinstance(other, ModMatrix):
117 assert self.n == other.n and self.m == other.m
118 res = [a[:] for a in self.a]
119 for i in range(self.n):
120 resi = res[i]
121 oi = other.a[i]
122 for j in range(self.m):
123 val = resi[j] + oi[j]
124 resi[j] = (
125 val
126 if val < _titan_pylib_ModMatrix_MOD
127 else val - _titan_pylib_ModMatrix_MOD
128 )
129 return ModMatrix(res, _exter=False)
130 else:
131 raise TypeError
132
133 def __sub__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
134 if isinstance(other, int):
135 other %= _titan_pylib_ModMatrix_MOD
136 res = [a[:] for a in self.a]
137 for i in range(self.n):
138 resi = res[i]
139 for j in range(self.m):
140 val = resi[j] - other
141 resi[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val
142 return ModMatrix(res, _exter=False)
143 elif isinstance(other, ModMatrix):
144 assert self.n == other.n and self.m == other.m
145 res = [a[:] for a in self.a]
146 for i in range(self.n):
147 resi = res[i]
148 oi = other.a[i]
149 for j in range(self.m):
150 val = resi[j] - oi[j]
151 resi[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val
152 return ModMatrix(res, _exter=False)
153 else:
154 raise TypeError
155
156 def __mul__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
157 if isinstance(other, int):
158 other %= _titan_pylib_ModMatrix_MOD
159 res = [a[:] for a in self.a]
160 for i in range(self.n):
161 resi = res[i]
162 for j in range(self.m):
163 resi[j] = resi[j] * other % _titan_pylib_ModMatrix_MOD
164 return ModMatrix(res, _exter=False)
165 if isinstance(other, ModMatrix):
166 assert self.n == other.n and self.m == other.m
167 res = [a[:] for a in self.a]
168 for i in range(self.n):
169 resi = res[i]
170 oi = other.a[i]
171 for j in range(self.m):
172 resi[j] = resi[j] * oi[j] % _titan_pylib_ModMatrix_MOD
173 return ModMatrix(res, _exter=False)
174 raise TypeError
175
176 def __matmul__(self, other: "ModMatrix") -> "ModMatrix":
177 if isinstance(other, ModMatrix):
178 assert self.m == other.n
179 res = [[0] * other.m for _ in range(self.n)]
180 for i in range(self.n):
181 si = self.a[i]
182 res_i = res[i]
183 for k in range(self.m):
184 ok = other.a[k]
185 sik = si[k]
186 for j in range(other.m):
187 res_i[j] = (res_i[j] + ok[j] * sik) % _titan_pylib_ModMatrix_MOD
188 return ModMatrix(res, _exter=False)
189 raise TypeError
190
191 def __pow__(self, n: int) -> "ModMatrix":
192 assert self.n == self.m
193 res = ModMatrix.identity(self.n)
194 a = ModMatrix([a[:] for a in self.a], _exter=False)
195 while n > 0:
196 if n & 1 == 1:
197 res @= a
198 a @= a
199 n >>= 1
200 return res
201
202 __radd__ = __add__
203 __rmul__ = __mul__
204
205 def __rsub__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
206 if isinstance(other, int):
207 other %= _titan_pylib_ModMatrix_MOD
208 res = [a[:] for a in self.a]
209 for i in range(self.n):
210 resi = res[i]
211 for j in range(self.m):
212 val = other - resi[j]
213 resi[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val
214 return ModMatrix(res, _exter=False)
215 elif isinstance(other, ModMatrix):
216 assert self.n == other.n and self.m == other.m
217 res = [a[:] for a in self.a]
218 for i in range(self.n):
219 resi = res[i]
220 oi = other.a[i]
221 for j in range(self.m):
222 val = oi[j] - resi[j]
223 resi[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val
224 return ModMatrix(res, _exter=False)
225 else:
226 raise TypeError
227
228 def __iadd__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
229 if isinstance(other, int):
230 other %= _titan_pylib_ModMatrix_MOD
231 for i in range(self.n):
232 si = self.a[i]
233 for j in range(self.m):
234 val = si[j] + other
235 si[j] = (
236 val
237 if val < _titan_pylib_ModMatrix_MOD
238 else val - _titan_pylib_ModMatrix_MOD
239 )
240 elif isinstance(other, ModMatrix):
241 assert self.n == other.n and self.m == other.m
242 for i in range(self.n):
243 si = self.a[i]
244 oi = other.a[i]
245 for j in range(self.m):
246 val = si[j] + oi[j]
247 si[j] = (
248 val
249 if val < _titan_pylib_ModMatrix_MOD
250 else val - _titan_pylib_ModMatrix_MOD
251 )
252 else:
253 raise TypeError
254 return self
255
256 def __isub__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
257 if isinstance(other, int):
258 other %= _titan_pylib_ModMatrix_MOD
259 for i in range(self.n):
260 si = self.a[i]
261 for j in range(self.m):
262 val = si[j] - other
263 si[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val
264 elif isinstance(other, ModMatrix):
265 assert self.n == other.n and self.m == other.m
266 for i in range(self.n):
267 si = self.a[i]
268 oi = other.a[i]
269 for j in range(self.m):
270 val = si[j] - oi[j]
271 si[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val
272 else:
273 raise TypeError
274 return self
275
276 def __imul__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
277 if isinstance(other, int):
278 other %= _titan_pylib_ModMatrix_MOD
279 for i in range(self.n):
280 si = self.a[i]
281 for j in range(self.m):
282 si[j] = si[j] * other % _titan_pylib_ModMatrix_MOD
283 elif isinstance(other, ModMatrix):
284 assert self.n == other.n and self.m == other.m
285 for i in range(self.n):
286 si = self.a[i]
287 oi = other.a[i]
288 for j in range(self.m):
289 si[j] = si[j] * oi[j] % _titan_pylib_ModMatrix_MOD
290 else:
291 raise TypeError
292 return self
293
294 def __imatmul__(self, other: "ModMatrix") -> "ModMatrix":
295 return self.__matmul__(other)
296
297 def __ipow__(self, n: int) -> "ModMatrix":
298 assert self.n == self.m
299 res = ModMatrix.identity(self.n)
300 while n:
301 if n & 1:
302 res @= self
303 self @= self
304 n >>= 1
305 return res
306
307 def __ne__(self) -> "ModMatrix":
308 a = [a[:] for a in self.a]
309 for i in range(self.n):
310 for j in range(self.m):
311 a[i][j] = (-a[i][j]) % _titan_pylib_ModMatrix_MOD
312 return ModMatrix(a, _exter=False)
313
[docs]
314 def add(self, n: int, m: int, key: int) -> None:
315 assert 0 <= n < self.n and 0 <= m < self.m
316 self.a[n][m] = (self.a[n][m] + key) % _titan_pylib_ModMatrix_MOD
317
[docs]
318 def get(self, n: int, m: int) -> int:
319 assert 0 <= n < self.n and 0 <= m < self.m
320 return self.a[n][m]
321
[docs]
322 def get_n(self, n: int) -> list[int]:
323 assert 0 <= n < self.n
324 return self.a[n]
325
[docs]
326 def set(self, n: int, m: int, key: int) -> None:
327 assert 0 <= n < self.n and 0 <= m < self.m
328 self.a[n][m] = key % _titan_pylib_ModMatrix_MOD
329
[docs]
330 def tolist(self) -> list[list[int]]:
331 return [a[:] for a in self.a]
332
[docs]
333 def show(self) -> None:
334 for a in self.a:
335 print(*a)
336 print()
337
338 def __iter__(self):
339 self.__iter = 0
340 return self
341
342 def __next__(self):
343 if self.__iter == self.n:
344 raise StopIteration
345 self.__iter += 1
346 return self.a[self.__iter - 1]
347
348 def __str__(self):
349 return str(self.a)