[docs]
1def ext_gcd(a: int, b: int) -> tuple[int, int, int]:
2 """gcdと、ax + by = gcd(a, b)なるx,yを返す / O(log(min(|a|, |b|)))
3
4 Args:
5 a (int):
6 b (int):
7
8 Returns:
9 tuple[int, int, int]: (gcd, x, y)
10 """
11 if b == 0:
12 return a, 1, 0
13 d, y, x = ext_gcd(b, a % b)
14 y -= a // b * x
15 if x < 0:
16 x += b // d
17 y -= a // d
18 # assert a * x + b * y == d
19 return d, x, y
20
21
[docs]
22def linear_indeterminate_equation(a: int, b: int, c: int) -> tuple[int, int, int]:
23 """`ax + by = c` の整数解を返す"""
24 g, x, y = ext_gcd(a, b)
25 if c % g != 0:
26 return None, None, None
27 c //= g
28 return g, x * c, y * c
29
30
[docs]
31def crt(B: list[int], M: list[int]) -> tuple[int, int]:
32 """中国剰余定理 / O(nlog(lcm(M)))
33 ```
34 a == B[0] (mod M[0])
35 a == B[1] (mod M[1])
36 ...
37 ```
38
39 となるような、 `a == r (mod lcm(M))` を返す
40
41
42 Returns:
43 tuple[int, int]: `m == -1` のとき解なし。
44 """
45 assert len(B) == len(M)
46 r, lcm = 0, 1
47 for i in range(len(B)):
48 d, x, _ = ext_gcd(lcm, M[i])
49 if (B[i] - r) % d != 0:
50 return (0, -1)
51 tmp = (B[i] - r) // d * x % (M[i] // d)
52 r += lcm * tmp
53 lcm *= M[i] // d
54 return (r, lcm)
55
56
57import math
58
59
[docs]
60def lcm(a: int, b: int) -> int:
61 return a // math.gcd(a, b) * b
62
63
[docs]
64def lcm_mul(A: list[int]) -> int:
65 assert len(A) > 0
66 ans = 1
67 for a in A:
68 ans = lcm(ans, a)
69 return ans
70
71
[docs]
72def totient_function(n: int) -> int:
73 """1からnまでの自然数の中で、nと互いに素なものの個数 / O(√N)"""
74 assert n > 0
75 ans = n
76 i = 2
77 while i * i <= n:
78 if n % i == 0:
79 ans -= ans // i
80 while n % i == 0:
81 n //= i
82 i += 1
83 if n > 1:
84 ans -= ans // n
85 return ans
86
87
88mod = "998244353"
89
90
[docs]
91def fastpow(a: int, b: int) -> int:
92 res = 1
93 while b:
94 if b & 1:
95 res = res * a % mod
96 a = a * a % mod
97 b >>= 1
98 return res
99
100
[docs]
101def modinv(a, mod):
102 b = mod
103 x, y, u, v = 1, 0, 0, 1
104 while b:
105 k = a // b
106 x -= k * u
107 y -= k * v
108 x, u = u, x
109 y, v = v, y
110 a, b = b, a % b
111 x %= mod
112 return x
113
114
[docs]
115def isqrt(n: int) -> int:
116 assert n >= 0
117 if n == 0:
118 return 0
119 x = 1 << (n.bit_length() + 1) >> 1
120 y = (x + n // x) >> 1
121 while y < x:
122 x, y = y, (y + n // y) >> 1
123 return x
124
125
126"Return LCM % mod"
127from collections import Counter
128from titan_pylib.math.divisors import Osa_k
129
130
[docs]
131def lcm_mod(o: Osa_k, A: list, mod: int) -> int:
132 cou = Counter()
133 for a in A:
134 for k, v in Counter(o.p_factorization(a)).items():
135 cou[k] = max(cou[k], v)
136 lcm = 1
137 for k, v in cou.items():
138 lcm *= pow(k, v, mod)
139 lcm %= mod
140 return lcm % mod
141
142
143# ----------------------- #
144
145"Return (a // b) % mod"
146"O(1), mod: prime"
147
148
[docs]
149def div_mod(a: int, b: int, mod: int) -> int:
150 "Return (a // b) % mod"
151 return (a % mod) * pow(b, mod - 2, mod) % mod
152
153
154# ----------------------- #
155
156
[docs]
157def large_pow(a, b, c, mod):
158 "return (a^(b^c)) % mod. p: prime."
159 if a % mod == 0:
160 return 0
161 return pow(a, pow(b, c, mod - 1), mod)
162
163
164# ----------------------- #
165
166
[docs]
167def mat_mul(A: list, B: list, mod: int) -> list:
168 l, m = len(A), len(A[0])
169 n = len(B[0])
170 assert m == len(B)
171 return [
172 [sum([A[i][k] * B[k][j] for k in range(m)]) % mod for j in range(n)]
173 for i in range(l)
174 ]
175
176
[docs]
177def mat_powmod(A: list, n: int, mod: int) -> list:
178 res = [[0] * len(A) for _ in range(len(A))]
179 for i in range(len(A)):
180 res[i][i] = 1
181 while n > 0:
182 if n & 1 == 1:
183 res = mat_mul(A, res, mod)
184 A = mat_mul(A, A, mod)
185 n >>= 1
186 return res
187
188
189# ----------------------- #