1# from titan_pylib.data_structures.set.mex_multiset import MexMultiset
2# from titan_pylib.data_structures.segment_tree.segment_tree import SegmentTree
3# from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
4# SegmentTreeInterface,
5# )
6from abc import ABC, abstractmethod
7from typing import TypeVar, Generic, Union, Iterable, Callable
8
9T = TypeVar("T")
10
11
12class SegmentTreeInterface(ABC, Generic[T]):
13
14 @abstractmethod
15 def __init__(self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
16 raise NotImplementedError
17
18 @abstractmethod
19 def set(self, k: int, v: T) -> None:
20 raise NotImplementedError
21
22 @abstractmethod
23 def get(self, k: int) -> T:
24 raise NotImplementedError
25
26 @abstractmethod
27 def prod(self, l: int, r: int) -> T:
28 raise NotImplementedError
29
30 @abstractmethod
31 def all_prod(self) -> T:
32 raise NotImplementedError
33
34 @abstractmethod
35 def max_right(self, l: int, f: Callable[[T], bool]) -> int:
36 raise NotImplementedError
37
38 @abstractmethod
39 def min_left(self, r: int, f: Callable[[T], bool]) -> int:
40 raise NotImplementedError
41
42 @abstractmethod
43 def tolist(self) -> list[T]:
44 raise NotImplementedError
45
46 @abstractmethod
47 def __getitem__(self, k: int) -> T:
48 raise NotImplementedError
49
50 @abstractmethod
51 def __setitem__(self, k: int, v: T) -> None:
52 raise NotImplementedError
53
54 @abstractmethod
55 def __str__(self):
56 raise NotImplementedError
57
58 @abstractmethod
59 def __repr__(self):
60 raise NotImplementedError
61from typing import Generic, Iterable, TypeVar, Callable, Union
62
63T = TypeVar("T")
64
65
66class SegmentTree(SegmentTreeInterface, Generic[T]):
67 """セグ木です。非再帰です。"""
68
69 def __init__(
70 self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T
71 ) -> None:
72 """``SegmentTree`` を構築します。
73 :math:`O(n)` です。
74
75 Args:
76 n_or_a (Union[int, Iterable[T]]): ``n: int`` のとき、 ``e`` を初期値として長さ ``n`` の ``SegmentTree`` を構築します。
77 ``a: Iterable[T]`` のとき、 ``a`` から ``SegmentTree`` を構築します。
78 op (Callable[[T, T], T]): 2項演算の関数です。
79 e (T): 単位元です。
80 """
81 self._op = op
82 self._e = e
83 if isinstance(n_or_a, int):
84 self._n = n_or_a
85 self._log = (self._n - 1).bit_length()
86 self._size = 1 << self._log
87 self._data = [e] * (self._size << 1)
88 else:
89 n_or_a = list(n_or_a)
90 self._n = len(n_or_a)
91 self._log = (self._n - 1).bit_length()
92 self._size = 1 << self._log
93 _data = [e] * (self._size << 1)
94 _data[self._size : self._size + self._n] = n_or_a
95 for i in range(self._size - 1, 0, -1):
96 _data[i] = op(_data[i << 1], _data[i << 1 | 1])
97 self._data = _data
98
99 def set(self, k: int, v: T) -> None:
100 """一点更新です。
101 :math:`O(\\log{n})` です。
102
103 Args:
104 k (int): 更新するインデックスです。
105 v (T): 更新する値です。
106
107 制約:
108 :math:`-n \\leq n \\leq k < n`
109 """
110 assert (
111 -self._n <= k < self._n
112 ), f"IndexError: {self.__class__.__name__}.set({k}, {v}), n={self._n}"
113 if k < 0:
114 k += self._n
115 k += self._size
116 self._data[k] = v
117 for _ in range(self._log):
118 k >>= 1
119 self._data[k] = self._op(self._data[k << 1], self._data[k << 1 | 1])
120
121 def get(self, k: int) -> T:
122 """一点取得です。
123 :math:`O(1)` です。
124
125 Args:
126 k (int): インデックスです。
127
128 制約:
129 :math:`-n \\leq n \\leq k < n`
130 """
131 assert (
132 -self._n <= k < self._n
133 ), f"IndexError: {self.__class__.__name__}.get({k}), n={self._n}"
134 if k < 0:
135 k += self._n
136 return self._data[k + self._size]
137
138 def prod(self, l: int, r: int) -> T:
139 """区間 ``[l, r)`` の総積を返します。
140 :math:`O(\\log{n})` です。
141
142 Args:
143 l (int): インデックスです。
144 r (int): インデックスです。
145
146 制約:
147 :math:`0 \\leq l \\leq r \\leq n`
148 """
149 assert (
150 0 <= l <= r <= self._n
151 ), f"IndexError: {self.__class__.__name__}.prod({l}, {r})"
152 l += self._size
153 r += self._size
154 lres = self._e
155 rres = self._e
156 while l < r:
157 if l & 1:
158 lres = self._op(lres, self._data[l])
159 l += 1
160 if r & 1:
161 rres = self._op(self._data[r ^ 1], rres)
162 l >>= 1
163 r >>= 1
164 return self._op(lres, rres)
165
166 def all_prod(self) -> T:
167 """区間 ``[0, n)`` の総積を返します。
168 :math:`O(1)` です。
169 """
170 return self._data[1]
171
172 def max_right(self, l: int, f: Callable[[T], bool]) -> int:
173 """Find the largest index R s.t. f([l, R)) == True. / O(\\log{n})"""
174 assert (
175 0 <= l <= self._n
176 ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
177 # assert f(self._e), \
178 # f'{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true.'
179 if l == self._n:
180 return self._n
181 l += self._size
182 s = self._e
183 while True:
184 while l & 1 == 0:
185 l >>= 1
186 if not f(self._op(s, self._data[l])):
187 while l < self._size:
188 l <<= 1
189 if f(self._op(s, self._data[l])):
190 s = self._op(s, self._data[l])
191 l |= 1
192 return l - self._size
193 s = self._op(s, self._data[l])
194 l += 1
195 if l & -l == l:
196 break
197 return self._n
198
199 def min_left(self, r: int, f: Callable[[T], bool]) -> int:
200 """Find the smallest index L s.t. f([L, r)) == True. / O(\\log{n})"""
201 assert (
202 0 <= r <= self._n
203 ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
204 # assert f(self._e), \
205 # f'{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true.'
206 if r == 0:
207 return 0
208 r += self._size
209 s = self._e
210 while True:
211 r -= 1
212 while r > 1 and r & 1:
213 r >>= 1
214 if not f(self._op(self._data[r], s)):
215 while r < self._size:
216 r = r << 1 | 1
217 if f(self._op(self._data[r], s)):
218 s = self._op(self._data[r], s)
219 r ^= 1
220 return r + 1 - self._size
221 s = self._op(self._data[r], s)
222 if r & -r == r:
223 break
224 return 0
225
226 def tolist(self) -> list[T]:
227 """リストにして返します。
228 :math:`O(n)` です。
229 """
230 return [self.get(i) for i in range(self._n)]
231
232 def show(self) -> None:
233 """デバッグ用のメソッドです。"""
234 print(
235 f"<{self.__class__.__name__}> [\n"
236 + "\n".join(
237 [
238 " "
239 + " ".join(
240 map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
241 )
242 for i in range(self._log + 1)
243 ]
244 )
245 + "\n]"
246 )
247
248 def __getitem__(self, k: int) -> T:
249 assert (
250 -self._n <= k < self._n
251 ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), n={self._n}"
252 return self.get(k)
253
254 def __setitem__(self, k: int, v: T):
255 assert (
256 -self._n <= k < self._n
257 ), f"IndexError: {self.__class__.__name__}.__setitem__{k}, {v}), n={self._n}"
258 self.set(k, v)
259
260 def __len__(self) -> int:
261 return self._n
262
263 def __str__(self) -> str:
264 return str(self.tolist())
265
266 def __repr__(self) -> str:
267 return f"{self.__class__.__name__}({self})"
268from typing import Iterable
269
270
271class MexMultiset:
272 """``MexMultiset`` です。
273
274 各操作は `log` がつきますが、ANDセグ木の ``log`` で割と軽いです。
275 """
276
277 def __init__(self, u: int, a: Iterable[int] = []) -> None:
278 """``[0, u)`` の範囲の mex を計算する ``MexMultiset`` を構築します。
279
280 時間・空間共に :math:`O(u)` です。
281
282 Args:
283 u (int): 値の上限です。
284 """
285 data = [0] * (u + 1)
286 init_data = [1] * (u + 1)
287 for e in a:
288 if e <= u:
289 data[e] += 1
290 init_data[e] = 0
291 self.u: int = u
292 self.data: list[int] = data
293 self.seg: SegmentTree[int] = SegmentTree(init_data, op=lambda s, t: s | t, e=0)
294
295 def add(self, key: int) -> None:
296 """``key`` を追加します。
297
298 :math:`O(\\log{n})` です。
299 """
300 if key > self.u:
301 return
302 if self.data[key] == 0:
303 self.seg[key] = 0
304 self.data[key] += 1
305
306 def remove(self, key: int) -> None:
307 """``key`` を削除します。 ``key`` は存在していなければなりません。
308
309 :math:`O(\\log{n})` です。
310 """
311 if key > self.u:
312 return
313 if self.data[key] == 1:
314 self.seg[key] = 1
315 self.data[key] -= 1
316
317 def mex(self) -> int:
318 """mex を返します。
319
320 :math:`O(\\log{n})` です。
321 """
322 return self.seg.max_right(0, lambda lr: lr == 0)