red_black_tree_set

ソースコード

from titan_pylib.data_structures.red_black_tree.red_black_tree_set import RedBlackTreeSet

view on github

展開済みコード

  1# from titan_pylib.data_structures.red_black_tree.red_black_tree_set import RedBlackTreeSet
  2# from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
  3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  4from typing import Protocol
  5
  6
  7class SupportsLessThan(Protocol):
  8
  9    def __lt__(self, other) -> bool: ...
 10from abc import ABC, abstractmethod
 11from typing import Iterable, Optional, Iterator, TypeVar, Generic
 12
 13T = TypeVar("T", bound=SupportsLessThan)
 14
 15
 16class OrderedSetInterface(ABC, Generic[T]):
 17
 18    @abstractmethod
 19    def __init__(self, a: Iterable[T]) -> None:
 20        raise NotImplementedError
 21
 22    @abstractmethod
 23    def add(self, key: T) -> bool:
 24        raise NotImplementedError
 25
 26    @abstractmethod
 27    def discard(self, key: T) -> bool:
 28        raise NotImplementedError
 29
 30    @abstractmethod
 31    def remove(self, key: T) -> None:
 32        raise NotImplementedError
 33
 34    @abstractmethod
 35    def le(self, key: T) -> Optional[T]:
 36        raise NotImplementedError
 37
 38    @abstractmethod
 39    def lt(self, key: T) -> Optional[T]:
 40        raise NotImplementedError
 41
 42    @abstractmethod
 43    def ge(self, key: T) -> Optional[T]:
 44        raise NotImplementedError
 45
 46    @abstractmethod
 47    def gt(self, key: T) -> Optional[T]:
 48        raise NotImplementedError
 49
 50    @abstractmethod
 51    def get_max(self) -> Optional[T]:
 52        raise NotImplementedError
 53
 54    @abstractmethod
 55    def get_min(self) -> Optional[T]:
 56        raise NotImplementedError
 57
 58    @abstractmethod
 59    def pop_max(self) -> T:
 60        raise NotImplementedError
 61
 62    @abstractmethod
 63    def pop_min(self) -> T:
 64        raise NotImplementedError
 65
 66    @abstractmethod
 67    def clear(self) -> None:
 68        raise NotImplementedError
 69
 70    @abstractmethod
 71    def tolist(self) -> list[T]:
 72        raise NotImplementedError
 73
 74    @abstractmethod
 75    def __iter__(self) -> Iterator:
 76        raise NotImplementedError
 77
 78    @abstractmethod
 79    def __next__(self) -> T:
 80        raise NotImplementedError
 81
 82    @abstractmethod
 83    def __contains__(self, key: T) -> bool:
 84        raise NotImplementedError
 85
 86    @abstractmethod
 87    def __len__(self) -> int:
 88        raise NotImplementedError
 89
 90    @abstractmethod
 91    def __bool__(self) -> bool:
 92        raise NotImplementedError
 93
 94    @abstractmethod
 95    def __str__(self) -> str:
 96        raise NotImplementedError
 97
 98    @abstractmethod
 99    def __repr__(self) -> str:
