1# from titan_pylib.math.mod_matrix import ModMatrix
2from typing import Union, Final
3
4_titan_pylib_ModMatrix_MOD: Final[int] = 998244353
5
6
7class ModMatrix:
8
9 @staticmethod
10 def zeros(n: int, m: int) -> "ModMatrix":
11 return ModMatrix([[0] * m for _ in range(n)], _exter=False)
12
13 @staticmethod
14 def ones(n: int, m: int) -> "ModMatrix":
15 return ModMatrix([[1] * m for _ in range(n)], _exter=False)
16
17 @staticmethod
18 def identity(n: int) -> "ModMatrix":
19 a = [[0] * n for _ in range(n)]
20 for i in range(n):
21 a[i][i] = 1
22 return ModMatrix(a, _exter=False)
23
24 def __init__(self, a: list[list[int]], _exter=True) -> None:
25 self.n: int = len(a)
26 self.m: int = len(a[0]) if self.n > 0 else 0
27 if _exter:
28 for ai in a:
29 for j in range(self.m):
30 ai[j] %= _titan_pylib_ModMatrix_MOD
31 self.a: list[list[int]] = a
32
33 def det(self, inplace=False) -> int:
34 # 上三角行列の行列式はその対角成分の総積であることを利用
35 assert self.n == self.m
36 a = self.a if inplace else [a[:] for a in self.a]
37 flip = 0
38 res = 1
39 for i, ai in enumerate(a):
40 if ai[i] == 0:
41 for j in range(i + 1, self.n):
42 if a[j][i] != 0:
43 a[i], a[j] = a[j], a[i]
44 ai = a[i]
45 flip ^= 1
46 break
47 else:
48 return 0
49 inv = pow(ai[i], -1, _titan_pylib_ModMatrix_MOD)
50 for j in range(i + 1, self.n):
51 aj = a[j]
52 freq = aj[i] * inv % _titan_pylib_ModMatrix_MOD
53 for k in range(i + 1, self.n): # i+1スタートで十分
54 aj[k] = (aj[k] - freq * ai[k]) % _titan_pylib_ModMatrix_MOD
55 res *= ai[i]
56 res %= _titan_pylib_ModMatrix_MOD
57 if flip:
58 res = -res % _titan_pylib_ModMatrix_MOD
59 return res
60
61 def inv(self, inplace=False) -> Union[None, "ModMatrix"]:
62 # 掃き出し法の利用
63 assert self.n == self.m
64 a = self.a if inplace else [a[:] for a in self.a]
65 r = [[0] * self.n for _ in range(self.n)]
66 for i in range(self.n):
67 r[i][i] = 1
68 for i in range(self.n):
69 ai = a[i]
70 ri = r[i]
71 if ai[i] == 0:
72 for j in range(i + 1, self.n):
73 if a[j][i] != 0:
74 a[i], a[j] = a[j], a[i]
75 ai = a[i]
76 r[i], r[j] = r[j], r[i]
77 ri = r[i]
78 break
79 else:
80 return None
81 tmp = pow(ai[i], _titan_pylib_ModMatrix_MOD - 2, _titan_pylib_ModMatrix_MOD)
82 for j in range(self.n):
83 ai[j] = ai[j] * tmp % _titan_pylib_ModMatrix_MOD
84 ri[j] = ri[j] * tmp % _titan_pylib_ModMatrix_MOD
85 for j in range(self.n):
86 if i == j:
87 continue
88 aj = a[j]
89 rj = r[j]
90 aji = aj[i]
91 for k in range(self.n):
92 aj[k] = (aj[k] - ai[k] * aji) % _titan_pylib_ModMatrix_MOD
93 rj[k] = (rj[k] - ri[k] * aji) % _titan_pylib_ModMatrix_MOD
94 return ModMatrix(r, _exter=False)
95
96 @classmethod
97 def linear_equations(cls, A: "ModMatrix", b: "ModMatrix", inplace=False):
98 # A_inv = A.inv(inplace=inplace)
99 # res = A_inv @ b
100 # return res
101 pass
102
103 def __add__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
104 if isinstance(other, int):
105 other %= _titan_pylib_ModMatrix_MOD
106 res = [a[:] for a in self.a]
107 for i in range(self.n):
108 resi = res[i]
109 for j in range(self.m):
110 val = resi[j] + other
111 resi[j] = (
112 val
113 if val < _titan_pylib_ModMatrix_MOD
114 else val - _titan_pylib_ModMatrix_MOD
115 )
116 return ModMatrix(res, _exter=False)
117 elif isinstance(other, ModMatrix):
118 assert self.n == other.n and self.m == other.m
119 res = [a[:] for a in self.a]
120 for i in range(self.n):
121 resi = res[i]
122 oi = other.a[i]
123 for j in range(self.m):
124 val = resi[j] + oi[j]
125 resi[j] = (
126 val
127 if val < _titan_pylib_ModMatrix_MOD
128 else val - _titan_pylib_ModMatrix_MOD
129 )
130 return ModMatrix(res, _exter=False)
131 else:
132 raise TypeError
133
134 def __sub__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
135 if isinstance(other, int):
136 other %= _titan_pylib_ModMatrix_MOD
137 res = [a[:] for a in self.a]
138 for i in range(self.n):
139 resi = res[i]
140 for j in range(self.m):
141 val = resi[j] - other
142 resi[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val
143 return ModMatrix(res, _exter=False)
144 elif isinstance(other, ModMatrix):
145 assert self.n == other.n and self.m == other.m
146 res = [a[:] for a in self.a]
147 for i in range(self.n):
148 resi = res[i]
149 oi = other.a[i]
150 for j in range(self.m):
151 val = resi[j] - oi[j]
152 resi[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val
153 return ModMatrix(res, _exter=False)
154 else:
155 raise TypeError
156
157 def __mul__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
158 if isinstance(other, int):
159 other %= _titan_pylib_ModMatrix_MOD
160 res = [a[:] for a in self.a]
161 for i in range(self.n):
162 resi = res[i]
163 for j in range(self.m):
164 resi[j] = resi[j] * other % _titan_pylib_ModMatrix_MOD
165 return ModMatrix(res, _exter=False)
166 if isinstance(other, ModMatrix):
167 assert self.n == other.n and self.m == other.m
168 res = [a[:] for a in self.a]
169 for i in range(self.n):
170 resi = res[i]
171 oi = other.a[i]
172 for j in range(self.m):
173 resi[j] = resi[j] * oi[j] % _titan_pylib_ModMatrix_MOD
174 return ModMatrix(res, _exter=False)
175 raise TypeError
176
177 def __matmul__(self, other: "ModMatrix") -> "ModMatrix":
178 if isinstance(other, ModMatrix):
179 assert self.m == other.n
180 res = [[0] * other.m for _ in range(self.n)]
181 for i in range(self.n):
182 si = self.a[i]
183 res_i = res[i]
184 for k in range(self.m):
185 ok = other.a[k]
186 sik = si[k]
187 for j in range(other.m):
188 res_i[j] = (res_i[j] + ok[j] * sik) % _titan_pylib_ModMatrix_MOD
189 return ModMatrix(res, _exter=False)
190 raise TypeError
191
192 def __pow__(self, n: int) -> "ModMatrix":
193 assert self.n == self.m
194 res = ModMatrix.identity(self.n)
195 a = ModMatrix([a[:] for a in self.a], _exter=False)
196 while n > 0:
197 if n & 1 == 1:
198 res @= a
199 a @= a
200 n >>= 1
201 return res
202
203 __radd__ = __add__
204 __rmul__ = __mul__
205
206 def __rsub__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
207 if isinstance(other, int):
208 other %= _titan_pylib_ModMatrix_MOD
209 res = [a[:] for a in self.a]
210 for i in range(self.n):
211 resi = res[i]
212 for j in range(self.m):
213 val = other - resi[j]
214 resi[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val
215 return ModMatrix(res, _exter=False)
216 elif isinstance(other, ModMatrix):
217 assert self.n == other.n and self.m == other.m
218 res = [a[:] for a in self.a]
219 for i in range(self.n):
220 resi = res[i]
221 oi = other.a[i]
222 for j in range(self.m):
223 val = oi[j] - resi[j]
224 resi[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val
225 return ModMatrix(res, _exter=False)
226 else:
227 raise TypeError
228
229 def __iadd__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
230 if isinstance(other, int):
231 other %= _titan_pylib_ModMatrix_MOD
232 for i in range(self.n):
233 si = self.a[i]
234 for j in range(self.m):
235 val = si[j] + other
236 si[j] = (
237 val
238 if val < _titan_pylib_ModMatrix_MOD
239 else val - _titan_pylib_ModMatrix_MOD
240 )
241 elif isinstance(other, ModMatrix):
242 assert self.n == other.n and self.m == other.m
243 for i in range(self.n):
244 si = self.a[i]
245 oi = other.a[i]
246 for j in range(self.m):
247 val = si[j] + oi[j]
248 si[j] = (
249 val
250 if val < _titan_pylib_ModMatrix_MOD
251 else val - _titan_pylib_ModMatrix_MOD
252 )
253 else:
254 raise TypeError
255 return self
256
257 def __isub__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
258 if isinstance(other, int):
259 other %= _titan_pylib_ModMatrix_MOD
260 for i in range(self.n):
261 si = self.a[i]
262 for j in range(self.m):
263 val = si[j] - other
264 si[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val
265 elif isinstance(other, ModMatrix):
266 assert self.n == other.n and self.m == other.m
267 for i in range(self.n):
268 si = self.a[i]
269 oi = other.a[i]
270 for j in range(self.m):
271 val = si[j] - oi[j]
272 si[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val
273 else:
274 raise TypeError
275 return self
276
277 def __imul__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix":
278 if isinstance(other, int):
279 other %= _titan_pylib_ModMatrix_MOD
280 for i in range(self.n):
281 si = self.a[i]
282 for j in range(self.m):
283 si[j] = si[j] * other % _titan_pylib_ModMatrix_MOD
284 elif isinstance(other, ModMatrix):
285 assert self.n == other.n and self.m == other.m
286 for i in range(self.n):
287 si = self.a[i]
288 oi = other.a[i]
289 for j in range(self.m):
290 si[j] = si[j] * oi[j] % _titan_pylib_ModMatrix_MOD
291 else:
292 raise TypeError
293 return self
294
295 def __imatmul__(self, other: "ModMatrix") -> "ModMatrix":
296 return self.__matmul__(other)
297
298 def __ipow__(self, n: int) -> "ModMatrix":
299 assert self.n == self.m
300 res = ModMatrix.identity(self.n)
301 while n:
302 if n & 1:
303 res @= self
304 self @= self
305 n >>= 1
306 return res
307
308 def __ne__(self) -> "ModMatrix":
309 a = [a[:] for a in self.a]
310 for i in range(self.n):
311 for j in range(self.m):
312 a[i][j] = (-a[i][j]) % _titan_pylib_ModMatrix_MOD
313 return ModMatrix(a, _exter=False)
314
315 def add(self, n: int, m: int, key: int) -> None:
316 assert 0 <= n < self.n and 0 <= m < self.m
317 self.a[n][m] = (self.a[n][m] + key) % _titan_pylib_ModMatrix_MOD
318
319 def get(self, n: int, m: int) -> int:
320 assert 0 <= n < self.n and 0 <= m < self.m
321 return self.a[n][m]
322
323 def get_n(self, n: int) -> list[int]:
324 assert 0 <= n < self.n
325 return self.a[n]
326
327 def set(self, n: int, m: int, key: int) -> None:
328 assert 0 <= n < self.n and 0 <= m < self.m
329 self.a[n][m] = key % _titan_pylib_ModMatrix_MOD
330
331 def tolist(self) -> list[list[int]]:
332 return [a[:] for a in self.a]
333
334 def show(self) -> None:
335 for a in self.a:
336 print(*a)
337 print()
338
339 def __iter__(self):
340 self.__iter = 0
341 return self
342
343 def __next__(self):
344 if self.__iter == self.n:
345 raise StopIteration
346 self.__iter += 1
347 return self.a[self.__iter - 1]
348
349 def __str__(self):
350 return str(self.a)