wbt_set¶
ソースコード¶
from titan_pylib.data_structures.wbt.wbt_set import WBTSet
展開済みコード¶
1# from titan_pylib.data_structures.wbt.wbt_set import WBTSet
2# from titan_pylib.data_structures.wbt._wbt_set_node import _WBTSetNode
3# from titan_pylib.data_structures.wbt._wbt_node_base import _WBTNodeBase
4from typing import Generic, TypeVar, Optional, Final
5
6T = TypeVar("T")
7
8
9class _WBTNodeBase(Generic[T]):
10 """WBTノードのベースクラス
11 size, par, left, rightをもつ
12 """
13
14 __slots__ = "_size", "_par", "_left", "_right"
15 DELTA: Final[int] = 3
16 GAMMA: Final[int] = 2
17
18 def __init__(self) -> None:
19 self._size: int = 1
20 self._par: Optional[_WBTNodeBase[T]] = None
21 self._left: Optional[_WBTNodeBase[T]] = None
22 self._right: Optional[_WBTNodeBase[T]] = None
23
24 def _rebalance(self) -> "_WBTNodeBase[T]":
25 """根までを再構築する
26
27 Returns:
28 _WBTNodeBase[T]: 根ノード
29 """
30 node = self
31 while True:
32 node._update()
33 wl, wr = node._weight_left(), node._weight_right()
34 if wl * _WBTNodeBase.DELTA < wr:
35 if (
36 node._right._weight_left()
37 >= node._right._weight_right() * _WBTNodeBase.GAMMA
38 ):
39 node._right = node._right._rotate_right()
40 node = node._rotate_left()
41 elif wr * _WBTNodeBase.DELTA < wl:
42 if (
43 node._left._weight_right()
44 >= node._left._weight_left() * _WBTNodeBase.GAMMA
45 ):
46 node._left = node._left._rotate_left()
47 node = node._rotate_right()
48 if not node._par:
49 return node
50 node = node._par
51
52 def _copy_from(self, other: "_WBTNodeBase[T]") -> None:
53 self._size = other._size
54 if other._left:
55 other._left._par = self
56 if other._right:
57 other._right._par = self
58 if other._par:
59 if other._par._left is other:
60 other._par._left = self
61 else:
62 other._par._right = self
63 self._par = other._par
64 self._left = other._left
65 self._right = other._right
66
67 def _weight_left(self) -> int:
68 return self._left._size + 1 if self._left else 1
69
70 def _weight_right(self) -> int:
71 return self._right._size + 1 if self._right else 1
72
73 def _update(self) -> None:
74 self._size = (
75 1
76 + (self._left._size if self._left else 0)
77 + (self._right._size if self._right else 0)
78 )
79
80 def _rotate_right(self) -> "_WBTNodeBase[T]":
81 u = self._left
82 u._size = self._size
83 self._size -= u._left._size + 1 if u._left else 1
84 u._par = self._par
85 self._left = u._right
86 if u._right:
87 u._right._par = self
88 u._right = self
89 self._par = u
90 if u._par:
91 if u._par._left is self:
92 u._par._left = u
93 else:
94 u._par._right = u
95 return u
96
97 def _rotate_left(self) -> "_WBTNodeBase[T]":
98 u = self._right
99 u._size = self._size
100 self._size -= u._right._size + 1 if u._right else 1
101 u._par = self._par
102 self._right = u._left
103 if u._left:
104 u._left._par = self
105 u._left = self
106 self._par = u
107 if u._par:
108 if u._par._left is self:
109 u._par._left = u
110 else:
111 u._par._right = u
112 return u
113
114 def _balance_check(self) -> None:
115 if not self._weight_left() * _WBTNodeBase.DELTA >= self._weight_right():
116 print(self._weight_left(), self._weight_right(), flush=True)
117 print(self)
118 assert False, f"self._weight_left() * DELTA >= self._weight_right()"
119 if not self._weight_right() * _WBTNodeBase.DELTA >= self._weight_left():
120 print(self._weight_left(), self._weight_right(), flush=True)
121 print(self)
122 assert False, f"self._weight_right() * DELTA >= self._weight_left()"
123
124 def _min(self) -> "_WBTNodeBase[T]":
125 node = self
126 while node._left:
127 node = node._left
128 return node
129
130 def _max(self) -> "_WBTNodeBase[T]":
131 node = self
132 while node._right:
133 node = node._right
134 return node
135
136 def _next(self) -> Optional["_WBTNodeBase[T]"]:
137 if self._right:
138 return self._right._min()
139 now, pre = self, None
140 while now and now._right is pre:
141 now, pre = now._par, now
142 return now
143
144 def _prev(self) -> Optional["_WBTNodeBase[T]"]:
145 if self._left:
146 return self._left._max()
147 now, pre = self, None
148 while now and now._left is pre:
149 now, pre = now._par, now
150 return now
151
152 def __add__(self, other: int) -> Optional["_WBTNodeBase[T]"]:
153 node = self
154 for _ in range(other):
155 node = node._next()
156 return node
157
158 def __sub__(self, other: int) -> Optional["_WBTNodeBase[T]"]:
159 node = self
160 for _ in range(other):
161 node = node._prev()
162 return node
163
164 __iadd__ = __add__
165 __isub__ = __sub__
166
167 def __str__(self) -> str:
168 # if self._left is None and self._right is None:
169 # return f"key:{self._key, self._size}\n"
170 # return f"key:{self._key, self._size},\n _left:{self._left},\n _right:{self._right}\n"
171 return str(self._key)
172
173 __repr__ = __str__
174from typing import TypeVar, Optional
175
176T = TypeVar("T")
177
178
179class _WBTSetNode(_WBTNodeBase[T]):
180
181 __slots__ = "_key", "_size", "_par", "_left", "_right"
182
183 def __init__(self, key: T) -> None:
184 super().__init__()
185 self._key: T = key
186 self._par: Optional[_WBTSetNode[T]]
187 self._left: Optional[_WBTSetNode[T]]
188 self._right: Optional[_WBTSetNode[T]]
189
190 @property
191 def key(self) -> T:
192 return self._key
193from typing import Generic, TypeVar, Optional, Iterable, Iterator
194
195T = TypeVar("T")
196
197
198class WBTSet(Generic[T]):
199 """重み平衡木で実装された順序付き集合"""
200
201 __slots__ = "_root", "_min", "_max"
202
203 def __init__(self, a: Iterable[T] = []) -> None:
204 """イテラブル ``a`` から ``WBTSet`` を構築します。
205
206 Args:
207 a (Iterable[T], optional): 構築元のイテラブルです。
208
209 計算量:
210
211 ソート済みなら :math:`O(n)` 、そうでないなら :math:`O(n \\log{n})`
212 """
213 self._root: Optional[_WBTSetNode[T]] = None
214 self._min: Optional[_WBTSetNode[T]] = None
215 self._max: Optional[_WBTSetNode[T]] = None
216 self.__build(a)
217
218 def __build(self, a: Iterable[T]) -> None:
219 """再帰的に構築する関数"""
220
221 def build(
222 l: int, r: int, pnode: Optional[_WBTSetNode[T]] = None
223 ) -> _WBTSetNode[T]:
224 if l == r:
225 return None
226 mid = (l + r) // 2
227 node = _WBTSetNode(a[mid])
228 node._left = build(l, mid, node)
229 node._right = build(mid + 1, r, node)
230 node._par = pnode
231 node._update()
232 return node
233
234 a = list(a)
235 if not a:
236 return
237 if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
238 a.sort()
239 new_a = [a[0]]
240 for elm in a:
241 if new_a[-1] == elm:
242 continue
243 new_a.append(elm)
244 a = new_a
245 self._root = build(0, len(a))
246 self._max = self._root._max()
247 self._min = self._root._min()
248
249 def add(self, key: T) -> bool:
250 """既に ``key`` が存在していれば何もせず ``False`` を返し、
251 存在していれば ``key`` を 1 つ追加して ``True`` を返します。
252
253 Args:
254 key (T): 追加するキーです。
255
256 Returns:
257 bool: ``key`` を追加したら ``True`` 、そうでなければ ``False`` を返します。
258
259 計算量:
260 :math:`O(\\log{n})`
261 """
262 if not self._root:
263 self._root = _WBTSetNode(key)
264 self._max = self._root
265 self._min = self._root
266 return True
267 pnode = None
268 node = self._root
269 while node:
270 if key == node._key:
271 return False
272 pnode = node
273 node = node._left if key < node._key else node._right
274 if key < pnode._key:
275 pnode._left = _WBTSetNode(key)
276 if key < self._min._key:
277 self._min = pnode._left
278 pnode._left._par = pnode
279 else:
280 pnode._right = _WBTSetNode(key)
281 if key > self._max._key:
282 self._max = pnode._right
283 pnode._right._par = pnode
284 self._root = pnode._rebalance()
285 return True
286
287 def find_key(self, key: T) -> Optional[_WBTSetNode[T]]:
288 """``key`` が存在すれば ``key`` を指すノードを返します。
289 そうでなければ ``None`` を返します。
290
291 Args:
292 key (T):
293
294 Returns:
295 Optional[_WBTSetNode[T]]:
296
297 計算量:
298 :math:`O(\\log{n})`
299 """
300 node = self._root
301 while node:
302 if key == node._key:
303 return node
304 node = node._left if key < node._key else node._right
305 return None
306
307 def find_order(self, k: int) -> _WBTSetNode[T]:
308 """昇順 ``k`` 番目のノードを返します。
309
310 Args:
311 k (int):
312
313 Returns:
314 _WBTSetNode[T]:
315
316 計算量:
317 :math:`O(\\log{n})`
318
319 制約:
320 :math:`-n \\leq k \\le n`
321 """
322 if k < 0:
323 k += len(self)
324 node = self._root
325 while True:
326 t = node._left._size if node._left else 0
327 if t == k:
328 return node
329 if t < k:
330 k -= t + 1
331 node = node._right
332 else:
333 node = node._left
334
335 def count(self, key: T) -> int:
336 return 1 if self.find_key(key) is not None else 0
337
338 def remove_iter(self, node: _WBTSetNode[T]) -> None:
339 """``node`` を削除します。
340
341 Args:
342 node (_WBTSetNode[T]):
343
344 計算量:
345 :math:`O(\\log{n})`
346 """
347 if node is self._min:
348 self._min = self._min._next()
349 if node is self._max:
350 self._max = self._max._prev()
351 delnode = node
352 pnode, mnode = node._par, None
353 if node._left and node._right:
354 pnode, mnode = node, node._left
355 while mnode._right:
356 pnode, mnode = mnode, mnode._right
357 node = mnode
358 cnode = node._right if not node._left else node._left
359 if cnode:
360 cnode._par = pnode
361 if pnode:
362 if pnode._left is node:
363 pnode._left = cnode
364 else:
365 pnode._right = cnode
366 self._root = pnode._rebalance()
367 else:
368 self._root = cnode
369 if mnode:
370 if self._root is delnode:
371 self._root = mnode
372 mnode._copy_from(delnode)
373 del delnode
374
375 def remove(self, key: T) -> None:
376 """``key`` を削除します。
377
378 Args:
379 key (T): 削除する ``key`` です。
380
381 計算量:
382 :math:`O(\\log{n})`
383
384 Note:
385 ``key`` が存在しない場合、 ``AssertionError`` を出します。
386 """
387 node = self.find_key(key)
388 assert node, f"KeyError: {key} is not exist."
389 self.remove_iter(node)
390
391 def discard(self, key: T) -> bool:
392 """``key`` が存在すれば削除して ``True`` を返します。
393 存在しなければなにもせず ``False`` を返します。
394
395 Args:
396 key (T): 削除する ``key`` です。
397
398 Returns:
399 bool: ``key`` が存在したかどうか
400
401 計算量:
402 :math:`O(\\log{n})`
403 """
404 node = self.find_key(key)
405 if node is None:
406 return False
407 self.remove_iter(node)
408 return True
409
410 def pop(self, k: int = -1) -> T:
411 """``k`` 番目の値を削除して返します。
412 引数指定がない場合は最大の値を削除して返します。
413
414 Args:
415 k (int, optional): 削除するインデックスです。
416
417 Returns:
418 T: ``k`` 番目の値です。
419
420 計算量:
421 :math:`O(\\log{n})`
422 """
423 node = self.find_order(k)
424 key = node._key
425 self.remove_iter(node)
426 return key
427
428 def le_iter(self, key: T) -> Optional[_WBTSetNode[T]]:
429 """``key`` 以下で最大のノードを返します。存在しないときは ``None`` を返します。
430
431 計算量:
432 :math:`O(\\log{n})`
433 """
434 res = None
435 node = self._root
436 while node:
437 if key == node._key:
438 res = node
439 break
440 if key < node._key:
441 node = node._left
442 else:
443 res = node
444 node = node._right
445 return res
446
447 def lt_iter(self, key: T) -> Optional[_WBTSetNode[T]]:
448 """``key`` より小さい値で最大のノードを返します。存在しないときは ``None`` を返します。
449
450 計算量:
451 :math:`O(\\log{n})`
452 """
453 res = None
454 node = self._root
455 while node:
456 if key <= node._key:
457 node = node._left
458 else:
459 res = node
460 node = node._right
461 return res
462
463 def ge_iter(self, key: T) -> Optional[_WBTSetNode[T]]:
464 """``key`` 以上で最小のノードを返します。存在しないときは ``None`` を返します。
465
466 計算量:
467 :math:`O(\\log{n})`
468 """
469 res = None
470 node = self._root
471 while node:
472 if key == node._key:
473 res = node
474 break
475 if key < node._key:
476 res = node
477 node = node._left
478 else:
479 node = node._right
480 return res
481
482 def gt_iter(self, key: T) -> Optional[_WBTSetNode[T]]:
483 """``key`` より大きい値で最小のノードを返します。存在しないときは ``None`` を返します。
484
485 計算量:
486 :math:`O(\\log{n})`
487 """
488 res = None
489 node = self._root
490 while node:
491 if key < node._key:
492 res = node
493 node = node._left
494 else:
495 node = node._right
496 return res
497
498 def le(self, key: T) -> Optional[T]:
499 """``key`` 以下で最大の要素を返します。存在しないときは ``None`` を返します。
500
501 計算量:
502 :math:`O(\\log{n})`
503 """
504 res = None
505 node = self._root
506 while node:
507 if key == node._key:
508 res = key
509 break
510 if key < node._key:
511 node = node._left
512 else:
513 res = node._key
514 node = node._right
515 return res
516
517 def lt(self, key: T) -> Optional[T]:
518 """``key`` より小さい値で最大の要素を返します。存在しないときは ``None`` を返します。
519
520 計算量:
521 :math:`O(\\log{n})`
522 """
523 res = None
524 node = self._root
525 while node:
526 if key <= node._key:
527 node = node._left
528 else:
529 res = node._key
530 node = node._right
531 return res
532
533 def ge(self, key: T) -> Optional[T]:
534 """``key`` 以上で最小の要素を返します。存在しないときは ``None`` を返します。
535
536 計算量:
537 :math:`O(\\log{n})`
538 """
539 res = None
540 node = self._root
541 while node:
542 if key == node._key:
543 res = key
544 break
545 if key < node._key:
546 res = node._key
547 node = node._left
548 else:
549 node = node._right
550 return res
551
552 def gt(self, key: T) -> Optional[T]:
553 """``key`` より大きい値で最小の要素を返します。存在しないときは ``None`` を返します。
554
555 計算量:
556 :math:`O(\\log{n})`
557 """
558 res = None
559 node = self._root
560 while node:
561 if key < node._key:
562 res = node._key
563 node = node._left
564 else:
565 node = node._right
566 return res
567
568 def index(self, key: T) -> int:
569 """``key`` より小さい値を個数を返します。
570
571 Args:
572 key (T):
573
574 Returns:
575 int:
576
577 計算量:
578 :math:`O(\\log{n})`
579 """
580 k = 0
581 node = self._root
582 while node:
583 if key == node._key:
584 k += node._left._size if node._left else 0
585 break
586 if key < node._key:
587 node = node._left
588 else:
589 k += node._left._size + 1 if node._left else 1
590 node = node._right
591 return k
592
593 def index_right(self, key: T) -> int:
594 """``key`` 以下の値を個数を返します。
595
596 Args:
597 key (T):
598
599 Returns:
600 int:
601
602 計算量:
603 :math:`O(\\log{n})`
604 """
605 k = 0
606 node = self._root
607 while node:
608 if key == node._key:
609 k += node._left._size + 1 if node._left else 1
610 break
611 if key < node._key:
612 node = node._left
613 else:
614 k += node._left._size + 1 if node._left else 1
615 node = node._right
616 return k
617
618 def get_min(self) -> T:
619 """最小の要素を返します。
620
621 Returns:
622 T:
623
624 計算量:
625 :math:`O(1)`
626
627 制約:
628 :math:`0 < n`
629 """
630 assert self._min
631 return self._min._key
632
633 def get_max(self) -> T:
634 """最大の要素を返します。
635
636 Returns:
637 T:
638
639 計算量:
640 :math:`O(1)`
641
642 制約:
643 :math:`0 < n`
644 """
645 assert self._max
646 return self._max._key
647
648 def pop_min(self) -> T:
649 """最小の要素を削除して返します。
650
651 Returns:
652 T:
653
654 計算量:
655 :math:`O(\\log{n})`
656
657 制約:
658 :math:`0 < n`
659 """
660 assert self._min
661 key = self._min._key
662 self.remove_iter(self._min)
663 return key
664
665 def pop_max(self) -> T:
666 """最大の要素を削除して返します。
667
668 Returns:
669 T:
670
671 計算量:
672 :math:`O(\\log{n})`
673
674 制約:
675 :math:`0 < n`
676 """
677 assert self._max
678 key = self._max._key
679 self.remove_iter(self._max)
680 return key
681
682 def _check(self) -> int:
683 """作業用デバック関数
684 size,key,balanceをチェックして、正しければ高さを表示する
685 """
686 if self._root is None:
687 # print("ok. 0 (empty)")
688 return 0
689
690 # _size, height
691 def dfs(node: _WBTSetNode[T]) -> tuple[int, int]:
692 h = 0
693 s = 1
694 if node._left:
695 assert node._key > node._left._key
696 ls, lh = dfs(node._left)
697 s += ls
698 h = max(h, lh)
699 if node._right:
700 assert node._key < node._right._key
701 rs, rh = dfs(node._right)
702 s += rs
703 h = max(h, rh)
704 assert node._size == s
705 node._balance_check()
706 return s, h + 1
707
708 _, h = dfs(self._root)
709 # print(f"ok. {h}")
710 return h
711
712 def __contains__(self, key: T) -> bool:
713 """``key`` が存在すれば ``True`` 、そうでなければ ``False`` を返します。
714
715 Args:
716 key (T):
717
718 Returns:
719 bool:
720
721 計算量:
722 :math:`O(\\log{n})`
723 """
724 return self.find_key(key) is not None
725
726 def __getitem__(self, k: int) -> T:
727 """昇順 ``k`` 番目の値を返します。
728
729 Args:
730 k (int):
731
732 Returns:
733 T:
734
735 計算量:
736 k = 0 または k = n-1 の場合: :math:`O(1)`
737 そうでない場合: :math:`O(\\log{n})`
738 """
739 assert (
740 -len(self) <= k < len(self)
741 ), f"IndexError: {self.__class__.__name__}[{k}], len={len(self)}"
742 if k < 0:
743 k += len(self)
744 if k == 0:
745 return self.get_min()
746 if k == len(self) - 1:
747 return self.get_max()
748 return self.find_order(k)._key
749
750 # def __delitem__(self, k: int) -> None:
751 # self.remove_iter(self.find_order(k))
752
753 def __len__(self) -> int:
754 """要素数を返します。
755
756 Returns:
757 int:
758
759 計算量:
760 :math:`O(1)`
761 """
762 return self._root._size if self._root else 0
763
764 def __iter__(self) -> Iterator[T]:
765 """昇順に値を返します。
766
767 Yields:
768 Iterator[T]:
769
770 計算量:
771 全体で :math:`O(n)`
772 """
773 stack: list[_WBTSetNode[T]] = []
774 node = self._root
775 while stack or node:
776 if node:
777 stack.append(node)
778 node = node._left
779 else:
780 node = stack.pop()
781 yield node._key
782 node = node._right
783
784 def __reversed__(self) -> Iterator[T]:
785 """降順に値を返します。
786
787 Yields:
788 Iterator[T]:
789
790 計算量:
791 全体で :math:`O(n)`
792 """
793 stack: list[_WBTSetNode[T]] = []
794 node = self._root
795 while stack or node:
796 if node:
797 stack.append(node)
798 node = node._right
799 else:
800 node = stack.pop()
801 yield node._key
802 node = node._left
803
804 def __str__(self) -> str:
805 return "{" + ", ".join(map(str, self)) + "}"
806
807 def __repr__(self) -> str:
808 return f"{self.__class__.__name__}(" + "{" + ", ".join(map(str, self)) + "})"
仕様¶
- class WBTSet(a: Iterable[T] = [])[source]¶
Bases:
Generic
[T
]重み平衡木で実装された順序付き集合
- __contains__(key: T) bool [source]¶
key
が存在すればTrue
、そうでなければFalse
を返します。- Parameters:
key (T)
- Return type:
bool
- 計算量:
\(O(\log{n})\)
- __getitem__(k: int) T [source]¶
昇順
k
番目の値を返します。- Parameters:
k (int)
- Return type:
T
- 計算量:
k = 0 または k = n-1 の場合: \(O(1)\) そうでない場合: \(O(\log{n})\)
- add(key: T) bool [source]¶
既に
key
が存在していれば何もせずFalse
を返し、 存在していればkey
を 1 つ追加してTrue
を返します。- Parameters:
key (T) – 追加するキーです。
- Returns:
key
を追加したらTrue
、そうでなければFalse
を返します。- Return type:
bool
- 計算量:
\(O(\log{n})\)
- discard(key: T) bool [source]¶
key
が存在すれば削除してTrue
を返します。 存在しなければなにもせずFalse
を返します。- Parameters:
key (T) – 削除する
key
です。- Returns:
key
が存在したかどうか- Return type:
bool
- 計算量:
\(O(\log{n})\)
- find_key(key: T) _WBTSetNode[T] | None [source]¶
key
が存在すればkey
を指すノードを返します。 そうでなければNone
を返します。- Parameters:
key (T)
- Return type:
Optional[_WBTSetNode[T]]
- 計算量:
\(O(\log{n})\)
- find_order(k: int) _WBTSetNode[T] [source]¶
昇順
k
番目のノードを返します。- Parameters:
k (int)
- Return type:
_WBTSetNode[T]
- 計算量:
\(O(\log{n})\)
- 制約:
\(-n \leq k \le n\)
- ge_iter(key: T) _WBTSetNode[T] | None [source]¶
key
以上で最小のノードを返します。存在しないときはNone
を返します。- 計算量:
\(O(\log{n})\)
- gt_iter(key: T) _WBTSetNode[T] | None [source]¶
key
より大きい値で最小のノードを返します。存在しないときはNone
を返します。- 計算量:
\(O(\log{n})\)
- index(key: T) int [source]¶
key
より小さい値を個数を返します。- Parameters:
key (T)
- Return type:
int
- 計算量:
\(O(\log{n})\)
- index_right(key: T) int [source]¶
key
以下の値を個数を返します。- Parameters:
key (T)
- Return type:
int
- 計算量:
\(O(\log{n})\)
- le_iter(key: T) _WBTSetNode[T] | None [source]¶
key
以下で最大のノードを返します。存在しないときはNone
を返します。- 計算量:
\(O(\log{n})\)
- lt_iter(key: T) _WBTSetNode[T] | None [source]¶
key
より小さい値で最大のノードを返します。存在しないときはNone
を返します。- 計算量:
\(O(\log{n})\)
- pop(k: int = -1) T [source]¶
k
番目の値を削除して返します。 引数指定がない場合は最大の値を削除して返します。- Parameters:
k (int, optional) – 削除するインデックスです。
- Returns:
k
番目の値です。- Return type:
T
- 計算量:
\(O(\log{n})\)