1from titan_pylib.my_class.supports_less_than import SupportsLessThan
2from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
3from array import array
4from typing import Generic, Iterable, TypeVar, Optional
5
6T = TypeVar("T", bound=SupportsLessThan)
7
8
[docs]
9class AVLTreeSet(OrderedSetInterface, Generic[T]):
10 """AVLTreeSet
11 集合としての AVL 木です。
12 配列を用いてノードを表現しています。
13 size を持ちます。
14 """
15
16 def __init__(self, a: Iterable[T] = []) -> None:
17 self.root = 0
18 self.key = [0]
19 self.size = array("I", bytes(4))
20 self.left = array("I", bytes(4))
21 self.right = array("I", bytes(4))
22 self.balance = array("b", bytes(1))
23 self.end = 1
24 if not isinstance(a, list):
25 a = list(a)
26 if a:
27 self._build(a)
28
[docs]
29 def reserve(self, n: int) -> None:
30 self.key += [0] * n
31 a = array("I", bytes(4 * n))
32 self.left += a
33 self.right += a
34 self.size += array("I", [1] * n)
35 self.balance += array("b", bytes(n))
36
37 def _build(self, a: list[T]) -> None:
38 left, right, size, balance = self.left, self.right, self.size, self.balance
39
40 def sort(l: int, r: int) -> tuple[int, int]:
41 mid = (l + r) >> 1
42 node = mid
43 hl, hr = 0, 0
44 if l != mid:
45 left[node], hl = sort(l, mid)
46 size[node] += size[left[node]]
47 if mid + 1 != r:
48 right[node], hr = sort(mid + 1, r)
49 size[node] += size[right[node]]
50 balance[node] = hl - hr
51 return node, max(hl, hr) + 1
52
53 n = len(a)
54 if n == 0:
55 return
56 if not all(a[i] < a[i + 1] for i in range(n - 1)):
57 b = sorted(a)
58 a = [b[0]]
59 for i in range(1, n):
60 if b[i] != a[-1]:
61 a.append(b[i])
62 n = len(a)
63 end = self.end
64 self.end += n
65 self.reserve(n)
66 self.key[end : end + n] = a
67 self.root = sort(end, n + end)[0]
68
69 def _rotate_L(self, node: int) -> int:
70 left, right, size, balance = self.left, self.right, self.size, self.balance
71 u = left[node]
72 size[u] = size[node]
73 size[node] -= size[left[u]] + 1
74 left[node] = right[u]
75 right[u] = node
76 if balance[u] == 1:
77 balance[u] = 0
78 balance[node] = 0
79 else:
80 balance[u] = -1
81 balance[node] = 1
82 return u
83
84 def _rotate_R(self, node: int) -> int:
85 left, right, size, balance = self.left, self.right, self.size, self.balance
86 u = right[node]
87 size[u] = size[node]
88 size[node] -= size[right[u]] + 1
89 right[node] = left[u]
90 left[u] = node
91 if balance[u] == -1:
92 balance[u] = 0
93 balance[node] = 0
94 else:
95 balance[u] = 1
96 balance[node] = -1
97 return u
98
99 def _update_balance(self, node: int) -> None:
100 balance = self.balance
101 if balance[node] == 1:
102 balance[self.right[node]] = -1
103 balance[self.left[node]] = 0
104 elif balance[node] == -1:
105 balance[self.right[node]] = 0
106 balance[self.left[node]] = 1
107 else:
108 balance[self.right[node]] = 0
109 balance[self.left[node]] = 0
110 balance[node] = 0
111
112 def _rotate_LR(self, node: int) -> int:
113 left, right, size = self.left, self.right, self.size
114 B = left[node]
115 E = right[B]
116 size[E] = size[node]
117 size[node] -= size[B] - size[right[E]]
118 size[B] -= size[right[E]] + 1
119 right[B] = left[E]
120 left[E] = B
121 left[node] = right[E]
122 right[E] = node
123 self._update_balance(E)
124 return E
125
126 def _rotate_RL(self, node: int) -> int:
127 left, right, size = self.left, self.right, self.size
128 C = right[node]
129 D = left[C]
130 size[D] = size[node]
131 size[node] -= size[C] - size[left[D]]
132 size[C] -= size[left[D]] + 1
133 left[C] = right[D]
134 right[D] = C
135 right[node] = left[D]
136 left[D] = node
137 self._update_balance(D)
138 return D
139
140 def _make_node(self, key: T) -> int:
141 end = self.end
142 if end >= len(self.key):
143 self.key.append(key)
144 self.size.append(1)
145 self.left.append(0)
146 self.right.append(0)
147 self.balance.append(0)
148 else:
149 self.key[end] = key
150 self.end += 1
151 return end
152
[docs]
153 def add(self, key: T) -> bool:
154 if self.root == 0:
155 self.root = self._make_node(key)
156 return True
157 left, right, size, balance, keys = (
158 self.left,
159 self.right,
160 self.size,
161 self.balance,
162 self.key,
163 )
164 node = self.root
165 path = []
166 di = 0
167 while node:
168 if key == keys[node]:
169 return False
170 di <<= 1
171 path.append(node)
172 if key < keys[node]:
173 di |= 1
174 node = left[node]
175 else:
176 node = right[node]
177 if di & 1:
178 left[path[-1]] = self._make_node(key)
179 else:
180 right[path[-1]] = self._make_node(key)
181 new_node = 0
182 while path:
183 node = path.pop()
184 size[node] += 1
185 balance[node] += 1 if di & 1 else -1
186 di >>= 1
187 if balance[node] == 0:
188 break
189 if balance[node] == 2:
190 new_node = (
191 self._rotate_LR(node)
192 if balance[left[node]] == -1
193 else self._rotate_L(node)
194 )
195 break
196 elif balance[node] == -2:
197 new_node = (
198 self._rotate_RL(node)
199 if balance[right[node]] == 1
200 else self._rotate_R(node)
201 )
202 break
203 if new_node:
204 if path:
205 node = path.pop()
206 size[node] += 1
207 if di & 1:
208 left[node] = new_node
209 else:
210 right[node] = new_node
211 else:
212 self.root = new_node
213 for p in path:
214 size[p] += 1
215 return True
216
[docs]
217 def remove(self, key: T) -> bool:
218 if self.discard(key):
219 return True
220 raise KeyError(key)
221
[docs]
222 def discard(self, key: T) -> bool:
223 left, right, size, balance, keys = (
224 self.left,
225 self.right,
226 self.size,
227 self.balance,
228 self.key,
229 )
230 di = 0
231 path = []
232 node = self.root
233 while node:
234 if key == keys[node]:
235 break
236 path.append(node)
237 di <<= 1
238 if key < keys[node]:
239 di |= 1
240 node = left[node]
241 else:
242 node = right[node]
243 else:
244 return False
245 if left[node] and right[node]:
246 path.append(node)
247 di <<= 1
248 di |= 1
249 lmax = left[node]
250 while right[lmax]:
251 path.append(lmax)
252 di <<= 1
253 lmax = right[lmax]
254 keys[node] = keys[lmax]
255 node = lmax
256 cnode = right[node] if left[node] == 0 else left[node]
257 if path:
258 if di & 1:
259 left[path[-1]] = cnode
260 else:
261 right[path[-1]] = cnode
262 else:
263 self.root = cnode
264 return True
265 while path:
266 new_node = 0
267 node = path.pop()
268 balance[node] -= 1 if di & 1 else -1
269 di >>= 1
270 size[node] -= 1
271 if balance[node] == 2:
272 new_node = (
273 self._rotate_LR(node)
274 if balance[left[node]] == -1
275 else self._rotate_L(node)
276 )
277 elif balance[node] == -2:
278 new_node = (
279 self._rotate_RL(node)
280 if balance[right[node]] == 1
281 else self._rotate_R(node)
282 )
283 elif balance[node]:
284 break
285 if new_node:
286 if not path:
287 self.root = new_node
288 return True
289 if di & 1:
290 left[path[-1]] = new_node
291 else:
292 right[path[-1]] = new_node
293 if balance[new_node]:
294 break
295 for p in path:
296 size[p] -= 1
297 return True
298
[docs]
299 def le(self, key: T) -> Optional[T]:
300 keys, left, right = self.key, self.left, self.right
301 res = None
302 node = self.root
303 while node:
304 if key == keys[node]:
305 return keys[node]
306 if key < keys[node]:
307 node = left[node]
308 else:
309 res = keys[node]
310 node = right[node]
311 return res
312
[docs]
313 def lt(self, key: T) -> Optional[T]:
314 keys, left, right = self.key, self.left, self.right
315 res = None
316 node = self.root
317 while node:
318 if key <= keys[node]:
319 node = left[node]
320 else:
321 res = keys[node]
322 node = right[node]
323 return res
324
[docs]
325 def ge(self, key: T) -> Optional[T]:
326 keys, left, right = self.key, self.left, self.right
327 res = None
328 node = self.root
329 while node:
330 if key == keys[node]:
331 return keys[node]
332 if key < keys[node]:
333 res = keys[node]
334 node = left[node]
335 else:
336 node = right[node]
337 return res
338
[docs]
339 def gt(self, key: T) -> Optional[T]:
340 keys, left, right = self.key, self.left, self.right
341 res = None
342 node = self.root
343 while node:
344 if key < keys[node]:
345 res = keys[node]
346 node = left[node]
347 else:
348 node = right[node]
349 return res
350
[docs]
351 def index(self, key: T) -> int:
352 keys, left, right, size = self.key, self.left, self.right, self.size
353 k = 0
354 node = self.root
355 while node:
356 if key == keys[node]:
357 k += size[left[node]]
358 break
359 if key < keys[node]:
360 node = left[node]
361 else:
362 k += size[left[node]] + 1
363 node = right[node]
364 return k
365
[docs]
366 def index_right(self, key: T) -> int:
367 keys, left, right, size = self.key, self.left, self.right, self.size
368 k, node = 0, self.root
369 while node:
370 if key == keys[node]:
371 k += size[left[node]] + 1
372 break
373 if key < keys[node]:
374 node = left[node]
375 else:
376 k += size[left[node]] + 1
377 node = right[node]
378 return k
379
[docs]
380 def get_max(self) -> Optional[T]:
381 if not self:
382 return
383 return self[len(self) - 1]
384
[docs]
385 def get_min(self) -> Optional[T]:
386 if not self:
387 return
388 return self[0]
389
[docs]
390 def pop(self, k: int = -1) -> T:
391 left, right, size, key, balance = (
392 self.left,
393 self.right,
394 self.size,
395 self.key,
396 self.balance,
397 )
398 if k < 0:
399 k += size[self.root]
400 assert 0 <= k and k < size[self.root], "IndexError"
401 path = []
402 di = 0
403 node = self.root
404 while True:
405 t = size[left[node]]
406 if t == k:
407 res = key[node]
408 break
409 path.append(node)
410 di <<= 1
411 if t < k:
412 k -= t + 1
413 node = right[node]
414 else:
415 di |= 1
416 node = left[node]
417 if left[node] and right[node]:
418 path.append(node)
419 di <<= 1
420 di |= 1
421 lmax = left[node]
422 while right[lmax]:
423 path.append(lmax)
424 di <<= 1
425 lmax = right[lmax]
426 key[node] = key[lmax]
427 node = lmax
428 cnode = right[node] if left[node] == 0 else left[node]
429 if path:
430 if di & 1:
431 left[path[-1]] = cnode
432 else:
433 right[path[-1]] = cnode
434 else:
435 self.root = cnode
436 return res
437 while path:
438 new_node = 0
439 node = path.pop()
440 balance[node] -= 1 if di & 1 else -1
441 di >>= 1
442 size[node] -= 1
443 if balance[node] == 2:
444 new_node = (
445 self._rotate_LR(node)
446 if balance[left[node]] == -1
447 else self._rotate_L(node)
448 )
449 elif balance[node] == -2:
450 new_node = (
451 self._rotate_RL(node)
452 if balance[right[node]] == 1
453 else self._rotate_R(node)
454 )
455 elif balance[node]:
456 break
457 if new_node:
458 if not path:
459 self.root = new_node
460 return res
461 if di & 1:
462 left[path[-1]] = new_node
463 else:
464 right[path[-1]] = new_node
465 if balance[new_node]:
466 break
467 for p in path:
468 size[p] -= 1
469 return res
470
[docs]
471 def pop_max(self) -> T:
472 return self.pop()
473
[docs]
474 def pop_min(self) -> T:
475 return self.pop(0)
476
[docs]
477 def clear(self) -> None:
478 self.root = 0
479
[docs]
480 def tolist(self) -> list[T]:
481 left, right, keys = self.left, self.right, self.key
482 node = self.root
483 stack, a = [], []
484 while stack or node:
485 if node:
486 stack.append(node)
487 node = left[node]
488 else:
489 node = stack.pop()
490 a.append(keys[node])
491 node = right[node]
492 return a
493
[docs]
494 def ave_height(self):
495 if not self.root:
496 return 0
497 left, right, size, keys = self.left, self.right, self.size, self.key
498 ans = 0
499
500 def dfs(node, dep):
501 nonlocal ans
502 ans += dep
503 if left[node]:
504 dfs(left[node], dep + 1)
505 if right[node]:
506 dfs(right[node], dep + 1)
507
508 dfs(self.root, 1)
509 ans /= len(self)
510 return ans
511
512 def _get_height(self) -> int:
513 """作業用デバック関数
514 size,key,balanceをチェックして、正しければ高さを表示する
515 """
516 if self.root == 0:
517 return 0
518
519 # _size, height
520 def dfs(node) -> tuple[int, int]:
521 h = 0
522 s = 1
523 if self.left[node]:
524 ls, lh = dfs(self.left[node])
525 s += ls
526 h = max(h, lh)
527 if self.right[node]:
528 rs, rh = dfs(self.right[node])
529 s += rs
530 h = max(h, rh)
531 assert self.size[node] == s
532 return s, h + 1
533
534 _, h = dfs(self.root)
535 return h
536
537 def __contains__(self, key: T) -> bool:
538 keys, left, right = self.key, self.left, self.right
539 node = self.root
540 while node:
541 if key == keys[node]:
542 return True
543 node = left[node] if key < keys[node] else right[node]
544 return False
545
546 def __getitem__(self, k: int) -> T:
547 left, right, size, key = self.left, self.right, self.size, self.key
548 if k < 0:
549 k += size[self.root]
550 assert (
551 0 <= k and k < size[self.root]
552 ), f"IndexError: AVLTreeSet[{k}], len={len(self)}"
553 node = self.root
554 while True:
555 t = size[left[node]]
556 if t == k:
557 return key[node]
558 if t < k:
559 k -= t + 1
560 node = right[node]
561 else:
562 node = left[node]
563
564 def __iter__(self):
565 self.__iter = 0
566 return self
567
568 def __next__(self):
569 if self.__iter == self.__len__():
570 raise StopIteration
571 res = self[self.__iter]
572 self.__iter += 1
573 return res
574
575 def __reversed__(self):
576 for i in range(self.__len__()):
577 yield self[-i - 1]
578
579 def __len__(self):
580 return self.size[self.root]
581
582 def __bool__(self):
583 return self.root != 0
584
585 def __str__(self):
586 return "{" + ", ".join(map(str, self.tolist())) + "}"
587
588 def __repr__(self):
589 return f"AVLTreeSet({self})"