Source code for titan_pylib.math.mod_matrix

  1from typing import Union, Final
  2
  3_titan_pylib_ModMatrix_MOD: Final[int] = 998244353
  4
  5
[docs] 6class ModMatrix: 7
[docs] 8 @staticmethod 9 def zeros(n: int, m: int) -> "ModMatrix": 10 return ModMatrix([[0] * m for _ in range(n)], _exter=False)
11
[docs] 12 @staticmethod 13 def ones(n: int, m: int) -> "ModMatrix": 14 return ModMatrix([[1] * m for _ in range(n)], _exter=False)
15
[docs] 16 @staticmethod 17 def identity(n: int) -> "ModMatrix": 18 a = [[0] * n for _ in range(n)] 19 for i in range(n): 20 a[i][i] = 1 21 return ModMatrix(a, _exter=False)
22 23 def __init__(self, a: list[list[int]], _exter=True) -> None: 24 self.n: int = len(a) 25 self.m: int = len(a[0]) if self.n > 0 else 0 26 if _exter: 27 for ai in a: 28 for j in range(self.m): 29 ai[j] %= _titan_pylib_ModMatrix_MOD 30 self.a: list[list[int]] = a 31
[docs] 32 def det(self, inplace=False) -> int: 33 # 上三角行列の行列式はその対角成分の総積であることを利用 34 assert self.n == self.m 35 a = self.a if inplace else [a[:] for a in self.a] 36 flip = 0 37 res = 1 38 for i, ai in enumerate(a): 39 if ai[i] == 0: 40 for j in range(i + 1, self.n): 41 if a[j][i] != 0: 42 a[i], a[j] = a[j], a[i] 43 ai = a[i] 44 flip ^= 1 45 break 46 else: 47 return 0 48 inv = pow(ai[i], -1, _titan_pylib_ModMatrix_MOD) 49 for j in range(i + 1, self.n): 50 aj = a[j] 51 freq = aj[i] * inv % _titan_pylib_ModMatrix_MOD 52 for k in range(i + 1, self.n): # i+1スタートで十分 53 aj[k] = (aj[k] - freq * ai[k]) % _titan_pylib_ModMatrix_MOD 54 res *= ai[i] 55 res %= _titan_pylib_ModMatrix_MOD 56 if flip: 57 res = -res % _titan_pylib_ModMatrix_MOD 58 return res
59
[docs] 60 def inv(self, inplace=False) -> Union[None, "ModMatrix"]: 61 # 掃き出し法の利用 62 assert self.n == self.m 63 a = self.a if inplace else [a[:] for a in self.a] 64 r = [[0] * self.n for _ in range(self.n)] 65 for i in range(self.n): 66 r[i][i] = 1 67 for i in range(self.n): 68 ai = a[i] 69 ri = r[i] 70 if ai[i] == 0: 71 for j in range(i + 1, self.n): 72 if a[j][i] != 0: 73 a[i], a[j] = a[j], a[i] 74 ai = a[i] 75 r[i], r[j] = r[j], r[i] 76 ri = r[i] 77 break 78 else: 79 return None 80 tmp = pow(ai[i], _titan_pylib_ModMatrix_MOD - 2, _titan_pylib_ModMatrix_MOD) 81 for j in range(self.n): 82 ai[j] = ai[j] * tmp % _titan_pylib_ModMatrix_MOD 83 ri[j] = ri[j] * tmp % _titan_pylib_ModMatrix_MOD 84 for j in range(self.n): 85 if i == j: 86 continue 87 aj = a[j] 88 rj = r[j] 89 aji = aj[i] 90 for k in range(self.n): 91 aj[k] = (aj[k] - ai[k] * aji) % _titan_pylib_ModMatrix_MOD 92 rj[k] = (rj[k] - ri[k] * aji) % _titan_pylib_ModMatrix_MOD 93 return ModMatrix(r, _exter=False)
94
[docs] 95 @classmethod 96 def linear_equations(cls, A: "ModMatrix", b: "ModMatrix", inplace=False): 97 # A_inv = A.inv(inplace=inplace) 98 # res = A_inv @ b 99 # return res 100 pass
101 102 def __add__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix": 103 if isinstance(other, int): 104 other %= _titan_pylib_ModMatrix_MOD 105 res = [a[:] for a in self.a] 106 for i in range(self.n): 107 resi = res[i] 108 for j in range(self.m): 109 val = resi[j] + other 110 resi[j] = ( 111 val 112 if val < _titan_pylib_ModMatrix_MOD 113 else val - _titan_pylib_ModMatrix_MOD 114 ) 115 return ModMatrix(res, _exter=False) 116 elif isinstance(other, ModMatrix): 117 assert self.n == other.n and self.m == other.m 118 res = [a[:] for a in self.a] 119 for i in range(self.n): 120 resi = res[i] 121 oi = other.a[i] 122 for j in range(self.m): 123 val = resi[j] + oi[j] 124 resi[j] = ( 125 val 126 if val < _titan_pylib_ModMatrix_MOD 127 else val - _titan_pylib_ModMatrix_MOD 128 ) 129 return ModMatrix(res, _exter=False) 130 else: 131 raise TypeError 132 133 def __sub__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix": 134 if isinstance(other, int): 135 other %= _titan_pylib_ModMatrix_MOD 136 res = [a[:] for a in self.a] 137 for i in range(self.n): 138 resi = res[i] 139 for j in range(self.m): 140 val = resi[j] - other 141 resi[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val 142 return ModMatrix(res, _exter=False) 143 elif isinstance(other, ModMatrix): 144 assert self.n == other.n and self.m == other.m 145 res = [a[:] for a in self.a] 146 for i in range(self.n): 147 resi = res[i] 148 oi = other.a[i] 149 for j in range(self.m): 150 val = resi[j] - oi[j] 151 resi[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val 152 return ModMatrix(res, _exter=False) 153 else: 154 raise TypeError 155 156 def __mul__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix": 157 if isinstance(other, int): 158 other %= _titan_pylib_ModMatrix_MOD 159 res = [a[:] for a in self.a] 160 for i in range(self.n): 161 resi = res[i] 162 for j in range(self.m): 163 resi[j] = resi[j] * other % _titan_pylib_ModMatrix_MOD 164 return ModMatrix(res, _exter=False) 165 if isinstance(other, ModMatrix): 166 assert self.n == other.n and self.m == other.m 167 res = [a[:] for a in self.a] 168 for i in range(self.n): 169 resi = res[i] 170 oi = other.a[i] 171 for j in range(self.m): 172 resi[j] = resi[j] * oi[j] % _titan_pylib_ModMatrix_MOD 173 return ModMatrix(res, _exter=False) 174 raise TypeError 175 176 def __matmul__(self, other: "ModMatrix") -> "ModMatrix": 177 if isinstance(other, ModMatrix): 178 assert self.m == other.n 179 res = [[0] * other.m for _ in range(self.n)] 180 for i in range(self.n): 181 si = self.a[i] 182 res_i = res[i] 183 for k in range(self.m): 184 ok = other.a[k] 185 sik = si[k] 186 for j in range(other.m): 187 res_i[j] = (res_i[j] + ok[j] * sik) % _titan_pylib_ModMatrix_MOD 188 return ModMatrix(res, _exter=False) 189 raise TypeError 190 191 def __pow__(self, n: int) -> "ModMatrix": 192 assert self.n == self.m 193 res = ModMatrix.identity(self.n) 194 a = ModMatrix([a[:] for a in self.a], _exter=False) 195 while n > 0: 196 if n & 1 == 1: 197 res @= a 198 a @= a 199 n >>= 1 200 return res 201 202 __radd__ = __add__ 203 __rmul__ = __mul__ 204 205 def __rsub__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix": 206 if isinstance(other, int): 207 other %= _titan_pylib_ModMatrix_MOD 208 res = [a[:] for a in self.a] 209 for i in range(self.n): 210 resi = res[i] 211 for j in range(self.m): 212 val = other - resi[j] 213 resi[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val 214 return ModMatrix(res, _exter=False) 215 elif isinstance(other, ModMatrix): 216 assert self.n == other.n and self.m == other.m 217 res = [a[:] for a in self.a] 218 for i in range(self.n): 219 resi = res[i] 220 oi = other.a[i] 221 for j in range(self.m): 222 val = oi[j] - resi[j] 223 resi[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val 224 return ModMatrix(res, _exter=False) 225 else: 226 raise TypeError 227 228 def __iadd__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix": 229 if isinstance(other, int): 230 other %= _titan_pylib_ModMatrix_MOD 231 for i in range(self.n): 232 si = self.a[i] 233 for j in range(self.m): 234 val = si[j] + other 235 si[j] = ( 236 val 237 if val < _titan_pylib_ModMatrix_MOD 238 else val - _titan_pylib_ModMatrix_MOD 239 ) 240 elif isinstance(other, ModMatrix): 241 assert self.n == other.n and self.m == other.m 242 for i in range(self.n): 243 si = self.a[i] 244 oi = other.a[i] 245 for j in range(self.m): 246 val = si[j] + oi[j] 247 si[j] = ( 248 val 249 if val < _titan_pylib_ModMatrix_MOD 250 else val - _titan_pylib_ModMatrix_MOD 251 ) 252 else: 253 raise TypeError 254 return self 255 256 def __isub__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix": 257 if isinstance(other, int): 258 other %= _titan_pylib_ModMatrix_MOD 259 for i in range(self.n): 260 si = self.a[i] 261 for j in range(self.m): 262 val = si[j] - other 263 si[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val 264 elif isinstance(other, ModMatrix): 265 assert self.n == other.n and self.m == other.m 266 for i in range(self.n): 267 si = self.a[i] 268 oi = other.a[i] 269 for j in range(self.m): 270 val = si[j] - oi[j] 271 si[j] = val + _titan_pylib_ModMatrix_MOD if val < 0 else val 272 else: 273 raise TypeError 274 return self 275 276 def __imul__(self, other: Union[int, "ModMatrix"]) -> "ModMatrix": 277 if isinstance(other, int): 278 other %= _titan_pylib_ModMatrix_MOD 279 for i in range(self.n): 280 si = self.a[i] 281 for j in range(self.m): 282 si[j] = si[j] * other % _titan_pylib_ModMatrix_MOD 283 elif isinstance(other, ModMatrix): 284 assert self.n == other.n and self.m == other.m 285 for i in range(self.n): 286 si = self.a[i] 287 oi = other.a[i] 288 for j in range(self.m): 289 si[j] = si[j] * oi[j] % _titan_pylib_ModMatrix_MOD 290 else: 291 raise TypeError 292 return self 293 294 def __imatmul__(self, other: "ModMatrix") -> "ModMatrix": 295 return self.__matmul__(other) 296 297 def __ipow__(self, n: int) -> "ModMatrix": 298 assert self.n == self.m 299 res = ModMatrix.identity(self.n) 300 while n: 301 if n & 1: 302 res @= self 303 self @= self 304 n >>= 1 305 return res 306 307 def __ne__(self) -> "ModMatrix": 308 a = [a[:] for a in self.a] 309 for i in range(self.n): 310 for j in range(self.m): 311 a[i][j] = (-a[i][j]) % _titan_pylib_ModMatrix_MOD 312 return ModMatrix(a, _exter=False) 313
[docs] 314 def add(self, n: int, m: int, key: int) -> None: 315 assert 0 <= n < self.n and 0 <= m < self.m 316 self.a[n][m] = (self.a[n][m] + key) % _titan_pylib_ModMatrix_MOD
317
[docs] 318 def get(self, n: int, m: int) -> int: 319 assert 0 <= n < self.n and 0 <= m < self.m 320 return self.a[n][m]
321
[docs] 322 def get_n(self, n: int) -> list[int]: 323 assert 0 <= n < self.n 324 return self.a[n]
325
[docs] 326 def set(self, n: int, m: int, key: int) -> None: 327 assert 0 <= n < self.n and 0 <= m < self.m 328 self.a[n][m] = key % _titan_pylib_ModMatrix_MOD
329
[docs] 330 def tolist(self) -> list[list[int]]: 331 return [a[:] for a in self.a]
332
[docs] 333 def show(self) -> None: 334 for a in self.a: 335 print(*a) 336 print()
337 338 def __iter__(self): 339 self.__iter = 0 340 return self 341 342 def __next__(self): 343 if self.__iter == self.n: 344 raise StopIteration 345 self.__iter += 1 346 return self.a[self.__iter - 1] 347 348 def __str__(self): 349 return str(self.a)