100        raise NotImplementedError
101# from titan_pylib.my_class.supports_less_than import SupportsLessThan
102# from titan_pylib.data_structures.bst_base.bst_set_node_base import BSTSetNodeBase
103from typing import TypeVar, Generic, Optional
104
105T = TypeVar("T")
106Node = TypeVar("Node")
107# protcolで、key,left,right を規定
108
109
110class BSTSetNodeBase(Generic[T, Node]):
111
112    @staticmethod
113    def sort_unique(a: list[T]) -> list[T]:
114        if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
115            a = sorted(a)
116            new_a = [a[0]]
117            for elm in a:
118                if new_a[-1] == elm:
119                    continue
120                new_a.append(elm)
121            a = new_a
122        return a
123
124    @staticmethod
125    def contains(node: Node, key: T) -> bool:
126        while node:
127            if key == node.key:
128                return True
129            node = node.left if key < node.key else node.right
130        return False
131
132    @staticmethod
133    def get_min(node: Node) -> Optional[T]:
134        if not node:
135            return None
136        while node.left:
137            node = node.left
138        return node.key
139
140    @staticmethod
141    def get_max(node: Node) -> Optional[T]:
142        if not node:
143            return None
144        while node.right:
145            node = node.right
146        return node.key
147
148    @staticmethod
149    def le(node: Node, key: T) -> Optional[T]:
150        res = None
151        while node is not None:
152            if key == node.key:
153                res = key
154                break
155            if key < node.key:
156                node = node.left
157            else:
158                res = node.key
159                node = node.right
160        return res
161
162    @staticmethod
163    def lt(node: Node, key: T) -> Optional[T]:
164        res = None
165        while node is not None:
166            if key <= node.key:
167                node = node.left
168            else:
169                res = node.key
170                node = node.right
171        return res
172
173    @staticmethod
174    def ge(node: Node, key: T) -> Optional[T]:
175        res = None
176        while node is not None:
177            if key == node.key:
178                res = key
179                break
180            if key < node.key:
181                res = node.key
182                node = node.left
183            else:
184                node = node.right
185        return res
186
187    @staticmethod
188    def gt(node: Node, key: T) -> Optional[T]:
189        res = None
190        while node is not None:
191            if key < node.key:
192                res = node.key
193                node = node.left
194            else:
195                node = node.right
196        return res
197
198    @staticmethod
199    def index(node: Node, key: T) -> int:
200        k = 0
201        while node is not None:
202            if key == node.key:
203                if node.left is not None:
204                    k += node.left.size
205                break
206            if key < node.key:
207                node = node.left
208            else:
209                k += 1 if node.left is None else node.left.size + 1
210                node = node.right
211        return k
212
213    @staticmethod
214    def index_right(node: Node, key: T) -> int:
215        k = 0
216        while node is not None:
217            if key == node.key:
218                k += 1 if node.left is None else node.left.size + 1
219                break
220            if key < node.key:
221                node = node.left
222            else:
223                k += 1 if node.left is None else node.left.size + 1
224                node = node.right
225        return k
226
227    @staticmethod
228    def tolist(node: Node) -> list[T]:
229        stack = []
230        res = []
231        while stack or node:
232            if node:
233                stack.append(node)
234                node = node.left
235            else:
236                node = stack.pop()
237                res.append(node.key)
238                node = node.right
239        return res
240
241    @staticmethod
242    def kth_elm(node: Node, k: int, _len: int) -> T:
243        if k < 0:
244            k += _len
245        while True:
246            t = 0 if node.left is None else node.left.size
247            if t == k:
248                return node.key
249            if t > k:
250                node = node.left
251            else:
252                node = node.right
253                k -= t + 1
254from typing import Iterable, Optional, TypeVar, Generic, Sequence
255
256T = TypeVar("T", bound=SupportsLessThan)
257
258
259class RedBlackTreeSet(OrderedSetInterface, Generic[T]):
260    """赤黒木です。集合です。
261
262    ``std::set`` も怖くない。
263    """
264
265    class Node:
266        """``RedBlackTreeSet`` で使用される節点クラスです。
267
268        双方向に進められます。
269        ``1`` だけ進める場合、計算量は平均 ``O(1)`` 、最悪 ``O(logN)`` です。
270        ``k`` だけ進める場合、だいたい ``k`` 倍になります(ホント?)。
271        """
272
273        def __init__(self, key: T) -> None:
274            self.key = key
275            self.left = RedBlackTreeSet.NIL
276            self.right = RedBlackTreeSet.NIL
277            self.par = RedBlackTreeSet.NIL
278            self.col = 0
279
280        @property
281        def count(self) -> int:
282            """保持している `key` の個数です。
283            ``1`` を返します。
284            """
285            return 1
286
287        def _min(self) -> "RedBlackTreeSet.Node":
288            now = self
289            while now.left:
290                now = now.left
291            return now
292
293        def _max(self) -> "RedBlackTreeSet.Node":
294            now = self
295            while now.right:
296                now = now.right
297            return now
298
299        def _next(self):
300            now = self
301            pre = RedBlackTreeSet.NIL
302            flag = now.right is pre
303            while now.right is pre:
304                pre, now = now, now.par
305            if not now:
306                return None
307            return now if flag and pre is now.left else now.right._min()
308
309        def _prev(self):
310            now, pre = self, RedBlackTreeSet.NIL
311            flag = now.left is pre
312            while now.left is pre:
313                pre, now = now, now.par
314            if not now:
315                return None
316            return now if flag and pre is now.right else now.left._max()
317
318        def __iadd__(self, other: int):
319            """``node`` を次 ``node`` にします。存在しないときは ``None`` になります。"""
320            res = self
321            for _ in range(other):
322                assert res is not None, "RedBlackTreeSet Node.__iadd__() Error"
323                res = res._next()
324            return res
325
326        def __isub__(self, other: int):
327            """``node`` を前の ``node`` にします。存在しないときは ``None`` になります。"""
328            res = self
329            for _ in range(other):
330                assert res is not None, "RedBlackTreeSet Node.__isub__() Error"
331                res = res._prev()
332            return res
333
334        def __add__(self, other: int):
335            """次の ``node`` を返します。存在しないときは ``None`` を返します。"""
336            res = self
337            for _ in range(other):
338                assert res is not None, "RedBlackTreeSet Node.__add__() Error"
339                res = res._next()
340            return res
341
342        def __sub__(self, other: int):
343            """前の ``node`` を返します。存在しないときは ``None`` を返します。"""
344            res = self
345            for _ in range(other):
346                assert res is not None, "RedBlackTreeSet Node.__add__() Error"
347                res = res._prev()
348            return res
349
350        def __str__(self):
351            if self.left is RedBlackTreeSet.NIL and self.right is RedBlackTreeSet.NIL:
352                return f"(key,col,par.key):{self.key, self.col, self.par.key}\n"
353            return f"(key,col,par.key):{self.key, self.col, self.par.key},\n left:{self.left},\n right:{self.right}\n"
354
355    class _NILNode:
356
357        key = None
358        left = None
359        right = None
360        par = None
361        col = 0
362
363        def _min(self):
364            return None
365
366        def _max(self):
367            return None
368
369        def __bool__(self):
370            return False
371
372        def __str__(self):
373            return "NIL"
374
375    NIL = _NILNode()
376
377    def __init__(self, a: Iterable[T] = []):
378        """``a`` から ``RedBlackTreeSet`` を再帰的に構築します。
379        重複無くソート済みなら :math:`O(N)` 、そうでないなら :math:`O(NlogN)` です。
380        """
381        self.node = RedBlackTreeSet.NIL
382        self.size = 0
383        self.min_node = None
384        self.max_node = None
385        if not isinstance(a, Sequence):
386            a = list(a)
387        if a:
388            self._build(a)
389
390    def _build(self, a: Sequence[T]) -> None:
391        Node = RedBlackTreeSet.Node
392
393        def rec(l: int, r: int, d: int) -> RedBlackTreeSet.Node:
394            mid = (l + r) >> 1
395            node = Node(a[mid])
396            node.col = int((not flag and d & 1) or (flag and d > 1 and not d & 1))
397            if l != mid:
398                node.left = rec(l, mid, d + 1)
399                node.left.par = node
400            if mid + 1 != r:
401                node.right = rec(mid + 1, r, d + 1)
402                node.right.par = node
403            return node
404
405        a = BSTSetNodeBase[T, RedBlackTreeSet.Node].sort_unique(a)
406        flag = len(a).bit_length() & 1
407        self.node = rec(0, len(a), 0)
408        self.min_node = self.node._min()
409        self.max_node = self.node._max()
410        self.size = len(a)
411
412    def _rotate_left(self, node: Node) -> None:
413        u = node.right
414        p = node.par
415        node.right = u.left
416        if u.left:
417            u.left.par = node
418        u.par = p
419        if not p:
420            self.node = u
421        elif node is p.left:
422            p.left = u
423        else:
424            p.right = u
425        u.left = node
426        node.par = u
427
428    def _rotate_right(self, node: Node) -> None:
429        u = node.left
430        p = node.par
431        node.left = u.right
432        if u.right:
433            u.right.par = node
434        u.par = p
435        if not p:
436            self.node = u
437        elif node is p.right:
438            p.right = u
439        else:
440            p.left = u
441        u.right = node
442        node.par = u
443
444    def _transplant(self, u: Node, v: Node) -> None:
445        if not u.par:
446            self.node = v
447        elif u is u.par.left:
448            u.par.left = v
449        else:
450            u.par.right = v
451        v.par = u.par
452
453    def _get_min(self, node: Node) -> Node:
454        while node.left:
455            node = node.left
456        return node
457
458    def _get_max(self, node: Node) -> Node:
459        while node.right:
460            node = node.right
461        return node
462
463    def add(self, key: T) -> bool:
464        if not self.node:
465            node = RedBlackTreeSet.Node(key)
466            self.node = node
467            self.min_node = node
468            self.max_node = node
469            self.size = 1
470            return True
471        pnode = RedBlackTreeSet.NIL
472        node = self.node
473        while node:
474            pnode = node
475            if key == node.key:
476                return False
477            node = node.left if key < node.key else node.right
478        self.size += 1
479        z = RedBlackTreeSet.Node(key)
480        if key < self.min_node.key:
481            self.min_node = z
482        if key > self.max_node.key:
483            self.max_node = z
484        z.par = pnode
485        if not pnode:
486            self.node = z
487        elif key < pnode.key:
488            pnode.left = z
489        else:
490            pnode.right = z
491        z.col = 1
492        while z.par.col:
493            g = z.par.par
494            if z.par is g.left:
495                y = g.right
496                if y.col:
497                    z.par.col = 0
498                    y.col = 0
499                    g.col = 1
500                    z = g
501                else:
502                    if z is z.par.right:
503                        z = z.par
504                        self._rotate_left(z)
505                    z.par.col = 0
506                    g.col = 1
507                    self._rotate_right(g)
508                    break
509            else:
510                y = g.left
511                if y.col:
512                    z.par.col = 0
513                    y.col = 0
514                    g.col = 1
515                    z = g
516                else:
517                    if z is z.par.left:
518                        z = z.par
519                        self._rotate_right(z)
520                    z.par.col = 0
521                    g.col = 1
522                    self._rotate_left(g)
523                    break
524        self.node.col = 0
525        return True
526
527    def discard_iter(self, node: Node) -> None:
528        """``node`` を削除します。
529        償却 :math:`O(1)` らしいです。
530
531        Args:
532          node (Node): 削除する ``node`` です。
533        """
534        assert isinstance(node, RedBlackTreeSet.Node)
535        self.size -= 1
536        if node.key == self.min_node.key:
537            self.min_node = node._next()
538        if node.key == self.max_node.key:
539            self.max_node = node._prev()
540        y = node
541        y_col = y.col
542        if not node.left:
543            x = node.right
544            self._transplant(node, node.right)
545        elif not node.right:
546            x = node.left
547            self._transplant(node, node.left)
548        else:
549            y = self._get_min(node.right)
550            y_col = y.col
551            x = y.right
552            if y.par is node:
553                x.par = y
554            else:
555                self._transplant(y, y.right)
556                y.right = node.right
557                y.right.par = y
558            self._transplant(node, y)
559            y.left = node.left
560            y.left.par = y
561            y.col = node.col
562        if y_col:
563            return
564        while x is not self.node and not x.col:
565            if x is x.par.left:
566                y = x.par
567                w = y.right
568                if w.col:
569                    w.col = 0
570                    y.col = 1
571                    self._rotate_left(y)
572                    w = y.right
573                if not (w.left.col or w.right.col):
574                    w.col = 1
575                    x = y
576                else:
577                    if not w.right.col:
578                        w.left.col = 0
579                        w.col = 1
580                        self._rotate_right(w)
581                        w = y.right
582                    w.col = y.col
583                    y.col = 0
584                    w.right.col = 0
585                    self._rotate_left(x.par)
586                    x = self.node
587            else:
588                y = x.par
589                w = y.left
590                if w.col:
591                    w.col = 0
592                    y.col = 1
593                    self._rotate_right(y)
594                    w = y.left
595                if not (w.right.col or w.left.col):
596                    w.col = 1
597                    x = y
598                else:
599                    if not w.left.col:
600                        w.right.col = 0
601                        w.col = 1
602                        self._rotate_left(w)
603                        w = y.left
604                    w.col = y.col
605                    y.col = 0
606                    w.left.col = 0
607                    self._rotate_right(y)
608                    x = self.node
609        x.col = 0
610
611    def discard(self, key: T) -> bool:
612        node = self.node
613        while node:
614            if key == node.key:
615                break
616            node = node.left if key < node.key else node.right
617        else:
618            return False
619        self.discard_iter(node)
620        return True
621
622    def remove(self, key: T) -> None:
623        if self.discard(key):
624            return
625        raise KeyError
626
627    def count(self, key: T) -> int:
628        return 1 if self.find(key) else 0
629
630    def get_max(self) -> Optional[T]:
631        if self.max_node is None:
632            return
633        return self.max_node.key
634
635    def get_min(self) -> Optional[T]:
636        if self.min_node is None:
637            return
638        return self.min_node.key
639
640    def get_max_iter(self) -> Optional[Node]:
641        """最大値を指す ``Node`` を返します。空であれば ``None`` を返します。
642        :math:`O(1)` です。
643        """
644        return self.max_node
645
646    def get_min_iter(self) -> Optional[Node]:
647        """最小値を指す ``Node`` を返します。空であれば ``None`` を返します。
648        :math:`O(1)` です。
649        """
650        return self.min_node
651
652    def le(self, key: T) -> Optional[T]:
653        res = self.le_iter(key)
654        return None if res is None else res.key
655
656    def lt(self, key: T) -> Optional[T]:
657        res = self.lt_iter(key)
658        return None if res is None else res.key
659
660    def ge(self, key: T) -> Optional[T]:
661        res = self.ge_iter(key)
662        return None if res is None else res.key
663
664    def gt(self, key: T) -> Optional[T]:
665        res = self.gt_iter(key)
666        return None if res is None else res.key
667
668    def le_iter(self, key: T) -> Optional[Node]:
669        res, node = None, self.node
670        while node:
671            if key == node.key:
672                res = node
673                break
674            elif key < node.key:
675                node = node.left
676            else:
677                res = node
678                node = node.right
679        return res
680
681    def lt_iter(self, key: T) -> Optional[Node]:
682        res, node = None, self.node
683        while node:
684            if key <= node.key:
685                node = node.left
686            else:
687                res = node
688                node = node.right
689        return res
690
691    def ge_iter(self, key: T) -> Optional[Node]:
692        res, node = None, self.node
693        while node:
694            if key == node.key:
695                res = node
696                break
697            if key < node.key:
698                res = node
699                node = node.left
700            else:
701                node = node.right
702        return res
703
704    def gt_iter(self, key: T) -> Optional[Node]:
705        res, node = None, self.node
706        while node:
707            if key < node.key:
708                res = node
709                node = node.left
710            else:
711                node = node.right
712        return res
713
714    def find(self, key: T) -> Optional[Node]:
715        """``key`` が存在すれば ``key`` を指す ``Node`` を返します。存在しなければ ``None`` を返します。
716        :math:`O(\\log{n})` です。
717        """
718        node = self.node
719        while node:
720            if key == node.key:
721                return node
722            node = node.left if key < node.key else node.right
723        return None
724
725    def tolist(self) -> list[T]:
726        return BSTSetNodeBase[T, RedBlackTreeSet.Node].tolist(self.node)
727
728    def pop_max(self) -> T:
729        assert self.node, f"IndexError: pop_max() from empty {self.__class__.__name__}."
730        node = self.max_node
731        self.discard_iter(node)
732        return node.key
733
734    def pop_min(self) -> T:
735        assert self.node, f"IndexError: pop_min() from empty {self.__class__.__name__}."
736        node = self.min_node
737        self.discard_iter(node)
738        return node.key
739
740    def clear(self) -> None:
741        self.node = RedBlackTreeSet.NIL
742        self.size = 0
743        self.min_node = None
744        self.max_node = None
745
746    def __iter__(self):
747        self.it = self.min_node
748        return self
749
750    def __next__(self):
751        if not self.it:
752            raise StopIteration
753        res = self.it.key
754        self.it += 1
755        return res
756
757    def __bool__(self):
758        return self.node is not RedBlackTreeSet.NIL
759
760    def __contains__(self, key: T):
761        node = self.node
762        while node:
763            if key == node.key:
764                return True
765            node = node.left if key < node.key else node.right
766        return False
767
768    def __len__(self):
769        return self.size
770
771    def __str__(self):
772        return "{" + ", ".join(map(str, self.tolist())) + "}"
773
774    def __repr__(self):
775        return f"{self.__class__.__name__}({self})"

