red_black_tree_set¶
ソースコード¶
from titan_pylib.data_structures.red_black_tree.red_black_tree_set import RedBlackTreeSet
展開済みコード¶
1# from titan_pylib.data_structures.red_black_tree.red_black_tree_set import RedBlackTreeSet
2# from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
4from typing import Protocol
5
6
7class SupportsLessThan(Protocol):
8
9 def __lt__(self, other) -> bool: ...
10from abc import ABC, abstractmethod
11from typing import Iterable, Optional, Iterator, TypeVar, Generic
12
13T = TypeVar("T", bound=SupportsLessThan)
14
15
16class OrderedSetInterface(ABC, Generic[T]):
17
18 @abstractmethod
19 def __init__(self, a: Iterable[T]) -> None:
20 raise NotImplementedError
21
22 @abstractmethod
23 def add(self, key: T) -> bool:
24 raise NotImplementedError
25
26 @abstractmethod
27 def discard(self, key: T) -> bool:
28 raise NotImplementedError
29
30 @abstractmethod
31 def remove(self, key: T) -> None:
32 raise NotImplementedError
33
34 @abstractmethod
35 def le(self, key: T) -> Optional[T]:
36 raise NotImplementedError
37
38 @abstractmethod
39 def lt(self, key: T) -> Optional[T]:
40 raise NotImplementedError
41
42 @abstractmethod
43 def ge(self, key: T) -> Optional[T]:
44 raise NotImplementedError
45
46 @abstractmethod
47 def gt(self, key: T) -> Optional[T]:
48 raise NotImplementedError
49
50 @abstractmethod
51 def get_max(self) -> Optional[T]:
52 raise NotImplementedError
53
54 @abstractmethod
55 def get_min(self) -> Optional[T]:
56 raise NotImplementedError
57
58 @abstractmethod
59 def pop_max(self) -> T:
60 raise NotImplementedError
61
62 @abstractmethod
63 def pop_min(self) -> T:
64 raise NotImplementedError
65
66 @abstractmethod
67 def clear(self) -> None:
68 raise NotImplementedError
69
70 @abstractmethod
71 def tolist(self) -> list[T]:
72 raise NotImplementedError
73
74 @abstractmethod
75 def __iter__(self) -> Iterator:
76 raise NotImplementedError
77
78 @abstractmethod
79 def __next__(self) -> T:
80 raise NotImplementedError
81
82 @abstractmethod
83 def __contains__(self, key: T) -> bool:
84 raise NotImplementedError
85
86 @abstractmethod
87 def __len__(self) -> int:
88 raise NotImplementedError
89
90 @abstractmethod
91 def __bool__(self) -> bool:
92 raise NotImplementedError
93
94 @abstractmethod
95 def __str__(self) -> str:
96 raise NotImplementedError
97
98 @abstractmethod
99 def __repr__(self) -> str:
100 raise NotImplementedError
101# from titan_pylib.my_class.supports_less_than import SupportsLessThan
102# from titan_pylib.data_structures.bst_base.bst_set_node_base import BSTSetNodeBase
103from typing import TypeVar, Generic, Optional
104
105T = TypeVar("T")
106Node = TypeVar("Node")
107# protcolで、key,left,right を規定
108
109
110class BSTSetNodeBase(Generic[T, Node]):
111
112 @staticmethod
113 def sort_unique(a: list[T]) -> list[T]:
114 if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
115 a = sorted(a)
116 new_a = [a[0]]
117 for elm in a:
118 if new_a[-1] == elm:
119 continue
120 new_a.append(elm)
121 a = new_a
122 return a
123
124 @staticmethod
125 def contains(node: Node, key: T) -> bool:
126 while node:
127 if key == node.key:
128 return True
129 node = node.left if key < node.key else node.right
130 return False
131
132 @staticmethod
133 def get_min(node: Node) -> Optional[T]:
134 if not node:
135 return None
136 while node.left:
137 node = node.left
138 return node.key
139
140 @staticmethod
141 def get_max(node: Node) -> Optional[T]:
142 if not node:
143 return None
144 while node.right:
145 node = node.right
146 return node.key
147
148 @staticmethod
149 def le(node: Node, key: T) -> Optional[T]:
150 res = None
151 while node is not None:
152 if key == node.key:
153 res = key
154 break
155 if key < node.key:
156 node = node.left
157 else:
158 res = node.key
159 node = node.right
160 return res
161
162 @staticmethod
163 def lt(node: Node, key: T) -> Optional[T]:
164 res = None
165 while node is not None:
166 if key <= node.key:
167 node = node.left
168 else:
169 res = node.key
170 node = node.right
171 return res
172
173 @staticmethod
174 def ge(node: Node, key: T) -> Optional[T]:
175 res = None
176 while node is not None:
177 if key == node.key:
178 res = key
179 break
180 if key < node.key:
181 res = node.key
182 node = node.left
183 else:
184 node = node.right
185 return res
186
187 @staticmethod
188 def gt(node: Node, key: T) -> Optional[T]:
189 res = None
190 while node is not None:
191 if key < node.key:
192 res = node.key
193 node = node.left
194 else:
195 node = node.right
196 return res
197
198 @staticmethod
199 def index(node: Node, key: T) -> int:
200 k = 0
201 while node is not None:
202 if key == node.key:
203 if node.left is not None:
204 k += node.left.size
205 break
206 if key < node.key:
207 node = node.left
208 else:
209 k += 1 if node.left is None else node.left.size + 1
210 node = node.right
211 return k
212
213 @staticmethod
214 def index_right(node: Node, key: T) -> int:
215 k = 0
216 while node is not None:
217 if key == node.key:
218 k += 1 if node.left is None else node.left.size + 1
219 break
220 if key < node.key:
221 node = node.left
222 else:
223 k += 1 if node.left is None else node.left.size + 1
224 node = node.right
225 return k
226
227 @staticmethod
228 def tolist(node: Node) -> list[T]:
229 stack = []
230 res = []
231 while stack or node:
232 if node:
233 stack.append(node)
234 node = node.left
235 else:
236 node = stack.pop()
237 res.append(node.key)
238 node = node.right
239 return res
240
241 @staticmethod
242 def kth_elm(node: Node, k: int, _len: int) -> T:
243 if k < 0:
244 k += _len
245 while True:
246 t = 0 if node.left is None else node.left.size
247 if t == k:
248 return node.key
249 if t > k:
250 node = node.left
251 else:
252 node = node.right
253 k -= t + 1
254from typing import Iterable, Optional, TypeVar, Generic, Sequence
255
256T = TypeVar("T", bound=SupportsLessThan)
257
258
259class RedBlackTreeSet(OrderedSetInterface, Generic[T]):
260 """赤黒木です。集合です。
261
262 ``std::set`` も怖くない。
263 """
264
265 class Node:
266 """``RedBlackTreeSet`` で使用される節点クラスです。
267
268 双方向に進められます。
269 ``1`` だけ進める場合、計算量は平均 ``O(1)`` 、最悪 ``O(logN)`` です。
270 ``k`` だけ進める場合、だいたい ``k`` 倍になります(ホント?)。
271 """
272
273 def __init__(self, key: T) -> None:
274 self.key = key
275 self.left = RedBlackTreeSet.NIL
276 self.right = RedBlackTreeSet.NIL
277 self.par = RedBlackTreeSet.NIL
278 self.col = 0
279
280 @property
281 def count(self) -> int:
282 """保持している `key` の個数です。
283 ``1`` を返します。
284 """
285 return 1
286
287 def _min(self) -> "RedBlackTreeSet.Node":
288 now = self
289 while now.left:
290 now = now.left
291 return now
292
293 def _max(self) -> "RedBlackTreeSet.Node":
294 now = self
295 while now.right:
296 now = now.right
297 return now
298
299 def _next(self):
300 now = self
301 pre = RedBlackTreeSet.NIL
302 flag = now.right is pre
303 while now.right is pre:
304 pre, now = now, now.par
305 if not now:
306 return None
307 return now if flag and pre is now.left else now.right._min()
308
309 def _prev(self):
310 now, pre = self, RedBlackTreeSet.NIL
311 flag = now.left is pre
312 while now.left is pre:
313 pre, now = now, now.par
314 if not now:
315 return None
316 return now if flag and pre is now.right else now.left._max()
317
318 def __iadd__(self, other: int):
319 """``node`` を次 ``node`` にします。存在しないときは ``None`` になります。"""
320 res = self
321 for _ in range(other):
322 assert res is not None, "RedBlackTreeSet Node.__iadd__() Error"
323 res = res._next()
324 return res
325
326 def __isub__(self, other: int):
327 """``node`` を前の ``node`` にします。存在しないときは ``None`` になります。"""
328 res = self
329 for _ in range(other):
330 assert res is not None, "RedBlackTreeSet Node.__isub__() Error"
331 res = res._prev()
332 return res
333
334 def __add__(self, other: int):
335 """次の ``node`` を返します。存在しないときは ``None`` を返します。"""
336 res = self
337 for _ in range(other):
338 assert res is not None, "RedBlackTreeSet Node.__add__() Error"
339 res = res._next()
340 return res
341
342 def __sub__(self, other: int):
343 """前の ``node`` を返します。存在しないときは ``None`` を返します。"""
344 res = self
345 for _ in range(other):
346 assert res is not None, "RedBlackTreeSet Node.__add__() Error"
347 res = res._prev()
348 return res
349
350 def __str__(self):
351 if self.left is RedBlackTreeSet.NIL and self.right is RedBlackTreeSet.NIL:
352 return f"(key,col,par.key):{self.key, self.col, self.par.key}\n"
353 return f"(key,col,par.key):{self.key, self.col, self.par.key},\n left:{self.left},\n right:{self.right}\n"
354
355 class _NILNode:
356
357 key = None
358 left = None
359 right = None
360 par = None
361 col = 0
362
363 def _min(self):
364 return None
365
366 def _max(self):
367 return None
368
369 def __bool__(self):
370 return False
371
372 def __str__(self):
373 return "NIL"
374
375 NIL = _NILNode()
376
377 def __init__(self, a: Iterable[T] = []):
378 """``a`` から ``RedBlackTreeSet`` を再帰的に構築します。
379 重複無くソート済みなら :math:`O(N)` 、そうでないなら :math:`O(NlogN)` です。
380 """
381 self.node = RedBlackTreeSet.NIL
382 self.size = 0
383 self.min_node = None
384 self.max_node = None
385 if not isinstance(a, Sequence):
386 a = list(a)
387 if a:
388 self._build(a)
389
390 def _build(self, a: Sequence[T]) -> None:
391 Node = RedBlackTreeSet.Node
392
393 def rec(l: int, r: int, d: int) -> RedBlackTreeSet.Node:
394 mid = (l + r) >> 1
395 node = Node(a[mid])
396 node.col = int((not flag and d & 1) or (flag and d > 1 and not d & 1))
397 if l != mid:
398 node.left = rec(l, mid, d + 1)
399 node.left.par = node
400 if mid + 1 != r:
401 node.right = rec(mid + 1, r, d + 1)
402 node.right.par = node
403 return node
404
405 a = BSTSetNodeBase[T, RedBlackTreeSet.Node].sort_unique(a)
406 flag = len(a).bit_length() & 1
407 self.node = rec(0, len(a), 0)
408 self.min_node = self.node._min()
409 self.max_node = self.node._max()
410 self.size = len(a)
411
412 def _rotate_left(self, node: Node) -> None:
413 u = node.right
414 p = node.par
415 node.right = u.left
416 if u.left:
417 u.left.par = node
418 u.par = p
419 if not p:
420 self.node = u
421 elif node is p.left:
422 p.left = u
423 else:
424 p.right = u
425 u.left = node
426 node.par = u
427
428 def _rotate_right(self, node: Node) -> None:
429 u = node.left
430 p = node.par
431 node.left = u.right
432 if u.right:
433 u.right.par = node
434 u.par = p
435 if not p:
436 self.node = u
437 elif node is p.right:
438 p.right = u
439 else:
440 p.left = u
441 u.right = node
442 node.par = u
443
444 def _transplant(self, u: Node, v: Node) -> None:
445 if not u.par:
446 self.node = v
447 elif u is u.par.left:
448 u.par.left = v
449 else:
450 u.par.right = v
451 v.par = u.par
452
453 def _get_min(self, node: Node) -> Node:
454 while node.left:
455 node = node.left
456 return node
457
458 def _get_max(self, node: Node) -> Node:
459 while node.right:
460 node = node.right
461 return node
462
463 def add(self, key: T) -> bool:
464 if not self.node:
465 node = RedBlackTreeSet.Node(key)
466 self.node = node
467 self.min_node = node
468 self.max_node = node
469 self.size = 1
470 return True
471 pnode = RedBlackTreeSet.NIL
472 node = self.node
473 while node:
474 pnode = node
475 if key == node.key:
476 return False
477 node = node.left if key < node.key else node.right
478 self.size += 1
479 z = RedBlackTreeSet.Node(key)
480 if key < self.min_node.key:
481 self.min_node = z
482 if key > self.max_node.key:
483 self.max_node = z
484 z.par = pnode
485 if not pnode:
486 self.node = z
487 elif key < pnode.key:
488 pnode.left = z
489 else:
490 pnode.right = z
491 z.col = 1
492 while z.par.col:
493 g = z.par.par
494 if z.par is g.left:
495 y = g.right
496 if y.col:
497 z.par.col = 0
498 y.col = 0
499 g.col = 1
500 z = g
501 else:
502 if z is z.par.right:
503 z = z.par
504 self._rotate_left(z)
505 z.par.col = 0
506 g.col = 1
507 self._rotate_right(g)
508 break
509 else:
510 y = g.left
511 if y.col:
512 z.par.col = 0
513 y.col = 0
514 g.col = 1
515 z = g
516 else:
517 if z is z.par.left:
518 z = z.par
519 self._rotate_right(z)
520 z.par.col = 0
521 g.col = 1
522 self._rotate_left(g)
523 break
524 self.node.col = 0
525 return True
526
527 def discard_iter(self, node: Node) -> None:
528 """``node`` を削除します。
529 償却 :math:`O(1)` らしいです。
530
531 Args:
532 node (Node): 削除する ``node`` です。
533 """
534 assert isinstance(node, RedBlackTreeSet.Node)
535 self.size -= 1
536 if node.key == self.min_node.key:
537 self.min_node = node._next()
538 if node.key == self.max_node.key:
539 self.max_node = node._prev()
540 y = node
541 y_col = y.col
542 if not node.left:
543 x = node.right
544 self._transplant(node, node.right)
545 elif not node.right:
546 x = node.left
547 self._transplant(node, node.left)
548 else:
549 y = self._get_min(node.right)
550 y_col = y.col
551 x = y.right
552 if y.par is node:
553 x.par = y
554 else:
555 self._transplant(y, y.right)
556 y.right = node.right
557 y.right.par = y
558 self._transplant(node, y)
559 y.left = node.left
560 y.left.par = y
561 y.col = node.col
562 if y_col:
563 return
564 while x is not self.node and not x.col:
565 if x is x.par.left:
566 y = x.par
567 w = y.right
568 if w.col:
569 w.col = 0
570 y.col = 1
571 self._rotate_left(y)
572 w = y.right
573 if not (w.left.col or w.right.col):
574 w.col = 1
575 x = y
576 else:
577 if not w.right.col:
578 w.left.col = 0
579 w.col = 1
580 self._rotate_right(w)
581 w = y.right
582 w.col = y.col
583 y.col = 0
584 w.right.col = 0
585 self._rotate_left(x.par)
586 x = self.node
587 else:
588 y = x.par
589 w = y.left
590 if w.col:
591 w.col = 0
592 y.col = 1
593 self._rotate_right(y)
594 w = y.left
595 if not (w.right.col or w.left.col):
596 w.col = 1
597 x = y
598 else:
599 if not w.left.col:
600 w.right.col = 0
601 w.col = 1
602 self._rotate_left(w)
603 w = y.left
604 w.col = y.col
605 y.col = 0
606 w.left.col = 0
607 self._rotate_right(y)
608 x = self.node
609 x.col = 0
610
611 def discard(self, key: T) -> bool:
612 node = self.node
613 while node:
614 if key == node.key:
615 break
616 node = node.left if key < node.key else node.right
617 else:
618 return False
619 self.discard_iter(node)
620 return True
621
622 def remove(self, key: T) -> None:
623 if self.discard(key):
624 return
625 raise KeyError
626
627 def count(self, key: T) -> int:
628 return 1 if self.find(key) else 0
629
630 def get_max(self) -> Optional[T]:
631 if self.max_node is None:
632 return
633 return self.max_node.key
634
635 def get_min(self) -> Optional[T]:
636 if self.min_node is None:
637 return
638 return self.min_node.key
639
640 def get_max_iter(self) -> Optional[Node]:
641 """最大値を指す ``Node`` を返します。空であれば ``None`` を返します。
642 :math:`O(1)` です。
643 """
644 return self.max_node
645
646 def get_min_iter(self) -> Optional[Node]:
647 """最小値を指す ``Node`` を返します。空であれば ``None`` を返します。
648 :math:`O(1)` です。
649 """
650 return self.min_node
651
652 def le(self, key: T) -> Optional[T]:
653 res = self.le_iter(key)
654 return None if res is None else res.key
655
656 def lt(self, key: T) -> Optional[T]:
657 res = self.lt_iter(key)
658 return None if res is None else res.key
659
660 def ge(self, key: T) -> Optional[T]:
661 res = self.ge_iter(key)
662 return None if res is None else res.key
663
664 def gt(self, key: T) -> Optional[T]:
665 res = self.gt_iter(key)
666 return None if res is None else res.key
667
668 def le_iter(self, key: T) -> Optional[Node]:
669 res, node = None, self.node
670 while node:
671 if key == node.key:
672 res = node
673 break
674 elif key < node.key:
675 node = node.left
676 else:
677 res = node
678 node = node.right
679 return res
680
681 def lt_iter(self, key: T) -> Optional[Node]:
682 res, node = None, self.node
683 while node:
684 if key <= node.key:
685 node = node.left
686 else:
687 res = node
688 node = node.right
689 return res
690
691 def ge_iter(self, key: T) -> Optional[Node]:
692 res, node = None, self.node
693 while node:
694 if key == node.key:
695 res = node
696 break
697 if key < node.key:
698 res = node
699 node = node.left
700 else:
701 node = node.right
702 return res
703
704 def gt_iter(self, key: T) -> Optional[Node]:
705 res, node = None, self.node
706 while node:
707 if key < node.key:
708 res = node
709 node = node.left
710 else:
711 node = node.right
712 return res
713
714 def find(self, key: T) -> Optional[Node]:
715 """``key`` が存在すれば ``key`` を指す ``Node`` を返します。存在しなければ ``None`` を返します。
716 :math:`O(\\log{n})` です。
717 """
718 node = self.node
719 while node:
720 if key == node.key:
721 return node
722 node = node.left if key < node.key else node.right
723 return None
724
725 def tolist(self) -> list[T]:
726 return BSTSetNodeBase[T, RedBlackTreeSet.Node].tolist(self.node)
727
728 def pop_max(self) -> T:
729 assert self.node, f"IndexError: pop_max() from empty {self.__class__.__name__}."
730 node = self.max_node
731 self.discard_iter(node)
732 return node.key
733
734 def pop_min(self) -> T:
735 assert self.node, f"IndexError: pop_min() from empty {self.__class__.__name__}."
736 node = self.min_node
737 self.discard_iter(node)
738 return node.key
739
740 def clear(self) -> None:
741 self.node = RedBlackTreeSet.NIL
742 self.size = 0
743 self.min_node = None
744 self.max_node = None
745
746 def __iter__(self):
747 self.it = self.min_node
748 return self
749
750 def __next__(self):
751 if not self.it:
752 raise StopIteration
753 res = self.it.key
754 self.it += 1
755 return res
756
757 def __bool__(self):
758 return self.node is not RedBlackTreeSet.NIL
759
760 def __contains__(self, key: T):
761 node = self.node
762 while node:
763 if key == node.key:
764 return True
765 node = node.left if key < node.key else node.right
766 return False
767
768 def __len__(self):
769 return self.size
770
771 def __str__(self):
772 return "{" + ", ".join(map(str, self.tolist())) + "}"
773
774 def __repr__(self):
775 return f"{self.__class__.__name__}({self})"
仕様¶
- class RedBlackTreeSet(a: Iterable[T] = [])[source]¶
Bases:
OrderedSetInterface
,Generic
[T
]赤黒木です。集合です。
std::set
も怖くない。- NIL = <titan_pylib.data_structures.red_black_tree.red_black_tree_set.RedBlackTreeSet._NILNode object>¶
- class Node(key: T)[source]¶
Bases:
object
RedBlackTreeSet
で使用される節点クラスです。双方向に進められます。
1
だけ進める場合、計算量は平均O(1)
、最悪O(logN)
です。k
だけ進める場合、だいたいk
倍になります(ホント?)。- property count: int¶
保持している key の個数です。
1
を返します。
- discard_iter(node: Node) None [source]¶
node
を削除します。 償却 \(O(1)\) らしいです。- Parameters:
node (Node) – 削除する
node
です。