red_black_tree_multiset

ソースコード

from titan_pylib.data_structures.red_black_tree.red_black_tree_multiset import RedBlackTreeMultiset

view on github

展開済みコード

  1# from titan_pylib.data_structures.red_black_tree.red_black_tree_multiset import RedBlackTreeMultiset
  2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  3from typing import Protocol
  4
  5
  6class SupportsLessThan(Protocol):
  7
  8    def __lt__(self, other) -> bool: ...
  9# from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
 10# from titan_pylib.my_class.supports_less_than import SupportsLessThan
 11from abc import ABC, abstractmethod
 12from typing import Iterable, Optional, Iterator, TypeVar, Generic
 13
 14T = TypeVar("T", bound=SupportsLessThan)
 15
 16
 17class OrderedMultisetInterface(ABC, Generic[T]):
 18
 19    @abstractmethod
 20    def __init__(self, a: Iterable[T]) -> None:
 21        raise NotImplementedError
 22
 23    @abstractmethod
 24    def add(self, key: T, cnt: int) -> None:
 25        raise NotImplementedError
 26
 27    @abstractmethod
 28    def discard(self, key: T, cnt: int) -> bool:
 29        raise NotImplementedError
 30
 31    @abstractmethod
 32    def discard_all(self, key: T) -> bool:
 33        raise NotImplementedError
 34
 35    @abstractmethod
 36    def count(self, key: T) -> int:
 37        raise NotImplementedError
 38
 39    @abstractmethod
 40    def remove(self, key: T, cnt: int) -> None:
 41        raise NotImplementedError
 42
 43    @abstractmethod
 44    def le(self, key: T) -> Optional[T]:
 45        raise NotImplementedError
 46
 47    @abstractmethod
 48    def lt(self, key: T) -> Optional[T]:
 49        raise NotImplementedError
 50
 51    @abstractmethod
 52    def ge(self, key: T) -> Optional[T]:
 53        raise NotImplementedError
 54
 55    @abstractmethod
 56    def gt(self, key: T) -> Optional[T]:
 57        raise NotImplementedError
 58
 59    @abstractmethod
 60    def get_max(self) -> Optional[T]:
 61        raise NotImplementedError
 62
 63    @abstractmethod
 64    def get_min(self) -> Optional[T]:
 65        raise NotImplementedError
 66
 67    @abstractmethod
 68    def pop_max(self) -> T:
 69        raise NotImplementedError
 70
 71    @abstractmethod
 72    def pop_min(self) -> T:
 73        raise NotImplementedError
 74
 75    @abstractmethod
 76    def clear(self) -> None:
 77        raise NotImplementedError
 78
 79    @abstractmethod
 80    def tolist(self) -> list[T]:
 81        raise NotImplementedError
 82
 83    @abstractmethod
 84    def __iter__(self) -> Iterator:
 85        raise NotImplementedError
 86
 87    @abstractmethod
 88    def __next__(self) -> T:
 89        raise NotImplementedError
 90
 91    @abstractmethod
 92    def __contains__(self, key: T) -> bool:
 93        raise NotImplementedError
 94
 95    @abstractmethod
 96    def __len__(self) -> int:
 97        raise NotImplementedError
 98
 99    @abstractmethod