仕様

class RedBlackTreeSet(a: Iterable[T] = [])[source]

Bases: OrderedSetInterface, Generic[T]

赤黒木です。集合です。

std::set も怖くない。

NIL = <titan_pylib.data_structures.red_black_tree.red_black_tree_set.RedBlackTreeSet._NILNode object>
class Node(key: T)[source]

Bases: object

RedBlackTreeSet で使用される節点クラスです。

双方向に進められます。 1 だけ進める場合、計算量は平均 O(1) 、最悪 O(logN) です。 k だけ進める場合、だいたい k 倍になります(ホント?)。

__add__(other: int)[source]

次の node を返します。存在しないときは None を返します。

__iadd__(other: int)[source]

node を次 node にします。存在しないときは None になります。

__isub__(other: int)[source]

node を前の node にします。存在しないときは None になります。

__sub__(other: int)[source]

前の node を返します。存在しないときは None を返します。

property count: int

保持している key の個数です。 1 を返します。

add(key: T) bool[source]
clear() None[source]
count(key: T) int[source]
discard(key: T) bool[source]
discard_iter(node: Node) None[source]

node を削除します。 償却 \(O(1)\) らしいです。

Parameters:

node (Node) – 削除する node です。

find(key: T) Node | None[source]

key が存在すれば key を指す Node を返します。存在しなければ None を返します。 \(O(\log{n})\) です。

ge(key: T) T | None[source]
ge_iter(key: T) Node | None[source]
get_max() T | None[source]
get_max_iter() Node | None[source]

最大値を指す Node を返します。空であれば None を返します。 \(O(1)\) です。

get_min() T | None[source]
get_min_iter() Node | None[source]

最小値を指す Node を返します。空であれば None を返します。 \(O(1)\) です。

gt(key: T) T | None[source]
gt_iter(key: T) Node | None[source]
le(key: T) T | None[source]
le_iter(key: T) Node | None[source]
lt(key: T) T | None[source]
lt_iter(key: T) Node | None[source]
pop_max() T[source]
pop_min() T[source]
remove(key: T) None[source]
tolist() list[T][source]