1# from titan_pylib.data_structures.avl_tree.wbt_set import WBTSet
2from array import array
3from typing import Generic, Iterable, TypeVar, Optional, Final
4
5T = TypeVar("T")
6
7DELTA: Final[int] = 3
8GAMMA: Final[int] = 2
9
10
11class WBTSet(Generic[T]):
12
13 def __init__(self, a: Iterable[T] = []) -> None:
14 self.root = 0
15 self.key = [0]
16 self.size = array("I", bytes(4))
17 self.left = array("I", bytes(4))
18 self.right = array("I", bytes(4))
19 self.end = 1
20 if not isinstance(a, list):
21 a = list(a)
22 if a:
23 self._build(a)
24
25 def reserve(self, n: int) -> None:
26 if n <= 0:
27 return
28 self.key += [0] * n
29 a = array("I", bytes(4 * n))
30 self.left += a
31 self.right += a
32 self.size += array("I", [1] * n)
33
34 def _build(self, a: list[T]) -> None:
35 left, right, size = self.left, self.right, self.size
36
37 def sort(l: int, r: int) -> int:
38 mid = (l + r) >> 1
39 node = mid
40 if l != mid:
41 left[node] = sort(l, mid)
42 size[node] += size[left[node]]
43 if mid + 1 != r:
44 right[node] = sort(mid + 1, r)
45 size[node] += size[right[node]]
46 return node
47
48 n = len(a)
49 if n == 0:
50 return
51 if not all(a[i] < a[i + 1] for i in range(n - 1)):
52 b = sorted(a)
53 a = [b[0]]
54 for i in range(1, n):
55 if b[i] != a[-1]:
56 a.append(b[i])
57 n = len(a)
58 end = self.end
59 self.end += n
60 self.reserve(n)
61 self.key[end : end + n] = a
62 self.root = sort(end, n + end)
63
64 def _rotate_left(self, node: int) -> int:
65 left, right, size = self.left, self.right, self.size
66 u = right[node]
67 size[u] = size[node]
68 size[node] -= size[right[u]] + 1
69 right[node] = left[u]
70 left[u] = node
71 return u
72
73 def _rotate_right(self, node: int) -> int:
74 left, right, size = self.left, self.right, self.size
75 u = left[node]
76 size[u] = size[node]
77 size[node] -= size[left[u]] + 1
78 left[node] = right[u]
79 right[u] = node
80 return u
81
82 def _make_node(self, key: T) -> int:
83 end = self.end
84 if end >= len(self.key):
85 self.key.append(key)
86 self.size.append(1)
87 self.left.append(0)
88 self.right.append(0)
89 else:
90 self.key[end] = key
91 self.end += 1
92 return end
93
94 def _weight_left(self, node: int) -> int:
95 return self.size[self.left[node]] + 1
96
97 def _weight_right(self, node: int) -> int:
98 return self.size[self.right[node]] + 1
99
100 def ave_height(self):
101 if not self.root:
102 return 0
103 left, right, size, keys = self.left, self.right, self.size, self.key
104 ans = 0
105
106 def dfs(node, dep):
107 nonlocal ans
108 ans += dep
109 if left[node]:
110 dfs(left[node], dep + 1)
111 if right[node]:
112 dfs(right[node], dep + 1)
113
114 dfs(self.root, 1)
115 ans /= len(self)
116 return ans
117
118 def debug(self, root):
119 left, right, size, keys = self.left, self.right, self.size, self.key
120
121 def dfs(node, indent):
122 if not node:
123 return
124 s = " " * indent
125 print(f"{s}key={keys[node]}, idx={node}")
126 if left[node]:
127 print(f"{s}left: {keys[left[node]]}, idx={left[node]}")
128 dfs(left[node], indent + 2)
129 if right[node]:
130 print(f"{s}righ: {keys[right[node]]}, idx={right[node]}")
131 dfs(right[node], indent + 2)
132
133 dfs(root, 0)
134
135 def add(self, key: T) -> bool:
136 if self.root == 0:
137 self.root = self._make_node(key)
138 return True
139 left, right, size, keys = self.left, self.right, self.size, self.key
140 node = self.root
141 path = []
142 di = 0
143 while node:
144 if key == keys[node]:
145 return False
146 path.append(node)
147 di <<= 1
148 if key < keys[node]:
149 di |= 1
150 node = left[node]
151 else:
152 node = right[node]
153 # self.debug(self.root)
154 if di & 1:
155 left[path[-1]] = self._make_node(key)
156 else:
157 right[path[-1]] = self._make_node(key)
158 while path:
159 node = path.pop()
160 size[node] += 1
161 di >>= 1
162 wl = self._weight_left(node)
163 wr = self._weight_right(node)
164 if wl * DELTA < wr:
165 # print("wl * DELTA < wr")
166 # self.debug(node)
167 if (
168 self._weight_left(right[node])
169 >= self._weight_right(right[node]) * GAMMA
170 ):
171 right[node] = self._rotate_right(right[node])
172 node = self._rotate_left(node)
173 # self.debug(node)
174 # assert node
175 elif wr * DELTA < wl:
176 # print("wr * DELTA < wl")
177 if (
178 self._weight_right(left[node])
179 >= self._weight_left(left[node]) * GAMMA
180 ):
181 # print("left")
182 left[node] = self._rotate_left(left[node])
183 # print("right")
184 node = self._rotate_right(node)
185 # assert node
186 if path:
187 if di & 1:
188 left[path[-1]] = node
189 else:
190 right[path[-1]] = node
191 else:
192 self.root = node
193 return True
194
195 def remove(self, key: T) -> bool:
196 if self.discard(key):
197 return True
198 raise KeyError(key)
199
200 def discard(self, key: T) -> bool:
201 left, right, size, keys = self.left, self.right, self.size, self.key
202 di = 0
203 path = []
204 node = self.root
205 while node:
206 if key == keys[node]:
207 break
208 path.append(node)
209 di <<= 1
210 if key < keys[node]:
211 di |= 1
212 node = left[node]
213 else:
214 node = right[node]
215 else:
216 return False
217 if left[node] and right[node]:
218 path.append(node)
219 di <<= 1
220 di |= 1
221 lmax = left[node]
222 while right[lmax]:
223 path.append(lmax)
224 di <<= 1
225 lmax = right[lmax]
226 keys[node] = keys[lmax]
227 node = lmax
228 cnode = right[node] if left[node] == 0 else left[node]
229 if path:
230 if di & 1:
231 left[path[-1]] = cnode
232 else:
233 right[path[-1]] = cnode
234 else:
235 self.root = cnode
236 return True
237 while path:
238 node = path.pop()
239 size[node] -= 1
240 di >>= 1
241 wl = self._weight_left(node)
242 wr = self._weight_right(node)
243 if wl * DELTA < wr:
244 if (
245 self._weight_left(right[node])
246 >= self._weight_right(right[node]) * GAMMA
247 ):
248 right[node] = self._rotate_right(right[node])
249 node = self._rotate_left(node)
250 elif wr * DELTA < wl:
251 if (
252 self._weight_right(left[node])
253 >= self._weight_left(left[node]) * GAMMA
254 ):
255 left[node] = self._rotate_left(left[node])
256 node = self._rotate_right(node)
257 if path:
258 if di & 1:
259 left[path[-1]] = node
260 else:
261 right[path[-1]] = node
262 else:
263 self.root = node
264 return True
265
266 def le(self, key: T) -> Optional[T]:
267 keys, left, right = self.key, self.left, self.right
268 res = None
269 node = self.root
270 while node:
271 if key == keys[node]:
272 return keys[node]
273 if key < keys[node]:
274 node = left[node]
275 else:
276 res = keys[node]
277 node = right[node]
278 return res
279
280 def lt(self, key: T) -> Optional[T]:
281 keys, left, right = self.key, self.left, self.right
282 res = None
283 node = self.root
284 while node:
285 if key <= keys[node]:
286 node = left[node]
287 else:
288 res = keys[node]
289 node = right[node]
290 return res
291
292 def ge(self, key: T) -> Optional[T]:
293 keys, left, right = self.key, self.left, self.right
294 res = None
295 node = self.root
296 while node:
297 if key == keys[node]:
298 return keys[node]
299 if key < keys[node]:
300 res = keys[node]
301 node = left[node]
302 else:
303 node = right[node]
304 return res
305
306 def gt(self, key: T) -> Optional[T]:
307 keys, left, right = self.key, self.left, self.right
308 res = None
309 node = self.root
310 while node:
311 if key < keys[node]:
312 res = keys[node]
313 node = left[node]
314 else:
315 node = right[node]
316 return res
317
318 def index(self, key: T) -> int:
319 keys, left, right, size = self.key, self.left, self.right, self.size
320 k = 0
321 node = self.root
322 while node:
323 if key == keys[node]:
324 k += size[left[node]]
325 break
326 if key < keys[node]:
327 node = left[node]
328 else:
329 k += size[left[node]] + 1
330 node = right[node]
331 return k
332
333 def index_right(self, key: T) -> int:
334 keys, left, right, size = self.key, self.left, self.right, self.size
335 k, node = 0, self.root
336 while node:
337 if key == keys[node]:
338 k += size[left[node]] + 1
339 break
340 if key < keys[node]:
341 node = left[node]
342 else:
343 k += size[left[node]] + 1
344 node = right[node]
345 return k
346
347 def get_max(self) -> Optional[T]:
348 if not self:
349 return
350 return self[len(self) - 1]
351
352 def get_min(self) -> Optional[T]:
353 if not self:
354 return
355 return self[0]
356
357 def pop(self, k: int = -1) -> T:
358 left, right, size, key = self.left, self.right, self.size, self.key
359 if k < 0:
360 k += size[self.root]
361 assert 0 <= k and k < size[self.root], "IndexError"
362 path = []
363 di = 0
364 node = self.root
365 while True:
366 t = size[left[node]]
367 if t == k:
368 res = key[node]
369 break
370 path.append(node)
371 di <<= 1
372 if t < k:
373 k -= t + 1
374 node = right[node]
375 else:
376 di |= 1
377 node = left[node]
378 if left[node] and right[node]:
379 path.append(node)
380 di <<= 1
381 di |= 1
382 lmax = left[node]
383 while right[lmax]:
384 path.append(lmax)
385 di <<= 1
386 lmax = right[lmax]
387 key[node] = key[lmax]
388 node = lmax
389 cnode = right[node] if left[node] == 0 else left[node]
390 if path:
391 if di & 1:
392 left[path[-1]] = cnode
393 else:
394 right[path[-1]] = cnode
395 else:
396 self.root = cnode
397 return res
398 while path:
399 node = path.pop()
400 size[node] -= 1
401 di >>= 1
402 wl = self._weight_left(node)
403 wr = self._weight_right(node)
404 if wl * DELTA < wr:
405 if (
406 self._weight_left(right[node])
407 >= self._weight_right(right[node]) * GAMMA
408 ):
409 right[node] = self._rotate_right(right[node])
410 node = self._rotate_left(node)
411 elif wr * DELTA < wl:
412 if (
413 self._weight_right(left[node])
414 >= self._weight_left(left[node]) * GAMMA
415 ):
416 left[node] = self._rotate_left(left[node])
417 node = self._rotate_right(node)
418 if path:
419 if di & 1:
420 left[path[-1]] = node
421 else:
422 right[path[-1]] = node
423 else:
424 self.root = node
425 return res
426
427 def pop_max(self) -> T:
428 return self.pop()
429
430 def pop_min(self) -> T:
431 return self.pop(0)
432
433 def clear(self) -> None:
434 self.root = 0
435
436 def tolist(self) -> list[T]:
437 left, right, keys = self.left, self.right, self.key
438 node = self.root
439 stack, a = [], []
440 while stack or node:
441 if node:
442 stack.append(node)
443 node = left[node]
444 else:
445 node = stack.pop()
446 a.append(keys[node])
447 node = right[node]
448 return a
449
450 def check(self) -> int:
451 """作業用デバック関数"""
452 if self.root == 0:
453 return 0
454
455 def _balance_check(node: int) -> None:
456 if node == 0:
457 return
458 if not self._weight_left(node) * DELTA >= self._weight_right(node):
459 print(self._weight_left(node), self._weight_right(node), flush=True)
460 assert False, f"self._weight_left() * DELTA >= self._weight_right()"
461 if not self._weight_right(node) * DELTA >= self._weight_left(node):
462 print(self._weight_left(node), self._weight_right(node), flush=True)
463 assert False, f"self._weight_right() * DELTA >= self._weight_left()"
464
465 keys = self.key
466
467 # _size, height
468 def dfs(node) -> tuple[int, int]:
469 _balance_check(node)
470 h = 0
471 s = 1
472 if self.left[node]:
473 assert keys[self.left[node]] < keys[node]
474 ls, lh = dfs(self.left[node])
475 s += ls
476 h = max(h, lh)
477 if self.right[node]:
478 assert keys[node] < keys[self.right[node]]
479 rs, rh = dfs(self.right[node])
480 s += rs
481 h = max(h, rh)
482 assert self.size[node] == s
483 return s, h + 1
484
485 _, h = dfs(self.root)
486 return h
487
488 def __contains__(self, key: T) -> bool:
489 keys, left, right = self.key, self.left, self.right
490 node = self.root
491 while node:
492 if key == keys[node]:
493 return True
494 node = left[node] if key < keys[node] else right[node]
495 return False
496
497 def __getitem__(self, k: int) -> T:
498 left, right, size, key = self.left, self.right, self.size, self.key
499 if k < 0:
500 k += size[self.root]
501 assert (
502 0 <= k and k < size[self.root]
503 ), f"IndexError: WBTSet[{k}], len={len(self)}"
504 node = self.root
505 while True:
506 t = size[left[node]]
507 if t == k:
508 return key[node]
509 if t < k:
510 k -= t + 1
511 node = right[node]
512 else:
513 node = left[node]
514
515 def __iter__(self):
516 self.__iter = 0
517 return self
518
519 def __next__(self):
520 if self.__iter == self.__len__():
521 raise StopIteration
522 res = self[self.__iter]
523 self.__iter += 1
524 return res
525
526 def __reversed__(self):
527 for i in range(self.__len__()):
528 yield self[-i - 1]
529
530 def __len__(self):
531 return self.size[self.root]
532
533 def __bool__(self):
534 return self.root != 0
535
536 def __str__(self):
537 return "{" + ", ".join(map(str, self.tolist())) + "}"
538
539 def __repr__(self):
540 return f"WBTSet({self})"