avl_tree_multiset3

ソースコード

from titan_pylib.data_structures.avl_tree.avl_tree_multiset3 import AVLTreeMultiset3

view on github

展開済みコード

  1# from titan_pylib.data_structures.avl_tree.avl_tree_multiset3 import AVLTreeMultiset3
  2# from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
  3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  4from typing import Protocol
  5
  6
  7class SupportsLessThan(Protocol):
  8
  9    def __lt__(self, other) -> bool: ...
 10from abc import ABC, abstractmethod
 11from typing import Iterable, Optional, Iterator, TypeVar, Generic
 12
 13T = TypeVar("T", bound=SupportsLessThan)
 14
 15
 16class OrderedMultisetInterface(ABC, Generic[T]):
 17
 18    @abstractmethod
 19    def __init__(self, a: Iterable[T]) -> None:
 20        raise NotImplementedError
 21
 22    @abstractmethod
 23    def add(self, key: T, cnt: int) -> None:
 24        raise NotImplementedError
 25
 26    @abstractmethod
 27    def discard(self, key: T, cnt: int) -> bool:
 28        raise NotImplementedError
 29
 30    @abstractmethod
 31    def discard_all(self, key: T) -> bool:
 32        raise NotImplementedError
 33
 34    @abstractmethod
 35    def count(self, key: T) -> int:
 36        raise NotImplementedError
 37
 38    @abstractmethod
 39    def remove(self, key: T, cnt: int) -> None:
 40        raise NotImplementedError
 41
 42    @abstractmethod
 43    def le(self, key: T) -> Optional[T]:
 44        raise NotImplementedError
 45
 46    @abstractmethod
 47    def lt(self, key: T) -> Optional[T]:
 48        raise NotImplementedError
 49
 50    @abstractmethod
 51    def ge(self, key: T) -> Optional[T]:
 52        raise NotImplementedError
 53
 54    @abstractmethod
 55    def gt(self, key: T) -> Optional[T]:
 56        raise NotImplementedError
 57
 58    @abstractmethod
 59    def get_max(self) -> Optional[T]:
 60        raise NotImplementedError
 61
 62    @abstractmethod
 63    def get_min(self) -> Optional[T]:
 64        raise NotImplementedError
 65
 66    @abstractmethod
 67    def pop_max(self) -> T:
 68        raise NotImplementedError
 69
 70    @abstractmethod
 71    def pop_min(self) -> T:
 72        raise NotImplementedError
 73
 74    @abstractmethod
 75    def clear(self) -> None:
 76        raise NotImplementedError
 77
 78    @abstractmethod
 79    def tolist(self) -> list[T]:
 80        raise NotImplementedError
 81
 82    @abstractmethod
 83    def __iter__(self) -> Iterator:
 84        raise NotImplementedError
 85
 86    @abstractmethod
 87    def __next__(self) -> T:
 88        raise NotImplementedError
 89
 90    @abstractmethod
 91    def __contains__(self, key: T) -> bool:
 92        raise NotImplementedError
 93
 94    @abstractmethod
 95    def __len__(self) -> int:
 96        raise NotImplementedError
 97
 98    @abstractmethod
 99    def __bool__(self) -> bool:
