1# from titan_pylib.data_structures.set.wordsize_tree_multiset import WordsizeTreeMultiset
2# from titan_pylib.data_structures.set.wordsize_tree_set import WordsizeTreeSet
3from array import array
4from typing import Iterable, Optional
5
6
7class WordsizeTreeSet:
8 """``[0, u)`` の整数集合を管理する32分木です。
9 空間 :math:`O(u)` であることに注意してください。
10 """
11
12 def __init__(self, u: int, a: Iterable[int] = []) -> None:
13 """:math:`O(u)` です。"""
14 assert u >= 0
15 u += 1 # 念のため
16 self.u = u
17 data = []
18 len_ = 0
19 if a:
20 u >>= 5
21 A = array("I", bytes(4 * (u + 1)))
22 for a_ in a:
23 assert (
24 0 <= a_ < self.u
25 ), f"ValueError: {self.__class__.__name__}.__init__, {a_}, u={u}"
26 if A[a_ >> 5] >> (a_ & 31) & 1 == 0:
27 len_ += 1
28 A[a_ >> 5] |= 1 << (a_ & 31)
29 data.append(A)
30 while u:
31 a = array("I", bytes(4 * ((u >> 5) + 1)))
32 for i in range(u + 1):
33 if A[i]:
34 a[i >> 5] |= 1 << (i & 31)
35 data.append(a)
36 A = a
37 u >>= 5
38 else:
39 while u:
40 u >>= 5
41 data.append(array("I", bytes(4 * (u + 1))))
42 self.data: list[array[int]] = data
43 self.len: int = len_
44 self.len_data: int = len(data)
45
46 def add(self, v: int) -> bool:
47 """整数 ``v`` を個追加します。
48 :math:`O(\\log{u})` です。
49 """
50 assert (
51 0 <= v < self.u
52 ), f"ValueError: {self.__class__.__name__}.add({v}), u={self.u}"
53 if self.data[0][v >> 5] >> (v & 31) & 1:
54 return False
55 self.len += 1
56 for a in self.data:
57 a[v >> 5] |= 1 << (v & 31)
58 v >>= 5
59 return True
60
61 def discard(self, v: int) -> bool:
62 """整数 ``v`` を削除します。
63 :math:`O(\\log{u})` です。
64 """
65 assert (
66 0 <= v < self.u
67 ), f"ValueError: {self.__class__.__name__}.discard({v}), u={self.u}"
68 if self.data[0][v >> 5] >> (v & 31) & 1 == 0:
69 return False
70 self.len -= 1
71 for a in self.data:
72 a[v >> 5] &= ~(1 << (v & 31))
73 v >>= 5
74 if a[v]:
75 break
76 return True
77
78 def remove(self, v: int) -> None:
79 """整数 ``v`` を削除します。
80 :math:`O(\\log{u})` です。
81
82 Note: ``v`` が存在しないとき、例外を投げます。
83 """
84 assert (
85 0 <= v < self.u
86 ), f"ValueError: {self.__class__.__name__}.remove({v}), u={self.u}"
87 assert self.discard(v), f"ValueError: {v} not in self."
88
89 def ge(self, v: int) -> Optional[int]:
90 """``v`` 以上で最小の要素を返します。存在しないとき、 ``None``を返します。
91 :math:`O(\\log{u})` です。
92 """
93 assert (
94 0 <= v < self.u
95 ), f"ValueError: {self.__class__.__name__}.ge({v}), u={self.u}"
96 data = self.data
97 d = 0
98 while True:
99 if d >= self.len_data or v >> 5 >= len(data[d]):
100 return None
101 m = data[d][v >> 5] & ((~0) << (v & 31))
102 if m == 0:
103 d += 1
104 v = (v >> 5) + 1
105 else:
106 v = (v >> 5 << 5) + (m & -m).bit_length() - 1
107 if d == 0:
108 break
109 v <<= 5
110 d -= 1
111 return v
112
113 def gt(self, v: int) -> Optional[int]:
114 """``v`` より大きい値で最小の要素を返します。存在しないとき、 ``None``を返します。
115 :math:`O(\\log{u})` です。
116 """
117 assert (
118 0 <= v < self.u
119 ), f"ValueError: {self.__class__.__name__}.gt({v}), u={self.u}"
120 if v + 1 == self.u:
121 return
122 return self.ge(v + 1)
123
124 def le(self, v: int) -> Optional[int]:
125 """``v`` 以下で最大の要素を返します。存在しないとき、 ``None``を返します。
126 :math:`O(\\log{u})` です。
127 """
128 assert (
129 0 <= v < self.u
130 ), f"ValueError: {self.__class__.__name__}.le({v}), u={self.u}"
131 data = self.data
132 d = 0
133 while True:
134 if v < 0 or d >= self.len_data:
135 return None
136 m = data[d][v >> 5] & ~((~1) << (v & 31))
137 if m == 0:
138 d += 1
139 v = (v >> 5) - 1
140 else:
141 v = (v >> 5 << 5) + m.bit_length() - 1
142 if d == 0:
143 break
144 v <<= 5
145 v += 31
146 d -= 1
147 return v
148
149 def lt(self, v: int) -> Optional[int]:
150 """``v`` より小さい値で最大の要素を返します。存在しないとき、 ``None``を返します。
151 :math:`O(\\log{u})` です。
152 """
153 assert (
154 0 <= v < self.u
155 ), f"ValueError: {self.__class__.__name__}.lt({v}), u={self.u}"
156 if v - 1 == 0:
157 return
158 return self.le(v - 1)
159
160 def get_min(self) -> Optional[int]:
161 """`最小値を返します。存在しないとき、 ``None``を返します。
162 :math:`O(\\log{u})` です。
163 """
164 return self.ge(0)
165
166 def get_max(self) -> Optional[int]:
167 """最大値を返します。存在しないとき、 ``None``を返します。
168 :math:`O(\\log{u})` です。
169 """
170 return self.le(self.u - 1)
171
172 def pop_min(self) -> int:
173 """最小値を削除して返します。
174 :math:`O(\\log{u})` です。
175 """
176 v = self.get_min()
177 assert (
178 v is not None
179 ), f"IndexError: pop_min() from empty {self.__class__.__name__}."
180 self.discard(v)
181 return v
182
183 def pop_max(self) -> int:
184 """最大値を削除して返します。
185 :math:`O(\\log{u})` です。
186 """
187 v = self.get_max()
188 assert (
189 v is not None
190 ), f"IndexError: pop_max() from empty {self.__class__.__name__}."
191 self.discard(v)
192 return v
193
194 def clear(self) -> None:
195 """集合を空にします。
196 :math:`O(n\\log{u})` です。
197 """
198 for e in self:
199 self.discard(e)
200 self.len = 0
201
202 def tolist(self) -> list[int]:
203 """リストにして返します。
204 :math:`O(n\\log{u})` です。
205 """
206 return [x for x in self]
207
208 def __bool__(self):
209 return self.len > 0
210
211 def __len__(self):
212 return self.len
213
214 def __contains__(self, v: int):
215 assert (
216 0 <= v < self.u
217 ), f"ValueError: {v} in {self.__class__.__name__}, u={self.u}"
218 return self.data[0][v >> 5] >> (v & 31) & 1 == 1
219
220 def __iter__(self):
221 self._val = self.ge(0)
222 return self
223
224 def __next__(self):
225 if self._val is None:
226 raise StopIteration
227 pre = self._val
228 self._val = self.gt(pre)
229 return pre
230
231 def __str__(self):
232 return "{" + ", ".join(map(str, self)) + "}"
233
234 def __repr__(self):
235 return f"{self.__class__.__name__}({self.u}, {self})"
236from typing import Iterable, Optional, Iterator
237
238
239class WordsizeTreeMultiset:
240 """``[0, u)`` の整数多重集合を管理する32分木です。
241 空間 :math:`O(u)` であることに注意してください。
242 """
243
244 def __init__(self, u: int, a: Iterable[int] = []) -> None:
245 """:math:`O(u)` です。"""
246 u += 1 # 念のため
247 assert u >= 0
248 self.u = u
249 self.len: int = 0
250 self.st: WordsizeTreeSet = WordsizeTreeSet(u, a)
251 cnt = [0] * (u + 1)
252 for a_ in a:
253 self.len += 1
254 cnt[a_] += 1
255 self.cnt: list[int] = cnt
256
257 def add(self, v: int, cnt: int = 1) -> None:
258 """整数 ``v`` を ``cnt`` 個追加します。
259 :math:`O(\\log{u})` です。
260 """
261 assert (
262 0 <= v < self.u
263 ), f"ValueError: {self.__class__.__name__}.add({v}, {cnt}), u={self.u}"
264 self.len += cnt
265 if self.cnt[v]:
266 self.cnt[v] += cnt
267 else:
268 self.cnt[v] = cnt
269 self.st.add(v)
270
271 def discard(self, v: int, cnt: int = 1) -> bool:
272 """整数 ``v`` を ``cnt`` 個削除します。
273 :math:`O(\\log{u})` です。
274 """
275 assert (
276 0 <= v < self.u
277 ), f"ValueError: {self.__class__.__name__}.discard({v}), u={self.u}"
278 if self.cnt[v] == 0:
279 return False
280 c = self.cnt[v]
281 if c > cnt:
282 self.cnt[v] -= cnt
283 self.len -= cnt
284 else:
285 self.len -= c
286 self.cnt[v] = 0
287 self.st.discard(v)
288 return True
289
290 def remove(self, v: int) -> None:
291 """整数 ``v`` を削除します。
292 :math:`O(\\log{u})` です。
293
294 Note: ``v`` が存在しないとき、例外を投げます。
295 """
296 assert (
297 0 <= v < self.u
298 ), f"ValueError: {self.__class__.__name__}.remove({v}), u={self.u}"
299 assert self.discard(v), f"ValueError: {v} not in self."
300
301 def count(self, v: int) -> int:
302 """整数 ``v`` の個数を返します。
303 :math:`O(1)` です。
304 """
305 assert (
306 0 <= v < self.u
307 ), f"ValueError: {self.__class__.__name__}.count({v}), u={self.u}"
308 return self.cnt[v]
309
310 def ge(self, v: int) -> Optional[int]:
311 """``v`` 以上で最小の要素を返します。存在しないとき、 ``None``を返します。
312 :math:`O(\\log{u})` です。
313 """
314 assert (
315 0 <= v < self.u
316 ), f"ValueError: {self.__class__.__name__}.ge({v}), u={self.u}"
317 return self.st.ge(v)
318
319 def gt(self, v: int) -> Optional[int]:
320 """``v`` より大きい値で最小の要素を返します。存在しないとき、 ``None``を返します。
321 :math:`O(\\log{u})` です。
322 """
323 assert (
324 0 <= v < self.u
325 ), f"ValueError: {self.__class__.__name__}.gt({v}), u={self.u}"
326 return self.ge(v + 1)
327
328 def le(self, v: int) -> Optional[int]:
329 """``v`` 以下で最大の要素を返します。存在しないとき、 ``None``を返します。
330 :math:`O(\\log{u})` です。
331 """
332 assert (
333 0 <= v < self.u
334 ), f"ValueError: {self.__class__.__name__}.le({v}), u={self.u}"
335 return self.st.le(v)
336
337 def lt(self, v: int) -> Optional[int]:
338 """``v`` より小さい値で最大の要素を返します。存在しないとき、 ``None``を返します。
339 :math:`O(\\log{u})` です。
340 """
341 assert (
342 0 <= v < self.u
343 ), f"ValueError: {self.__class__.__name__}.lt({v}), u={self.u}"
344 return self.le(v - 1)
345
346 def get_min(self) -> Optional[int]:
347 """`最小値を返します。存在しないとき、 ``None``を返します。
348 :math:`O(\\log{u})` です。
349 """
350 return self.st.ge(0)
351
352 def get_max(self) -> Optional[int]:
353 """最大値を返します。存在しないとき、 ``None``を返します。
354 :math:`O(\\log{u})` です。
355 """
356 return self.st.le(self.st.u - 1)
357
358 def pop_min(self) -> int:
359 """最小値を削除して返します。
360 :math:`O(\\log{u})` です。
361 """
362 assert self, f"IndexError: pop_min() from empty {self.__class__.__name__}."
363 x = self.st.get_min()
364 self.discard(x)
365 return x
366
367 def pop_max(self) -> int:
368 """最大値を削除して返します。
369 :math:`O(\\log{u})` です。
370 """
371 assert self, f"IndexError: pop_max() from empty {self.__class__.__name__}."
372 x = self.st.get_max()
373 self.discard(x)
374 return x
375
376 def keys(self) -> Iterator[int]:
377 """集合に含まれている要素(重複無し)を昇順にイテレートします。
378 :math:`O(n\\log{u})` です。
379 """
380 v = self.st.get_min()
381 while v is not None:
382 yield v
383 v = self.st.gt(v)
384
385 def values(self) -> Iterator[int]:
386 """集合に含まれている要素の個数を、要素の昇順にイテレートします。
387 :math:`O(n\\log{u})` です。
388 """
389 v = self.st.get_min()
390 while v is not None:
391 yield self.cnt[v]
392 v = self.st.gt(v)
393
394 def items(self) -> Iterator[tuple[int, int]]:
395 """集合に含まれている要素とその個数を、要素の昇順にイテレートします。
396 :math:`O(n\\log{u})` です。
397 """
398 v = self.st.get_min()
399 while v is not None:
400 yield (v, self.cnt[v])
401 v = self.st.gt(v)
402
403 def clear(self) -> None:
404 """集合を空にします。
405 :math:`O(n\\log{u})` です。
406 """
407 for e in self:
408 self.cnt[e] = 0
409 self.st.discard(e)
410 self.len = 0
411
412 def tolist(self) -> list[int]:
413 """リストにして返します。
414 :math:`O(n\\log{u})` です。
415 """
416 return [x for x in self]
417
418 def __contains__(self, v: int):
419 """:math:`O(1)` です。"""
420 return self.cnt[v] > 0
421
422 def __bool__(self):
423 return self.len > 0
424
425 def __len__(self):
426 return self.len
427
428 def __iter__(self):
429 self.__val = self.st.get_min()
430 self.__valcnt = 1
431 return self
432
433 def __next__(self):
434 if self.__val is None:
435 raise StopIteration
436 pre = self.__val
437 self.__valcnt += 1
438 if self.__valcnt > self.cnt[self.__val]:
439 self.__valcnt = 1
440 self.__val = self.gt(self.__val)
441 return pre
442
443 def __str__(self):
444 return "{" + ", ".join(map(str, self)) + "}"
445
446 def __repr__(self):
447 return (
448 f"{self.__class__.__name__}({self.u}, [" + ", ".join(map(str, self)) + "])"
449 )