100    def __bool__(self) -> bool:
101        raise NotImplementedError
102
103    @abstractmethod
104    def __str__(self) -> str:
105        raise NotImplementedError
106
107    @abstractmethod
108    def __repr__(self) -> str:
109        raise NotImplementedError
110# from titan_pylib.data_structures.bst_base.bst_multiset_node_base import (
111#     BSTMultisetNodeBase,
112# )
113from typing import TypeVar, Generic, Optional
114
115T = TypeVar("T")
116Node = TypeVar("Node")
117# protcolで、key,val,left,right を規定
118
119
120class BSTMultisetNodeBase(Generic[T, Node]):
121
122    @staticmethod
123    def count(node, key: T) -> int:
124        while node is not None:
125            if node.key == key:
126                return node.val
127            node = node.left if key < node.key else node.right
128        return 0
129
130    @staticmethod
131    def get_min(node: Node) -> Optional[T]:
132        if not node:
133            return None
134        while node.left:
135            node = node.left
136        return node.key
137
138    @staticmethod
139    def get_max(node: Node) -> Optional[T]:
140        if not node:
141            return None
142        while node.right:
143            node = node.right
144        return node.key
145
146    @staticmethod
147    def contains(node: Node, key: T) -> bool:
148        while node:
149            if key == node.key:
150                return True
151            node = node.left if key < node.key else node.right
152        return False
153
154    @staticmethod
155    def tolist(node: Node) -> list[T]:
156        stack = []
157        a = []
158        while stack or node:
159            if node:
160                stack.append(node)
161                node = node.left
162            else:
163                node = stack.pop()
164                for _ in range(node.val):
165                    a.append(node.key)
166                node = node.right
167        return a
168
169    @staticmethod
170    def tolist_items(node: Node) -> list[tuple[T, int]]:
171        stack = []
172        a = []
173        while stack or node:
174            if node:
175                stack.append(node)
176                node = node.left
177            else:
178                node = stack.pop()
179                a.append((node.key, node.val))
180                node = node.right
181        return a
182
183    @staticmethod
184    def _rle(a: list[T]) -> tuple[list[T], list[int]]:
185        keys, vals = [], []
186        keys.append(a[0])
187        vals.append(1)
188        for i, elm in enumerate(a):
189            if i == 0:
190                continue
191            if elm == keys[-1]:
192                vals[-1] += 1
193                continue
194            keys.append(elm)
195            vals.append(1)
196        return keys, vals
197
198    @staticmethod
199    def le(node: Node, key: T) -> Optional[T]:
200        res = None
201        while node is not None:
202            if key == node.key:
203                res = key
204                break
205            if key < node.key:
206                node = node.left
207            else:
208                res = node.key
209                node = node.right
210        return res
211
212    @staticmethod
213    def lt(node: Node, key: T) -> Optional[T]:
214        res = None
215        while node is not None:
216            if key <= node.key:
217                node = node.left
218            else:
219                res = node.key
220                node = node.right
221        return res
222
223    @staticmethod
224    def ge(node: Node, key: T) -> Optional[T]:
225        res = None
226        while node is not None:
227            if key == node.key:
228                res = key
229                break
230            if key < node.key:
231                res = node.key
232                node = node.left
233            else:
234                node = node.right
235        return res
236
237    @staticmethod
238    def gt(node: Node, key: T) -> Optional[T]:
239        res = None
240        while node is not None:
241            if key < node.key:
242                res = node.key
243                node = node.left
244            else:
245                node = node.right
246        return res
247
248    @staticmethod
249    def index(node: Node, key: T) -> int:
250        k = 0
251        while node:
252            if key == node.key:
253                if node.left:
254                    k += node.left.valsize
255                break
256            if key < node.key:
257                node = node.left
258            else:
259                k += node.val if node.left is None else node.left.valsize + node.val
260                node = node.right
261        return k
262
263    @staticmethod
264    def index_right(node: Node, key: T) -> int:
265        k = 0
266        while node:
267            if key == node.key:
268                k += node.val if node.left is None else node.left.valsize + node.val
269                break
270            if key < node.key:
271                node = node.left
272            else:
273                k += node.val if node.left is None else node.left.valsize + node.val
274                node = node.right
275        return k
276from typing import Iterable, Optional, TypeVar, Generic
277
278T = TypeVar("T", bound=SupportsLessThan)
279
280
281class RedBlackTreeMultiset(OrderedMultisetInterface, Generic[T]):
282
283    class Node:
284
285        def __init__(self, key: T, cnt: int = 1):
286            self.key: T = key
287            self.left = RedBlackTreeMultiset._NIL
288            self.right = RedBlackTreeMultiset._NIL
289            self.par = RedBlackTreeMultiset._NIL
290            self.col: int = 0
291            self.cnt: int = cnt
292
293        @property
294        def count(self) -> int:
295            return self.cnt
296
297        def _min(self) -> "RedBlackTreeMultiset.Node":
298            now = self
299            while now.left:
300                now = now.left
301            return now
302
303        def _max(self) -> "RedBlackTreeMultiset.Node":
304            now = self
305            while now.right:
306                now = now.right
307            return now
308
309        def _next(self) -> Optional["RedBlackTreeMultiset.Node"]:
310            now = self
311            pre = RedBlackTreeMultiset._NIL
312            flag = now.right is pre
313            while now.right is pre:
314                pre, now = now, now.par
315            if not now:
316                return None
317            return now if flag and pre is now.left else now.right._min()
318
319        def _prev(self) -> Optional["RedBlackTreeMultiset.Node"]:
320            now, pre = self, RedBlackTreeMultiset._NIL
321            flag = now.left is pre
322            while now.left is pre:
323                pre, now = now, now.par
324            if not now:
325                return None
326            return now if flag and pre is now.right else now.left._max()
327
328        def __add__(self, other: int) -> Optional["RedBlackTreeMultiset.Node"]:
329            res = self
330            for _ in range(other):
331                res = res._next()
332            return res
333
334        def __sub__(self, other: int) -> Optional["RedBlackTreeMultiset.Node"]:
335            res = self
336            for _ in range(other):
337                res = res._prev()
338            return res
339
340        __iadd__ = __add__
341        __isub__ = __sub__
342
343        def __str__(self):
344            if (not self.left) and (not self.right):
345                return f"(key,col,par.key):{self.key, self.col, self.par.key}\n"
346            return f"(key,col,par.key):{self.key, self.col, self.par.key},\n left:{self.left},\n right:{self.right}\n"
347
348    class _NILNode:
349
350        key = None
351        left = None
352        right = None
353        par = None
354        col = 0
355        cnt = 0
356
357        def _min(self):
358            return None
359
360        def _max(self):
361            return None
362
363        def __bool__(self):
364            return False
365
366        def __str__(self):
367            return "_NIL"
368
369    _NIL = _NILNode()
370
371    def __init__(self, a: Iterable[T] = []):
372        self.node = RedBlackTreeMultiset._NIL
373        self.size = 0
374        self.min_node = None
375        self.max_node = None
376        if not isinstance(a, list):
377            a = list(a)
378        if a:
379            self._build(a)
380
381    def _build(self, a: list[T]) -> None:
382        def sort(l: int, r: int, d: int):
383            mid = (l + r) >> 1
384            node = Node(x[mid], y[mid])
385            node.col = int((not flag and d & 1) or (flag and d > 1 and not d & 1))
386            if l != mid:
387                node.left = sort(l, mid, d + 1)
388                node.left.par = node
389            if mid + 1 != r:
390                node.right = sort(mid + 1, r, d + 1)
391                node.right.par = node
392            return node
393
394        if not all(a[i] <= a[i + 1] for i in range(len(a) - 1)):
395            a = sorted(a)
396        Node = RedBlackTreeMultiset.Node
397        x, y = BSTMultisetNodeBase[T, RedBlackTreeMultiset.Node]._rle(a)
398        flag = len(x).bit_length() & 1
399        self.node = sort(0, len(x), 0)
400        self.min_node = self.node._min()
401        self.max_node = self.node._max()
402        self.size = len(a)
403
404    def _rotate_left(self, node: Node) -> None:
405        u = node.right
406        p = node.par
407        node.right = u.left
408        if u.left:
409            u.left.par = node
410        u.par = p
411        if not p:
412            self.node = u
413        elif node is p.left:
414            p.left = u
415        else:
416            p.right = u
417        u.left = node
418        node.par = u
419
420    def _rotate_right(self, node: Node) -> None:
421        u = node.left
422        p = node.par
423        node.left = u.right
424        if u.right:
425            u.right.par = node
426        u.par = p
427        if not p:
428            self.node = u
429        elif node is p.right:
430            p.right = u
431        else:
432            p.left = u
433        u.right = node
434        node.par = u
435
436    def _transplant(self, u: Node, v: Node) -> None:
437        if not u.par:
438            self.node = v
439        elif u is u.par.left:
440            u.par.left = v
441        else:
442            u.par.right = v
443        v.par = u.par
444
445    def _get_min(self, node: Node) -> Node:
446        while node.left:
447            node = node.left
448        return node
449
450    def _get_max(self, node: Node) -> Node:
451        while node.right:
452            node = node.right
453        return node
454
455    def add(self, key: T, cnt: int = 1) -> None:
456        if cnt <= 0:
457            return
458        self.size += cnt
459        if not self.node:
460            node = RedBlackTreeMultiset.Node(key, cnt)
461            self.node = node
462            self.min_node = node
463            self.max_node = node
464            return
465        pnode = RedBlackTreeMultiset._NIL
466        node = self.node
467        while node:
468            pnode = node
469            if key == node.key:
470                node.cnt += cnt
471                return
472            node = node.left if key < node.key else node.right
473        z = RedBlackTreeMultiset.Node(key, cnt)
474        if key < self.min_node.key:
475            self.min_node = z
476        if key > self.max_node.key:
477            self.max_node = z
478        z.par = pnode
479        if not pnode:
480            self.node = z
481        elif key < pnode.key:
482            pnode.left = z
483        else:
484            pnode.right = z
485        z.col = 1
486        while z.par.col == 1:
487            g = z.par.par
488            if z.par is g.left:
489                y = g.right
490                if y.col:
491                    z.par.col = 0
492                    y.col = 0
493                    g.col = 1
494                    z = g
495                else:
496                    if z is z.par.right:
497                        z = z.par
498                        self._rotate_left(z)
499                    z.par.col = 0
500                    g.col = 1
501                    self._rotate_right(g)
502                    break
503            else:
504                y = g.left
505                if y.col:
506                    z.par.col = 0
507                    y.col = 0
508                    g.col = 1
509                    z = g
510                else:
511                    if z is z.par.left:
512                        z = z.par
513                        self._rotate_right(z)
514                    z.par.col = 0
515                    g.col = 1
516                    self._rotate_left(g)
517                    break
518        self.node.col = 0
519
520    def discard_iter(self, node: Node) -> None:
521        assert isinstance(node, RedBlackTreeMultiset.Node)
522        self.size -= node.cnt
523        if node.key == self.min_node.key:
524            self.min_node = node._next()
525        if node.key == self.max_node.key:
526            self.max_node = node._prev()
527        y = node
528        y_col = y.col
529        if not node.left:
530            x = node.right
531            self._transplant(node, node.right)
532        elif not node.right:
533            x = node.left
534            self._transplant(node, node.left)
535        else:
536            y = self._get_min(node.right)
537            y_col = y.col
538            x = y.right
539            if y.par is node:
540                x.par = y
541            else:
542                self._transplant(y, y.right)
543                y.right = node.right
544                y.right.par = y
545            self._transplant(node, y)
546            y.left = node.left
547            y.left.par = y
548            y.col = node.col
549        if y_col:
550            return
551        while x is not self.node and not x.col:
552            if x is x.par.left:
553                y = x.par
554                w = y.right
555                if w.col:
556                    w.col = 0
557                    y.col = 1
558                    self._rotate_left(y)
559                    w = y.right
560                if not (w.left.col or w.right.col):
561                    w.col = 1
562                    x = y
563                else:
564                    if not w.right.col:
565                        w.left.col = 0
566                        w.col = 1
567                        self._rotate_right(w)
568                        w = y.right
569                    w.col = y.col
570                    y.col = 0
571                    w.right.col = 0
572                    self._rotate_left(x.par)
573                    x = self.node
574            else:
575                y = x.par
576                w = y.left
577                if w.col:
578                    w.col = 0
579                    y.col = 1
580                    self._rotate_right(y)
581                    w = y.left
582                if not (w.right.col or w.left.col):
583                    w.col = 1
584                    x = y
585                else:
586                    if not w.left.col:
587                        w.right.col = 0
588                        w.col = 1
589                        self._rotate_left(w)
590                        w = y.left
591                    w.col = y.col
592                    y.col = 0
593                    w.left.col = 0
594                    self._rotate_right(y)
595                    x = self.node
596        x.col = 0
597        return True
598
599    def discard(self, key: T, cnt: int = 1) -> bool:
600        assert cnt >= 0
601        node = self.node
602        while node:
603            if key == node.key:
604                break
605            node = node.left if key < node.key else node.right
606        else:
607            return False
608        if node.cnt > cnt:
609            node.cnt -= cnt
610            self.size -= cnt
611            return True
612        self.discard_iter(node)
613        return True
614
615    def discard_all(self, key: T) -> bool:
616        node = self.node
617        while node:
618            if key == node.key:
619                break
620            node = node.left if key < node.key else node.right
621        else:
622            return False
623        self.discard_iter(node)
624        return True
625
626    def remove(self, key: T, cnt: int = 1) -> None:
627        c = self.count(key)
628        if c < cnt:
629            raise KeyError(key)
630        self.discard(key, cnt)
631
632    def count(self, key: T) -> int:
633        node = self.node
634        while node:
635            if key == node.key:
636                break
637            node = node.left if key < node.key else node.right
638        return node.cnt if node else 0
639
640    def get_max(self) -> Optional[T]:
641        if self.max_node is None:
642            return
643        return self.max_node.key
644
645    def get_min(self) -> Optional[T]:
646        if self.min_node is None:
647            return
648        return self.min_node.key
649
650    def get_max_iter(self) -> Optional[Node]:
651        return self.max_node
652
653    def get_min_iter(self) -> Optional[Node]:
654        return self.min_node
655
656    def le(self, key: T) -> Optional[T]:
657        res, node = None, self.node
658        while node:
659            if key == node.key:
660                res = node
661                break
662            elif key < node.key:
663                node = node.left
664            else:
665                res = node
666                node = node.right
667        return res.key if res else None
668
669    def lt(self, key: T) -> Optional[T]:
670        res, node = None, self.node
671        while node:
672            if key <= node.key:
673                node = node.left
674            else:
675                res = node
676                node = node.right
677        return res.key if res else None
678
679    def ge(self, key: T) -> Optional[T]:
680        res, node = None, self.node
681        while node:
682            if key == node.key:
683                res = node
684                break
685            elif key < node.key:
686                res = node
687                node = node.left
688            else:
689                node = node.right
690        return res.key if res else None
691
692    def gt(self, key: T) -> Optional[T]:
693        res, node = None, self.node
694        while node:
695            if key < node.key:
696                res = node
697                node = node.left
698            else:
699                node = node.right
700        return res.key if res else None
701
702    def le_iter(self, key: T) -> Optional[Node]:
703        res, node = None, self.node
704        while node:
705            if key == node.key:
706                res = node
707                break
708            elif key < node.key:
709                node = node.left
710            else:
711                res = node
712                node = node.right
713        return res
714
715    def lt_iter(self, key: T) -> Optional[Node]:
716        res, node = None, self.node
717        while node:
718            if key <= node.key:
719                node = node.left
720            else:
721                res = node
722                node = node.right
723        return res
724
725    def ge_iter(self, key: T) -> Optional[Node]:
726        res, node = None, self.node
727        while node:
728            if key == node.key:
729                res = node
730                break
731            elif key < node.key:
732                res = node
733                node = node.left
734            else:
735                node = node.right
736        return res
737
738    def gt_iter(self, key: T) -> Optional[Node]:
739        res, node = None, self.node
740        while node:
741            if key < node.key:
742                res = node
743                node = node.left
744            else:
745                node = node.right
746        return res
747
748    def find(self, key: T) -> Optional[Node]:
749        node = self.node
750        while node:
751            if key == node.key:
752                return node
753            node = node.left if key < node.key else node.right
754        return None
755
756    def tolist(self) -> list[T]:
757        node = self.node
758        stack = []
759        res = []
760        while stack or node:
761            if node:
762                stack.append(node)
763                node = node.left
764            else:
765                node = stack.pop()
766                for _ in range(node.cnt):
767                    res.append(node.key)
768                node = node.right
769        return res
770
771    def pop_max(self) -> T:
772        assert self.node, "IndexError: pop_max() from empty RedBlackTreeMultiset"
773        node = self.max_node
774        if node.cnt == 1:
775            self.discard_iter(node)
776        else:
777            node.cnt -= 1
778            self.size -= 1
779        return node.key
780
781    def pop_min(self) -> T:
782        assert self.node, "IndexError: pop_min() from empty RedBlackTreeMultiset"
783        node = self.min_node
784        if node.cnt == 1:
785            self.discard_iter(node)
786        else:
787            node.cnt -= 1
788            self.size -= 1
789        return node.key
790
791    def clear(self) -> None:
792        self.node = RedBlackTreeMultiset._NIL
793        self.size = 0
794        self.min_node = None
795        self.max_node = None
796
797    def __iter__(self):
798        self.it = self.min_node
799        self.cnt = 0
800        return self
801
802    def __next__(self):
803        if self.cnt >= self.it.cnt:
804            self.it += 1
805            self.cnt = 0
806        self.cnt += 1
807        if self.it is None:
808            raise StopIteration
809        return self.it.key
810
811    def __bool__(self):
812        return self.node is not RedBlackTreeMultiset._NIL
813
814    def __contains__(self, key: T):
815        node = self.node
816        while node:
817            if key == node.key:
818                return True
819            node = node.left if key < node.key else node.right
820        return False
821
822    def __len__(self):
823        return self.size
824
825    def __str__(self):
826        return "{" + ", ".join(map(str, self.tolist())) + "}"
827
828    def __repr__(self):
829        return "RedBlackTreeMultiset(" + str(self.tolist()) + ")"

仕様

class RedBlackTreeMultiset(a: Iterable[T] = [])[source]

Bases: OrderedMultisetInterface, Generic[T]

class Node(key: T, cnt: int = 1)[source]

Bases: object

property count: int
add(key: T, cnt: int = 1) None[source]
clear() None[source]
count(key: T) int[source]
discard(key: T, cnt: int = 1) bool[source]
discard_all(key: T) bool[source]
discard_iter(node: Node) None[source]
find(key: T) Node | None[source]
ge(key: T) T | None[source]
ge_iter(key: T) Node | None[source]
get_max() T | None[source]
get_max_iter() Node | None[source]
get_min() T | None[source]
get_min_iter() Node | None[source]
gt(key: T) T | None[source]
gt_iter(key: T) Node | None[source]
le(key: T) T | None[source]
le_iter(key: T) Node | None[source]
lt(key: T) T | None[source]
lt_iter(key: T) Node | None[source]
pop_max() T[source]
pop_min() T[source]
remove(key: T, cnt: int = 1) None[source]
tolist() list[T][source]