100        raise NotImplementedError
101
102    @abstractmethod
103    def __str__(self) -> str:
104        raise NotImplementedError
105
106    @abstractmethod
107    def __repr__(self) -> str:
108        raise NotImplementedError
109# from titan_pylib.my_class.supports_less_than import SupportsLessThan
110from typing import Generic, Iterable, Iterator, TypeVar, Optional
111
112T = TypeVar("T", bound=SupportsLessThan)
113
114
115class AVLTreeMultiset3(OrderedMultisetInterface, Generic[T]):
116    """
117    多重集合としての AVL 木です。
118    ``class Node()`` を用いています。
119    """
120
121    class Node:
122
123        def __init__(self, key: T, val: int):
124            self.key: T = key
125            self.val: int = val
126            self.valsize: int = val
127            self.size: int = 1
128            self.left: Optional["AVLTreeMultiset3.Node"] = None
129            self.right: Optional["AVLTreeMultiset3.Node"] = None
130            self.balance: int = 0
131
132        def __str__(self):
133            if self.left is None and self.right is None:
134                return f"key:{self.key, self.val, self.size, self.valsize}\n"
135            return f"key:{self.key, self.val, self.size, self.valsize},\n left:{self.left},\n right:{self.right}\n"
136
137    def __init__(self, a: Iterable[T] = []):
138        self.node: Optional["AVLTreeMultiset3.Node"] = None
139        if a:
140            self._build(a)
141
142    def _rle(self, L: list[T]) -> tuple[list[T], list[int]]:
143        x, y = [L[0]], [1]
144        for i, a in enumerate(L):
145            if i == 0:
146                continue
147            if a == x[-1]:
148                y[-1] += 1
149                continue
150            x.append(a)
151            y.append(1)
152        return x, y
153
154    def _build(self, a: Iterable[T]) -> None:
155        Node = AVLTreeMultiset3.Node
156
157        def sort(l: int, r: int) -> tuple[Node, int]:
158            mid = (l + r) >> 1
159            node = Node(x[mid], y[mid])
160            h = 0
161            if l != mid:
162                left, hl = sort(l, mid)
163                node.left = left
164                node.size += left.size
165                node.valsize += left.valsize
166                node.balance = hl
167                h = hl
168            if mid + 1 != r:
169                right, hr = sort(mid + 1, r)
170                node.right = right
171                node.size += right.size
172                node.valsize += right.valsize
173                node.balance -= hr
174                if hr > h:
175                    h = hr
176            return node, h + 1
177
178        a = sorted(a)
179        if not a:
180            return
181        x, y = self._rle(a)
182        self.node = sort(0, len(x))[0]
183
184    def _rotate_L(self, node: Node) -> Node:
185        u = node.left
186        u.size = node.size
187        u.valsize = node.valsize
188        if u.left is None:
189            node.size -= 1
190            node.valsize -= u.val
191        else:
192            node.size -= u.left.size + 1
193            node.valsize -= u.left.valsize + u.val
194        node.left = u.right
195        u.right = node
196        if u.balance == 1:
197            u.balance = 0
198            node.balance = 0
199        else:
200            u.balance = -1
201            node.balance = 1
202        return u
203
204    def _rotate_R(self, node: Node) -> Node:
205        u = node.right
206        u.size = node.size
207        u.valsize = node.valsize
208        if u.right is None:
209            node.size -= 1
210            node.valsize -= u.val
211        else:
212            node.size -= u.right.size + 1
213            node.valsize -= u.right.valsize + u.val
214        node.right = u.left
215        u.left = node
216        if u.balance == -1:
217            u.balance = 0
218            node.balance = 0
219        else:
220            u.balance = 1
221            node.balance = -1
222        return u
223
224    def _update_balance(self, node: Node) -> None:
225        if node.balance == 1:
226            node.right.balance = -1
227            node.left.balance = 0
228        elif node.balance == -1:
229            node.right.balance = 0
230            node.left.balance = 1
231        else:
232            node.right.balance = 0
233            node.left.balance = 0
234        node.balance = 0
235
236    def _rotate_LR(self, node: Node) -> Node:
237        B = node.left
238        E = B.right
239        E.size = node.size
240        E.valsize = node.valsize
241        if E.right is None:
242            node.size -= B.size
243            node.valsize -= B.valsize
244            B.size -= 1
245            B.valsize -= E.val
246        else:
247            node.size -= B.size - E.right.size
248            node.valsize -= B.valsize - E.right.valsize
249            B.size -= E.right.size + 1
250            B.valsize -= E.right.valsize + E.val
251        B.right = E.left
252        E.left = B
253        node.left = E.right
254        E.right = node
255        self._update_balance(E)
256        return E
257
258    def _rotate_RL(self, node: Node) -> Node:
259        C = node.right
260        D = C.left
261        D.size = node.size
262        D.valsize = node.valsize
263        if D.left is None:
264            node.size -= C.size
265            node.valsize -= C.valsize
266            C.size -= 1
267            C.valsize -= D.val
268        else:
269            node.size -= C.size - D.left.size
270            node.valsize -= C.valsize - D.left.valsize
271            C.size -= D.left.size + 1
272            C.valsize -= D.left.valsize + D.val
273        C.left = D.right
274        D.right = C
275        node.right = D.left
276        D.left = node
277        self._update_balance(D)
278        return D
279
280    def _kth_elm(self, k: int) -> tuple[T, int]:
281        if k < 0:
282            k += len(self)
283        node = self.node
284        while True:
285            t = node.val if node.left is None else node.val + node.left.valsize
286            if t - node.val <= k < t:
287                return node.key, node.val
288            elif t > k:
289                node = node.left
290            else:
291                node = node.right
292                k -= t
293
294    def _kth_elm_tree(self, k: int) -> tuple[T, int]:
295        if k < 0:
296            k += self.len_elm()
297        assert 0 <= k < self.len_elm()
298        node = self.node
299        while True:
300            t = 0 if node.left is None else node.left.size
301            if t == k:
302                return node.key, node.val
303            elif t > k:
304                node = node.left
305            else:
306                node = node.right
307                k -= t + 1
308
309    def _discard(self, node: Node, path: list[Node], di: int) -> bool:
310        fdi = 0
311        if node.left is not None and node.right is not None:
312            path.append(node)
313            di <<= 1
314            di |= 1
315            lmax = node.left
316            while lmax.right is not None:
317                path.append(lmax)
318                di <<= 1
319                fdi <<= 1
320                fdi |= 1
321                lmax = lmax.right
322            lmax_val = lmax.val
323            node.key = lmax.key
324            node.val = lmax_val
325            node = lmax
326        cnode = node.right if node.left is None else node.left
327        if path:
328            if di & 1:
329                path[-1].left = cnode
330            else:
331                path[-1].right = cnode
332        else:
333            self.node = cnode
334            return True
335        while path:
336            new_node = None
337            pnode = path.pop()
338            pnode.balance -= 1 if di & 1 else -1
339            pnode.size -= 1
340            pnode.valsize -= lmax_val if fdi & 1 else 1
341            di >>= 1
342            fdi >>= 1
343            if pnode.balance == 2:
344                new_node = (
345                    self._rotate_LR(pnode)
346                    if pnode.left.balance < 0
347                    else self._rotate_L(pnode)
348                )
349            elif pnode.balance == -2:
350                new_node = (
351                    self._rotate_RL(pnode)
352                    if pnode.right.balance > 0
353                    else self._rotate_R(pnode)
354                )
355            elif pnode.balance != 0:
356                break
357            if new_node is not None:
358                if not path:
359                    self.node = new_node
360                    return
361                if di & 1:
362                    path[-1].left = new_node
363                else:
364                    path[-1].right = new_node
365                if new_node.balance != 0:
366                    break
367        while path:
368            pnode = path.pop()
369            pnode.size -= 1
370            pnode.valsize -= lmax_val if fdi & 1 else 1
371            fdi >>= 1
372        return True
373
374    def discard(self, key: T, val: int = 1) -> bool:
375        path = []
376        di = 0
377        node = self.node
378        while node is not None:
379            if key == node.key:
380                break
381            elif key < node.key:
382                path.append(node)
383                di <<= 1
384                di |= 1
385                node = node.left
386            else:
387                path.append(node)
388                di <<= 1
389                node = node.right
390        else:
391            return False
392        if val > node.val:
393            val = node.val - 1
394            node.val -= val
395            node.valsize -= val
396            for p in path:
397                p.valsize -= val
398        if node.val == 1:
399            self._discard(node, path, di)
400        else:
401            node.val -= val
402            node.valsize -= val
403            for p in path:
404                p.valsize -= val
405        return True
406
407    def discard_all(self, key: T) -> None:
408        self.discard(key, self.count(key))
409
410    def remove(self, key: T, val: int = 1) -> None:
411        if self.discard(key, val):
412            return
413        raise KeyError(key)
414
415    def add(self, key: T, val: int = 1) -> None:
416        if self.node is None:
417            self.node = AVLTreeMultiset3.Node(key, val)
418            return
419        pnode = self.node
420        di = 0
421        path = []
422        while pnode is not None:
423            if key == pnode.key:
424                pnode.val += val
425                pnode.valsize += val
426                for p in path:
427                    p.valsize += val
428                return
429            elif key < pnode.key:
430                path.append(pnode)
431                di <<= 1
432                di |= 1
433                pnode = pnode.left
434            else:
435                path.append(pnode)
436                di <<= 1
437                pnode = pnode.right
438        if di & 1:
439            path[-1].left = AVLTreeMultiset3.Node(key, val)
440        else:
441            path[-1].right = AVLTreeMultiset3.Node(key, val)
442        new_node = None
443        while path:
444            pnode = path.pop()
445            pnode.size += 1
446            pnode.valsize += val
447            pnode.balance += 1 if di & 1 else -1
448            di >>= 1
449            if pnode.balance == 0:
450                break
451            if pnode.balance == 2:
452                new_node = (
453                    self._rotate_LR(pnode)
454                    if pnode.left.balance < 0
455                    else self._rotate_L(pnode)
456                )
457                break
458            elif pnode.balance == -2:
459                new_node = (
460                    self._rotate_RL(pnode)
461                    if pnode.right.balance > 0
462                    else self._rotate_R(pnode)
463                )
464                break
465        if new_node is not None:
466            if path:
467                if di & 1:
468                    path[-1].left = new_node
469                else:
470                    path[-1].right = new_node
471            else:
472                self.node = new_node
473        for p in path:
474            p.size += 1
475            p.valsize += val
476
477    def count(self, key: T) -> int:
478        node = self.node
479        while node is not None:
480            if node.key == key:
481                return node.val
482            elif key < node.key:
483                node = node.left
484            else:
485                node = node.right
486        return 0
487
488    def le(self, key: T) -> Optional[T]:
489        res = None
490        node = self.node
491        while node is not None:
492            if key == node.key:
493                res = key
494                break
495            elif key < node.key:
496                node = node.left
497            else:
498                res = node.key
499                node = node.right
500        return res
501
502    def lt(self, key: T) -> Optional[T]:
503        res = None
504        node = self.node
505        while node is not None:
506            if key <= node.key:
507                node = node.left
508            else:
509                res = node.key
510                node = node.right
511        return res
512
513    def ge(self, key: T) -> Optional[T]:
514        res = None
515        node = self.node
516        while node is not None:
517            if key == node.key:
518                res = key
519                break
520            elif key < node.key:
521                res = node.key
522                node = node.left
523            else:
524                node = node.right
525        return res
526
527    def gt(self, key: T) -> Optional[T]:
528        res = None
529        node = self.node
530        while node is not None:
531            if key < node.key:
532                res = node.key
533                node = node.left
534            else:
535                node = node.right
536        return res
537
538    def index(self, key: T) -> int:
539        k = 0
540        node = self.node
541        while node is not None:
542            if key == node.key:
543                if node.left is not None:
544                    k += node.left.valsize
545                break
546            elif key < node.key:
547                node = node.left
548            else:
549                k += node.val if node.left is None else node.left.valsize + node.val
550                node = node.right
551        return k
552
553    def index_right(self, key: T) -> int:
554        k = 0
555        node = self.node
556        while node is not None:
557            if key == node.key:
558                k += node.val if node.left is None else node.left.valsize + node.val
559                break
560            elif key < node.key:
561                node = node.left
562            else:
563                k += node.val if node.left is None else node.left.valsize + node.val
564                node = node.right
565        return k
566
567    def index_keys(self, key: T) -> int:
568        k = 0
569        node = self.node
570        while node:
571            if key == node.key:
572                if node.left is not None:
573                    k += node.left.size
574                break
575            elif key < node.key:
576                node = node.left
577            else:
578                k += node.val if node.left is None else node.left.size + node.val
579                node = node.right
580        return k
581
582    def index_right_keys(self, key: T) -> int:
583        k = 0
584        node = self.node
585        while node:
586            if key == node.key:
587                k += node.val if node.left is None else node.left.size + node.val
588                break
589            elif key < node.key:
590                node = node.left
591            else:
592                k += node.val if node.left is None else node.left.size + node.val
593                node = node.right
594        return k
595
596    def get_min(self) -> Optional[T]:
597        if self.node is None:
598            return
599        node = self.node
600        while node.left is not None:
601            node = node.left
602        return node.key
603
604    def get_max(self) -> Optional[T]:
605        if self.node is None:
606            return
607        node = self.node
608        while node.right is not None:
609            node = node.right
610        return node.key
611
612    def pop(self, k: int = -1) -> T:
613        if k < 0:
614            k += self.node.valsize
615        node = self.node
616        path = []
617        if k == self.node.valsize - 1:
618            while node.right is not None:
619                path.append(node)
620                node = node.right
621            x = node.key
622            if node.val == 1:
623                self._discard(node, path, 0)
624            else:
625                node.val -= 1
626                node.valsize -= 1
627                for p in path:
628                    p.valsize -= 1
629            return x
630        di = 0
631        while True:
632            t = node.val if node.left is None else node.val + node.left.valsize
633            if t - node.val <= k < t:
634                x = node.key
635                break
636            elif t > k:
637                path.append(node)
638                di <<= 1
639                di |= 1
640                node = node.left
641            else:
642                path.append(node)
643                di <<= 1
644                node = node.right
645                k -= t
646        if node.val == 1:
647            self._discard(node, path, di)
648        else:
649            node.val -= 1
650            node.valsize -= 1
651            for p in path:
652                p.valsize -= 1
653        return x
654
655    def pop_max(self) -> T:
656        assert self
657        return self.pop()
658
659    def pop_min(self) -> T:
660        node = self.node
661        path = []
662        while node.left is not None:
663            path.append(node)
664            node = node.left
665        x = node.key
666        if node.val == 1:
667            self._discard(node, path, (1 << len(path)) - 1)
668        else:
669            node.val -= 1
670            node.valsize -= 1
671            for p in path:
672                p.valsize -= 1
673        return x
674
675    def items(self) -> Iterator[tuple[T, int]]:
676        for i in range(self.len_elm()):
677            yield self._kth_elm_tree(i)
678
679    def keys(self) -> Iterator[T]:
680        for i in range(self.len_elm()):
681            yield self._kth_elm_tree(i)[0]
682
683    def values(self) -> Iterator[int]:
684        for i in range(self.len_elm()):
685            yield self._kth_elm_tree(i)[1]
686
687    def len_elm(self) -> int:
688        return 0 if self.node is None else self.node.size
689
690    def show(self) -> None:
691        print(
692            "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}"
693        )
694
695    def clear(self) -> None:
696        self.node = None
697
698    def get_elm(self, k: int) -> T:
699        return self._kth_elm_tree(k)[0]
700
701    def tolist(self) -> list[T]:
702        a = []
703        if self.node is None:
704            return a
705
706        def rec(node):
707            if node.left is not None:
708                rec(node.left)
709            a.extend([node.key] * node.val)
710            if node.right is not None:
711                rec(node.right)
712
713        rec(self.node)
714        return a
715
716    def tolist_items(self) -> list[tuple[T, int]]:
717        a = []
718        if self.node is None:
719            return a
720
721        def rec(node):
722            if node.left is not None:
723                rec(node.left)
724            a.append((node.key, node.val))
725            if node.right is not None:
726                rec(node.right)
727
728        rec(self.node)
729        return a
730
731    def __getitem__(self, k: int):
732        return self._kth_elm(k)[0]
733
734    def __contains__(self, key: T):
735        node = self.node
736        while node:
737            if node.key == key:
738                return True
739            node = node.left if key < node.key else node.right
740        return False
741
742    def __iter__(self):
743        self.__iter = 0
744        return self
745
746    def __next__(self):
747        if self.__iter == len(self):
748            raise StopIteration
749        res = self._kth_elm(self.__iter)
750        self.__iter += 1
751        return res
752
753    def __reversed__(self):
754        for i in range(len(self)):
755            yield self._kth_elm(-i - 1)[0]
756
757    def __len__(self):
758        return 0 if self.node is None else self.node.valsize
759
760    def __bool__(self):
761        return self.node is not None
762
763    def __str__(self):
764        return "{" + ", ".join(map(str, self.tolist())) + "}"
765
766    def __repr__(self):
767        return f"AVLTreeMultiset3({self.tolist()})"

仕様

class AVLTreeMultiset3(a: Iterable[T] = [])[source]

Bases: OrderedMultisetInterface, Generic[T]

多重集合としての AVL 木です。 class Node() を用いています。

class Node(key: T, val: int)[source]

Bases: object

add(key: T, val: int = 1) None[source]
clear() None[source]
count(key: T) int[source]
discard(key: T, val: int = 1) bool[source]
discard_all(key: T) None[source]
ge(key: T) T | None[source]
get_elm(k: int) T[source]
get_max() T | None[source]
get_min() T | None[source]
gt(key: T) T | None[source]
index(key: T) int[source]
index_keys(key: T) int[source]
index_right(key: T) int[source]
index_right_keys(key: T) int[source]
items() Iterator[tuple[T, int]][source]
keys() Iterator[T][source]
le(key: T) T | None[source]
len_elm() int[source]
lt(key: T) T | None[source]
pop(k: int = -1) T[source]
pop_max() T[source]
pop_min() T[source]
remove(key: T, val: int = 1) None[source]
show() None[source]
tolist() list[T][source]
tolist_items() list[tuple[T, int]][source]
values() Iterator[int][source]