avl_tree_set¶
ソースコード¶
from titan_pylib.data_structures.avl_tree.avl_tree_set import AVLTreeSet
展開済みコード¶
1# from titan_pylib.data_structures.avl_tree.avl_tree_set import AVLTreeSet
2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
3from typing import Protocol
4
5
6class SupportsLessThan(Protocol):
7
8 def __lt__(self, other) -> bool: ...
9# from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
10# from titan_pylib.my_class.supports_less_than import SupportsLessThan
11from abc import ABC, abstractmethod
12from typing import Iterable, Optional, Iterator, TypeVar, Generic
13
14T = TypeVar("T", bound=SupportsLessThan)
15
16
17class OrderedSetInterface(ABC, Generic[T]):
18
19 @abstractmethod
20 def __init__(self, a: Iterable[T]) -> None:
21 raise NotImplementedError
22
23 @abstractmethod
24 def add(self, key: T) -> bool:
25 raise NotImplementedError
26
27 @abstractmethod
28 def discard(self, key: T) -> bool:
29 raise NotImplementedError
30
31 @abstractmethod
32 def remove(self, key: T) -> None:
33 raise NotImplementedError
34
35 @abstractmethod
36 def le(self, key: T) -> Optional[T]:
37 raise NotImplementedError
38
39 @abstractmethod
40 def lt(self, key: T) -> Optional[T]:
41 raise NotImplementedError
42
43 @abstractmethod
44 def ge(self, key: T) -> Optional[T]:
45 raise NotImplementedError
46
47 @abstractmethod
48 def gt(self, key: T) -> Optional[T]:
49 raise NotImplementedError
50
51 @abstractmethod
52 def get_max(self) -> Optional[T]:
53 raise NotImplementedError
54
55 @abstractmethod
56 def get_min(self) -> Optional[T]:
57 raise NotImplementedError
58
59 @abstractmethod
60 def pop_max(self) -> T:
61 raise NotImplementedError
62
63 @abstractmethod
64 def pop_min(self) -> T:
65 raise NotImplementedError
66
67 @abstractmethod
68 def clear(self) -> None:
69 raise NotImplementedError
70
71 @abstractmethod
72 def tolist(self) -> list[T]:
73 raise NotImplementedError
74
75 @abstractmethod
76 def __iter__(self) -> Iterator:
77 raise NotImplementedError
78
79 @abstractmethod
80 def __next__(self) -> T:
81 raise NotImplementedError
82
83 @abstractmethod
84 def __contains__(self, key: T) -> bool:
85 raise NotImplementedError
86
87 @abstractmethod
88 def __len__(self) -> int:
89 raise NotImplementedError
90
91 @abstractmethod
92 def __bool__(self) -> bool:
93 raise NotImplementedError
94
95 @abstractmethod
96 def __str__(self) -> str:
97 raise NotImplementedError
98
99 @abstractmethod
100 def __repr__(self) -> str:
101 raise NotImplementedError
102from array import array
103from typing import Generic, Iterable, TypeVar, Optional
104
105T = TypeVar("T", bound=SupportsLessThan)
106
107
108class AVLTreeSet(OrderedSetInterface, Generic[T]):
109 """AVLTreeSet
110 集合としての AVL 木です。
111 配列を用いてノードを表現しています。
112 size を持ちます。
113 """
114
115 def __init__(self, a: Iterable[T] = []) -> None:
116 self.root = 0
117 self.key = [0]
118 self.size = array("I", bytes(4))
119 self.left = array("I", bytes(4))
120 self.right = array("I", bytes(4))
121 self.balance = array("b", bytes(1))
122 self.end = 1
123 if not isinstance(a, list):
124 a = list(a)
125 if a:
126 self._build(a)
127
128 def reserve(self, n: int) -> None:
129 self.key += [0] * n
130 a = array("I", bytes(4 * n))
131 self.left += a
132 self.right += a
133 self.size += array("I", [1] * n)
134 self.balance += array("b", bytes(n))
135
136 def _build(self, a: list[T]) -> None:
137 left, right, size, balance = self.left, self.right, self.size, self.balance
138
139 def sort(l: int, r: int) -> tuple[int, int]:
140 mid = (l + r) >> 1
141 node = mid
142 hl, hr = 0, 0
143 if l != mid:
144 left[node], hl = sort(l, mid)
145 size[node] += size[left[node]]
146 if mid + 1 != r:
147 right[node], hr = sort(mid + 1, r)
148 size[node] += size[right[node]]
149 balance[node] = hl - hr
150 return node, max(hl, hr) + 1
151
152 n = len(a)
153 if n == 0:
154 return
155 if not all(a[i] < a[i + 1] for i in range(n - 1)):
156 b = sorted(a)
157 a = [b[0]]
158 for i in range(1, n):
159 if b[i] != a[-1]:
160 a.append(b[i])
161 n = len(a)
162 end = self.end
163 self.end += n
164 self.reserve(n)
165 self.key[end : end + n] = a
166 self.root = sort(end, n + end)[0]
167
168 def _rotate_L(self, node: int) -> int:
169 left, right, size, balance = self.left, self.right, self.size, self.balance
170 u = left[node]
171 size[u] = size[node]
172 size[node] -= size[left[u]] + 1
173 left[node] = right[u]
174 right[u] = node
175 if balance[u] == 1:
176 balance[u] = 0
177 balance[node] = 0
178 else:
179 balance[u] = -1
180 balance[node] = 1
181 return u
182
183 def _rotate_R(self, node: int) -> int:
184 left, right, size, balance = self.left, self.right, self.size, self.balance
185 u = right[node]
186 size[u] = size[node]
187 size[node] -= size[right[u]] + 1
188 right[node] = left[u]
189 left[u] = node
190 if balance[u] == -1:
191 balance[u] = 0
192 balance[node] = 0
193 else:
194 balance[u] = 1
195 balance[node] = -1
196 return u
197
198 def _update_balance(self, node: int) -> None:
199 balance = self.balance
200 if balance[node] == 1:
201 balance[self.right[node]] = -1
202 balance[self.left[node]] = 0
203 elif balance[node] == -1:
204 balance[self.right[node]] = 0
205 balance[self.left[node]] = 1
206 else:
207 balance[self.right[node]] = 0
208 balance[self.left[node]] = 0
209 balance[node] = 0
210
211 def _rotate_LR(self, node: int) -> int:
212 left, right, size = self.left, self.right, self.size
213 B = left[node]
214 E = right[B]
215 size[E] = size[node]
216 size[node] -= size[B] - size[right[E]]
217 size[B] -= size[right[E]] + 1
218 right[B] = left[E]
219 left[E] = B
220 left[node] = right[E]
221 right[E] = node
222 self._update_balance(E)
223 return E
224
225 def _rotate_RL(self, node: int) -> int:
226 left, right, size = self.left, self.right, self.size
227 C = right[node]
228 D = left[C]
229 size[D] = size[node]
230 size[node] -= size[C] - size[left[D]]
231 size[C] -= size[left[D]] + 1
232 left[C] = right[D]
233 right[D] = C
234 right[node] = left[D]
235 left[D] = node
236 self._update_balance(D)
237 return D
238
239 def _make_node(self, key: T) -> int:
240 end = self.end
241 if end >= len(self.key):
242 self.key.append(key)
243 self.size.append(1)
244 self.left.append(0)
245 self.right.append(0)
246 self.balance.append(0)
247 else:
248 self.key[end] = key
249 self.end += 1
250 return end
251
252 def add(self, key: T) -> bool:
253 if self.root == 0:
254 self.root = self._make_node(key)
255 return True
256 left, right, size, balance, keys = (
257 self.left,
258 self.right,
259 self.size,
260 self.balance,
261 self.key,
262 )
263 node = self.root
264 path = []
265 di = 0
266 while node:
267 if key == keys[node]:
268 return False
269 di <<= 1
270 path.append(node)
271 if key < keys[node]:
272 di |= 1
273 node = left[node]
274 else:
275 node = right[node]
276 if di & 1:
277 left[path[-1]] = self._make_node(key)
278 else:
279 right[path[-1]] = self._make_node(key)
280 new_node = 0
281 while path:
282 node = path.pop()
283 size[node] += 1
284 balance[node] += 1 if di & 1 else -1
285 di >>= 1
286 if balance[node] == 0:
287 break
288 if balance[node] == 2:
289 new_node = (
290 self._rotate_LR(node)
291 if balance[left[node]] == -1
292 else self._rotate_L(node)
293 )
294 break
295 elif balance[node] == -2:
296 new_node = (
297 self._rotate_RL(node)
298 if balance[right[node]] == 1
299 else self._rotate_R(node)
300 )
301 break
302 if new_node:
303 if path:
304 node = path.pop()
305 size[node] += 1
306 if di & 1:
307 left[node] = new_node
308 else:
309 right[node] = new_node
310 else:
311 self.root = new_node
312 for p in path:
313 size[p] += 1
314 return True
315
316 def remove(self, key: T) -> bool:
317 if self.discard(key):
318 return True
319 raise KeyError(key)
320
321 def discard(self, key: T) -> bool:
322 left, right, size, balance, keys = (
323 self.left,
324 self.right,
325 self.size,
326 self.balance,
327 self.key,
328 )
329 di = 0
330 path = []
331 node = self.root
332 while node:
333 if key == keys[node]:
334 break
335 path.append(node)
336 di <<= 1
337 if key < keys[node]:
338 di |= 1
339 node = left[node]
340 else:
341 node = right[node]
342 else:
343 return False
344 if left[node] and right[node]:
345 path.append(node)
346 di <<= 1
347 di |= 1
348 lmax = left[node]
349 while right[lmax]:
350 path.append(lmax)
351 di <<= 1
352 lmax = right[lmax]
353 keys[node] = keys[lmax]
354 node = lmax
355 cnode = right[node] if left[node] == 0 else left[node]
356 if path:
357 if di & 1:
358 left[path[-1]] = cnode
359 else:
360 right[path[-1]] = cnode
361 else:
362 self.root = cnode
363 return True
364 while path:
365 new_node = 0
366 node = path.pop()
367 balance[node] -= 1 if di & 1 else -1
368 di >>= 1
369 size[node] -= 1
370 if balance[node] == 2:
371 new_node = (
372 self._rotate_LR(node)
373 if balance[left[node]] == -1
374 else self._rotate_L(node)
375 )
376 elif balance[node] == -2:
377 new_node = (
378 self._rotate_RL(node)
379 if balance[right[node]] == 1
380 else self._rotate_R(node)
381 )
382 elif balance[node]:
383 break
384 if new_node:
385 if not path:
386 self.root = new_node
387 return True
388 if di & 1:
389 left[path[-1]] = new_node
390 else:
391 right[path[-1]] = new_node
392 if balance[new_node]:
393 break
394 for p in path:
395 size[p] -= 1
396 return True
397
398 def le(self, key: T) -> Optional[T]:
399 keys, left, right = self.key, self.left, self.right
400 res = None
401 node = self.root
402 while node:
403 if key == keys[node]:
404 return keys[node]
405 if key < keys[node]:
406 node = left[node]
407 else:
408 res = keys[node]
409 node = right[node]
410 return res
411
412 def lt(self, key: T) -> Optional[T]:
413 keys, left, right = self.key, self.left, self.right
414 res = None
415 node = self.root
416 while node:
417 if key <= keys[node]:
418 node = left[node]
419 else:
420 res = keys[node]
421 node = right[node]
422 return res
423
424 def ge(self, key: T) -> Optional[T]:
425 keys, left, right = self.key, self.left, self.right
426 res = None
427 node = self.root
428 while node:
429 if key == keys[node]:
430 return keys[node]
431 if key < keys[node]:
432 res = keys[node]
433 node = left[node]
434 else:
435 node = right[node]
436 return res
437
438 def gt(self, key: T) -> Optional[T]:
439 keys, left, right = self.key, self.left, self.right
440 res = None
441 node = self.root
442 while node:
443 if key < keys[node]:
444 res = keys[node]
445 node = left[node]
446 else:
447 node = right[node]
448 return res
449
450 def index(self, key: T) -> int:
451 keys, left, right, size = self.key, self.left, self.right, self.size
452 k = 0
453 node = self.root
454 while node:
455 if key == keys[node]:
456 k += size[left[node]]
457 break
458 if key < keys[node]:
459 node = left[node]
460 else:
461 k += size[left[node]] + 1
462 node = right[node]
463 return k
464
465 def index_right(self, key: T) -> int:
466 keys, left, right, size = self.key, self.left, self.right, self.size
467 k, node = 0, self.root
468 while node:
469 if key == keys[node]:
470 k += size[left[node]] + 1
471 break
472 if key < keys[node]:
473 node = left[node]
474 else:
475 k += size[left[node]] + 1
476 node = right[node]
477 return k
478
479 def get_max(self) -> Optional[T]:
480 if not self:
481 return
482 return self[len(self) - 1]
483
484 def get_min(self) -> Optional[T]:
485 if not self:
486 return
487 return self[0]
488
489 def pop(self, k: int = -1) -> T:
490 left, right, size, key, balance = (
491 self.left,
492 self.right,
493 self.size,
494 self.key,
495 self.balance,
496 )
497 if k < 0:
498 k += size[self.root]
499 assert 0 <= k and k < size[self.root], "IndexError"
500 path = []
501 di = 0
502 node = self.root
503 while True:
504 t = size[left[node]]
505 if t == k:
506 res = key[node]
507 break
508 path.append(node)
509 di <<= 1
510 if t < k:
511 k -= t + 1
512 node = right[node]
513 else:
514 di |= 1
515 node = left[node]
516 if left[node] and right[node]:
517 path.append(node)
518 di <<= 1
519 di |= 1
520 lmax = left[node]
521 while right[lmax]:
522 path.append(lmax)
523 di <<= 1
524 lmax = right[lmax]
525 key[node] = key[lmax]
526 node = lmax
527 cnode = right[node] if left[node] == 0 else left[node]
528 if path:
529 if di & 1:
530 left[path[-1]] = cnode
531 else:
532 right[path[-1]] = cnode
533 else:
534 self.root = cnode
535 return res
536 while path:
537 new_node = 0
538 node = path.pop()
539 balance[node] -= 1 if di & 1 else -1
540 di >>= 1
541 size[node] -= 1
542 if balance[node] == 2:
543 new_node = (
544 self._rotate_LR(node)
545 if balance[left[node]] == -1
546 else self._rotate_L(node)
547 )
548 elif balance[node] == -2:
549 new_node = (
550 self._rotate_RL(node)
551 if balance[right[node]] == 1
552 else self._rotate_R(node)
553 )
554 elif balance[node]:
555 break
556 if new_node:
557 if not path:
558 self.root = new_node
559 return res
560 if di & 1:
561 left[path[-1]] = new_node
562 else:
563 right[path[-1]] = new_node
564 if balance[new_node]:
565 break
566 for p in path:
567 size[p] -= 1
568 return res
569
570 def pop_max(self) -> T:
571 return self.pop()
572
573 def pop_min(self) -> T:
574 return self.pop(0)
575
576 def clear(self) -> None:
577 self.root = 0
578
579 def tolist(self) -> list[T]:
580 left, right, keys = self.left, self.right, self.key
581 node = self.root
582 stack, a = [], []
583 while stack or node:
584 if node:
585 stack.append(node)
586 node = left[node]
587 else:
588 node = stack.pop()
589 a.append(keys[node])
590 node = right[node]
591 return a
592
593 def ave_height(self):
594 if not self.root:
595 return 0
596 left, right, size, keys = self.left, self.right, self.size, self.key
597 ans = 0
598
599 def dfs(node, dep):
600 nonlocal ans
601 ans += dep
602 if left[node]:
603 dfs(left[node], dep + 1)
604 if right[node]:
605 dfs(right[node], dep + 1)
606
607 dfs(self.root, 1)
608 ans /= len(self)
609 return ans
610
611 def _get_height(self) -> int:
612 """作業用デバック関数
613 size,key,balanceをチェックして、正しければ高さを表示する
614 """
615 if self.root == 0:
616 return 0
617
618 # _size, height
619 def dfs(node) -> tuple[int, int]:
620 h = 0
621 s = 1
622 if self.left[node]:
623 ls, lh = dfs(self.left[node])
624 s += ls
625 h = max(h, lh)
626 if self.right[node]:
627 rs, rh = dfs(self.right[node])
628 s += rs
629 h = max(h, rh)
630 assert self.size[node] == s
631 return s, h + 1
632
633 _, h = dfs(self.root)
634 return h
635
636 def __contains__(self, key: T) -> bool:
637 keys, left, right = self.key, self.left, self.right
638 node = self.root
639 while node:
640 if key == keys[node]:
641 return True
642 node = left[node] if key < keys[node] else right[node]
643 return False
644
645 def __getitem__(self, k: int) -> T:
646 left, right, size, key = self.left, self.right, self.size, self.key
647 if k < 0:
648 k += size[self.root]
649 assert (
650 0 <= k and k < size[self.root]
651 ), f"IndexError: AVLTreeSet[{k}], len={len(self)}"
652 node = self.root
653 while True:
654 t = size[left[node]]
655 if t == k:
656 return key[node]
657 if t < k:
658 k -= t + 1
659 node = right[node]
660 else:
661 node = left[node]
662
663 def __iter__(self):
664 self.__iter = 0
665 return self
666
667 def __next__(self):
668 if self.__iter == self.__len__():
669 raise StopIteration
670 res = self[self.__iter]
671 self.__iter += 1
672 return res
673
674 def __reversed__(self):
675 for i in range(self.__len__()):
676 yield self[-i - 1]
677
678 def __len__(self):
679 return self.size[self.root]
680
681 def __bool__(self):
682 return self.root != 0
683
684 def __str__(self):
685 return "{" + ", ".join(map(str, self.tolist())) + "}"
686
687 def __repr__(self):
688 return f"AVLTreeSet({self})"
仕様¶
- class AVLTreeSet(a: Iterable[T] = [])[source]¶
Bases:
OrderedSetInterface
,Generic
[T
]集合としての AVL 木です。 配列を用いてノードを表現しています。 size を持ちます。