1from array import array
2from typing import Generic, Iterable, TypeVar, Optional, Final
3
4T = TypeVar("T")
5
6DELTA: Final[int] = 3
7GAMMA: Final[int] = 2
8
9
[docs]
10class WBTSet(Generic[T]):
11
12 def __init__(self, a: Iterable[T] = []) -> None:
13 self.root = 0
14 self.key = [0]
15 self.size = array("I", bytes(4))
16 self.left = array("I", bytes(4))
17 self.right = array("I", bytes(4))
18 self.end = 1
19 if not isinstance(a, list):
20 a = list(a)
21 if a:
22 self._build(a)
23
[docs]
24 def reserve(self, n: int) -> None:
25 if n <= 0:
26 return
27 self.key += [0] * n
28 a = array("I", bytes(4 * n))
29 self.left += a
30 self.right += a
31 self.size += array("I", [1] * n)
32
33 def _build(self, a: list[T]) -> None:
34 left, right, size = self.left, self.right, self.size
35
36 def sort(l: int, r: int) -> int:
37 mid = (l + r) >> 1
38 node = mid
39 if l != mid:
40 left[node] = sort(l, mid)
41 size[node] += size[left[node]]
42 if mid + 1 != r:
43 right[node] = sort(mid + 1, r)
44 size[node] += size[right[node]]
45 return node
46
47 n = len(a)
48 if n == 0:
49 return
50 if not all(a[i] < a[i + 1] for i in range(n - 1)):
51 b = sorted(a)
52 a = [b[0]]
53 for i in range(1, n):
54 if b[i] != a[-1]:
55 a.append(b[i])
56 n = len(a)
57 end = self.end
58 self.end += n
59 self.reserve(n)
60 self.key[end : end + n] = a
61 self.root = sort(end, n + end)
62
63 def _rotate_left(self, node: int) -> int:
64 left, right, size = self.left, self.right, self.size
65 u = right[node]
66 size[u] = size[node]
67 size[node] -= size[right[u]] + 1
68 right[node] = left[u]
69 left[u] = node
70 return u
71
72 def _rotate_right(self, node: int) -> int:
73 left, right, size = self.left, self.right, self.size
74 u = left[node]
75 size[u] = size[node]
76 size[node] -= size[left[u]] + 1
77 left[node] = right[u]
78 right[u] = node
79 return u
80
81 def _make_node(self, key: T) -> int:
82 end = self.end
83 if end >= len(self.key):
84 self.key.append(key)
85 self.size.append(1)
86 self.left.append(0)
87 self.right.append(0)
88 else:
89 self.key[end] = key
90 self.end += 1
91 return end
92
93 def _weight_left(self, node: int) -> int:
94 return self.size[self.left[node]] + 1
95
96 def _weight_right(self, node: int) -> int:
97 return self.size[self.right[node]] + 1
98
[docs]
99 def ave_height(self):
100 if not self.root:
101 return 0
102 left, right, size, keys = self.left, self.right, self.size, self.key
103 ans = 0
104
105 def dfs(node, dep):
106 nonlocal ans
107 ans += dep
108 if left[node]:
109 dfs(left[node], dep + 1)
110 if right[node]:
111 dfs(right[node], dep + 1)
112
113 dfs(self.root, 1)
114 ans /= len(self)
115 return ans
116
[docs]
117 def debug(self, root):
118 left, right, size, keys = self.left, self.right, self.size, self.key
119
120 def dfs(node, indent):
121 if not node:
122 return
123 s = " " * indent
124 print(f"{s}key={keys[node]}, idx={node}")
125 if left[node]:
126 print(f"{s}left: {keys[left[node]]}, idx={left[node]}")
127 dfs(left[node], indent + 2)
128 if right[node]:
129 print(f"{s}righ: {keys[right[node]]}, idx={right[node]}")
130 dfs(right[node], indent + 2)
131
132 dfs(root, 0)
133
[docs]
134 def add(self, key: T) -> bool:
135 if self.root == 0:
136 self.root = self._make_node(key)
137 return True
138 left, right, size, keys = self.left, self.right, self.size, self.key
139 node = self.root
140 path = []
141 di = 0
142 while node:
143 if key == keys[node]:
144 return False
145 path.append(node)
146 di <<= 1
147 if key < keys[node]:
148 di |= 1
149 node = left[node]
150 else:
151 node = right[node]
152 # self.debug(self.root)
153 if di & 1:
154 left[path[-1]] = self._make_node(key)
155 else:
156 right[path[-1]] = self._make_node(key)
157 while path:
158 node = path.pop()
159 size[node] += 1
160 di >>= 1
161 wl = self._weight_left(node)
162 wr = self._weight_right(node)
163 if wl * DELTA < wr:
164 # print("wl * DELTA < wr")
165 # self.debug(node)
166 if (
167 self._weight_left(right[node])
168 >= self._weight_right(right[node]) * GAMMA
169 ):
170 right[node] = self._rotate_right(right[node])
171 node = self._rotate_left(node)
172 # self.debug(node)
173 # assert node
174 elif wr * DELTA < wl:
175 # print("wr * DELTA < wl")
176 if (
177 self._weight_right(left[node])
178 >= self._weight_left(left[node]) * GAMMA
179 ):
180 # print("left")
181 left[node] = self._rotate_left(left[node])
182 # print("right")
183 node = self._rotate_right(node)
184 # assert node
185 if path:
186 if di & 1:
187 left[path[-1]] = node
188 else:
189 right[path[-1]] = node
190 else:
191 self.root = node
192 return True
193
[docs]
194 def remove(self, key: T) -> bool:
195 if self.discard(key):
196 return True
197 raise KeyError(key)
198
[docs]
199 def discard(self, key: T) -> bool:
200 left, right, size, keys = self.left, self.right, self.size, self.key
201 di = 0
202 path = []
203 node = self.root
204 while node:
205 if key == keys[node]:
206 break
207 path.append(node)
208 di <<= 1
209 if key < keys[node]:
210 di |= 1
211 node = left[node]
212 else:
213 node = right[node]
214 else:
215 return False
216 if left[node] and right[node]:
217 path.append(node)
218 di <<= 1
219 di |= 1
220 lmax = left[node]
221 while right[lmax]:
222 path.append(lmax)
223 di <<= 1
224 lmax = right[lmax]
225 keys[node] = keys[lmax]
226 node = lmax
227 cnode = right[node] if left[node] == 0 else left[node]
228 if path:
229 if di & 1:
230 left[path[-1]] = cnode
231 else:
232 right[path[-1]] = cnode
233 else:
234 self.root = cnode
235 return True
236 while path:
237 node = path.pop()
238 size[node] -= 1
239 di >>= 1
240 wl = self._weight_left(node)
241 wr = self._weight_right(node)
242 if wl * DELTA < wr:
243 if (
244 self._weight_left(right[node])
245 >= self._weight_right(right[node]) * GAMMA
246 ):
247 right[node] = self._rotate_right(right[node])
248 node = self._rotate_left(node)
249 elif wr * DELTA < wl:
250 if (
251 self._weight_right(left[node])
252 >= self._weight_left(left[node]) * GAMMA
253 ):
254 left[node] = self._rotate_left(left[node])
255 node = self._rotate_right(node)
256 if path:
257 if di & 1:
258 left[path[-1]] = node
259 else:
260 right[path[-1]] = node
261 else:
262 self.root = node
263 return True
264
[docs]
265 def le(self, key: T) -> Optional[T]:
266 keys, left, right = self.key, self.left, self.right
267 res = None
268 node = self.root
269 while node:
270 if key == keys[node]:
271 return keys[node]
272 if key < keys[node]:
273 node = left[node]
274 else:
275 res = keys[node]
276 node = right[node]
277 return res
278
[docs]
279 def lt(self, key: T) -> Optional[T]:
280 keys, left, right = self.key, self.left, self.right
281 res = None
282 node = self.root
283 while node:
284 if key <= keys[node]:
285 node = left[node]
286 else:
287 res = keys[node]
288 node = right[node]
289 return res
290
[docs]
291 def ge(self, key: T) -> Optional[T]:
292 keys, left, right = self.key, self.left, self.right
293 res = None
294 node = self.root
295 while node:
296 if key == keys[node]:
297 return keys[node]
298 if key < keys[node]:
299 res = keys[node]
300 node = left[node]
301 else:
302 node = right[node]
303 return res
304
[docs]
305 def gt(self, key: T) -> Optional[T]:
306 keys, left, right = self.key, self.left, self.right
307 res = None
308 node = self.root
309 while node:
310 if key < keys[node]:
311 res = keys[node]
312 node = left[node]
313 else:
314 node = right[node]
315 return res
316
[docs]
317 def index(self, key: T) -> int:
318 keys, left, right, size = self.key, self.left, self.right, self.size
319 k = 0
320 node = self.root
321 while node:
322 if key == keys[node]:
323 k += size[left[node]]
324 break
325 if key < keys[node]:
326 node = left[node]
327 else:
328 k += size[left[node]] + 1
329 node = right[node]
330 return k
331
[docs]
332 def index_right(self, key: T) -> int:
333 keys, left, right, size = self.key, self.left, self.right, self.size
334 k, node = 0, self.root
335 while node:
336 if key == keys[node]:
337 k += size[left[node]] + 1
338 break
339 if key < keys[node]:
340 node = left[node]
341 else:
342 k += size[left[node]] + 1
343 node = right[node]
344 return k
345
[docs]
346 def get_max(self) -> Optional[T]:
347 if not self:
348 return
349 return self[len(self) - 1]
350
[docs]
351 def get_min(self) -> Optional[T]:
352 if not self:
353 return
354 return self[0]
355
[docs]
356 def pop(self, k: int = -1) -> T:
357 left, right, size, key = self.left, self.right, self.size, self.key
358 if k < 0:
359 k += size[self.root]
360 assert 0 <= k and k < size[self.root], "IndexError"
361 path = []
362 di = 0
363 node = self.root
364 while True:
365 t = size[left[node]]
366 if t == k:
367 res = key[node]
368 break
369 path.append(node)
370 di <<= 1
371 if t < k:
372 k -= t + 1
373 node = right[node]
374 else:
375 di |= 1
376 node = left[node]
377 if left[node] and right[node]:
378 path.append(node)
379 di <<= 1
380 di |= 1
381 lmax = left[node]
382 while right[lmax]:
383 path.append(lmax)
384 di <<= 1
385 lmax = right[lmax]
386 key[node] = key[lmax]
387 node = lmax
388 cnode = right[node] if left[node] == 0 else left[node]
389 if path:
390 if di & 1:
391 left[path[-1]] = cnode
392 else:
393 right[path[-1]] = cnode
394 else:
395 self.root = cnode
396 return res
397 while path:
398 node = path.pop()
399 size[node] -= 1
400 di >>= 1
401 wl = self._weight_left(node)
402 wr = self._weight_right(node)
403 if wl * DELTA < wr:
404 if (
405 self._weight_left(right[node])
406 >= self._weight_right(right[node]) * GAMMA
407 ):
408 right[node] = self._rotate_right(right[node])
409 node = self._rotate_left(node)
410 elif wr * DELTA < wl:
411 if (
412 self._weight_right(left[node])
413 >= self._weight_left(left[node]) * GAMMA
414 ):
415 left[node] = self._rotate_left(left[node])
416 node = self._rotate_right(node)
417 if path:
418 if di & 1:
419 left[path[-1]] = node
420 else:
421 right[path[-1]] = node
422 else:
423 self.root = node
424 return res
425
[docs]
426 def pop_max(self) -> T:
427 return self.pop()
428
[docs]
429 def pop_min(self) -> T:
430 return self.pop(0)
431
[docs]
432 def clear(self) -> None:
433 self.root = 0
434
[docs]
435 def tolist(self) -> list[T]:
436 left, right, keys = self.left, self.right, self.key
437 node = self.root
438 stack, a = [], []
439 while stack or node:
440 if node:
441 stack.append(node)
442 node = left[node]
443 else:
444 node = stack.pop()
445 a.append(keys[node])
446 node = right[node]
447 return a
448
[docs]
449 def check(self) -> int:
450 """作業用デバック関数"""
451 if self.root == 0:
452 return 0
453
454 def _balance_check(node: int) -> None:
455 if node == 0:
456 return
457 if not self._weight_left(node) * DELTA >= self._weight_right(node):
458 print(self._weight_left(node), self._weight_right(node), flush=True)
459 assert False, f"self._weight_left() * DELTA >= self._weight_right()"
460 if not self._weight_right(node) * DELTA >= self._weight_left(node):
461 print(self._weight_left(node), self._weight_right(node), flush=True)
462 assert False, f"self._weight_right() * DELTA >= self._weight_left()"
463
464 keys = self.key
465
466 # _size, height
467 def dfs(node) -> tuple[int, int]:
468 _balance_check(node)
469 h = 0
470 s = 1
471 if self.left[node]:
472 assert keys[self.left[node]] < keys[node]
473 ls, lh = dfs(self.left[node])
474 s += ls
475 h = max(h, lh)
476 if self.right[node]:
477 assert keys[node] < keys[self.right[node]]
478 rs, rh = dfs(self.right[node])
479 s += rs
480 h = max(h, rh)
481 assert self.size[node] == s
482 return s, h + 1
483
484 _, h = dfs(self.root)
485 return h
486
487 def __contains__(self, key: T) -> bool:
488 keys, left, right = self.key, self.left, self.right
489 node = self.root
490 while node:
491 if key == keys[node]:
492 return True
493 node = left[node] if key < keys[node] else right[node]
494 return False
495
496 def __getitem__(self, k: int) -> T:
497 left, right, size, key = self.left, self.right, self.size, self.key
498 if k < 0:
499 k += size[self.root]
500 assert (
501 0 <= k and k < size[self.root]
502 ), f"IndexError: WBTSet[{k}], len={len(self)}"
503 node = self.root
504 while True:
505 t = size[left[node]]
506 if t == k:
507 return key[node]
508 if t < k:
509 k -= t + 1
510 node = right[node]
511 else:
512 node = left[node]
513
514 def __iter__(self):
515 self.__iter = 0
516 return self
517
518 def __next__(self):
519 if self.__iter == self.__len__():
520 raise StopIteration
521 res = self[self.__iter]
522 self.__iter += 1
523 return res
524
525 def __reversed__(self):
526 for i in range(self.__len__()):
527 yield self[-i - 1]
528
529 def __len__(self):
530 return self.size[self.root]
531
532 def __bool__(self):
533 return self.root != 0
534
535 def __str__(self):
536 return "{" + ", ".join(map(str, self.tolist())) + "}"
537
538 def __repr__(self):
539 return f"WBTSet({self})"