wbt_multiset

ソースコード

from titan_pylib.data_structures.wbt.wbt_multiset import WBTMultiset

view on github

展開済みコード

  1# from titan_pylib.data_structures.wbt.wbt_multiset import WBTMultiset
  2# from titan_pylib.data_structures.wbt._wbt_multiset_node import _WBTMultisetNode
  3# from titan_pylib.data_structures.wbt._wbt_node_base import _WBTNodeBase
  4from typing import Generic, TypeVar, Optional, Final
  5
  6T = TypeVar("T")
  7
  8
  9class _WBTNodeBase(Generic[T]):
 10    """WBTノードのベースクラス
 11    size, par, left, rightをもつ
 12    """
 13
 14    __slots__ = "_size", "_par", "_left", "_right"
 15    DELTA: Final[int] = 3
 16    GAMMA: Final[int] = 2
 17
 18    def __init__(self) -> None:
 19        self._size: int = 1
 20        self._par: Optional[_WBTNodeBase[T]] = None
 21        self._left: Optional[_WBTNodeBase[T]] = None
 22        self._right: Optional[_WBTNodeBase[T]] = None
 23
 24    def _rebalance(self) -> "_WBTNodeBase[T]":
 25        """根までを再構築する
 26
 27        Returns:
 28            _WBTNodeBase[T]: 根ノード
 29        """
 30        node = self
 31        while True:
 32            node._update()
 33            wl, wr = node._weight_left(), node._weight_right()
 34            if wl * _WBTNodeBase.DELTA < wr:
 35                if (
 36                    node._right._weight_left()
 37                    >= node._right._weight_right() * _WBTNodeBase.GAMMA
 38                ):
 39                    node._right = node._right._rotate_right()
 40                node = node._rotate_left()
 41            elif wr * _WBTNodeBase.DELTA < wl:
 42                if (
 43                    node._left._weight_right()
 44                    >= node._left._weight_left() * _WBTNodeBase.GAMMA
 45                ):
 46                    node._left = node._left._rotate_left()
 47                node = node._rotate_right()
 48            if not node._par:
 49                return node
 50            node = node._par
 51
 52    def _copy_from(self, other: "_WBTNodeBase[T]") -> None:
 53        self._size = other._size
 54        if other._left:
 55            other._left._par = self
 56        if other._right:
 57            other._right._par = self
 58        if other._par:
 59            if other._par._left is other:
 60                other._par._left = self
 61            else:
 62                other._par._right = self
 63        self._par = other._par
 64        self._left = other._left
 65        self._right = other._right
 66
 67    def _weight_left(self) -> int:
 68        return self._left._size + 1 if self._left else 1
 69
 70    def _weight_right(self) -> int:
 71        return self._right._size + 1 if self._right else 1
 72
 73    def _update(self) -> None:
 74        self._size = (
 75            1
 76            + (self._left._size if self._left else 0)
 77            + (self._right._size if self._right else 0)
 78        )
 79
 80    def _rotate_right(self) -> "_WBTNodeBase[T]":
 81        u = self._left
 82        u._size = self._size
 83        self._size -= u._left._size + 1 if u._left else 1
 84        u._par = self._par
 85        self._left = u._right
 86        if u._right:
 87            u._right._par = self
 88        u._right = self
 89        self._par = u
 90        if u._par:
 91            if u._par._left is self:
 92                u._par._left = u
 93            else:
 94                u._par._right = u
 95        return u
 96
 97    def _rotate_left(self) -> "_WBTNodeBase[T]":
 98        u = self._right
 99        u._size = self._size
100        self._size -= u._right._size + 1 if u._right else 1
101        u._par = self._par
102        self._right = u._left
103        if u._left:
104            u._left._par = self
105        u._left = self
106        self._par = u
107        if u._par:
108            if u._par._left is self:
109                u._par._left = u
110            else:
111                u._par._right = u
112        return u
113
114    def _balance_check(self) -> None:
115        if not self._weight_left() * _WBTNodeBase.DELTA >= self._weight_right():
116            print(self._weight_left(), self._weight_right(), flush=True)
117            print(self)
118            assert False, f"self._weight_left() * DELTA >= self._weight_right()"
119        if not self._weight_right() * _WBTNodeBase.DELTA >= self._weight_left():
120            print(self._weight_left(), self._weight_right(), flush=True)
121            print(self)
122            assert False, f"self._weight_right() * DELTA >= self._weight_left()"
123
124    def _min(self) -> "_WBTNodeBase[T]":
125        node = self
126        while node._left:
127            node = node._left
128        return node
129
130    def _max(self) -> "_WBTNodeBase[T]":
131        node = self
132        while node._right:
133            node = node._right
134        return node
135
136    def _next(self) -> Optional["_WBTNodeBase[T]"]:
137        if self._right:
138            return self._right._min()
139        now, pre = self, None
140        while now and now._right is pre:
141            now, pre = now._par, now
142        return now
143
144    def _prev(self) -> Optional["_WBTNodeBase[T]"]:
145        if self._left:
146            return self._left._max()
147        now, pre = self, None
148        while now and now._left is pre:
149            now, pre = now._par, now
150        return now
151
152    def __add__(self, other: int) -> Optional["_WBTNodeBase[T]"]:
153        node = self
154        for _ in range(other):
155            node = node._next()
156        return node
157
158    def __sub__(self, other: int) -> Optional["_WBTNodeBase[T]"]:
159        node = self
160        for _ in range(other):
161            node = node._prev()
162        return node
163
164    __iadd__ = __add__
165    __isub__ = __sub__
166
167    def __str__(self) -> str:
168        # if self._left is None and self._right is None:
169        #     return f"key:{self._key, self._size}\n"
170        # return f"key:{self._key, self._size},\n _left:{self._left},\n _right:{self._right}\n"
171        return str(self._key)
172
173    __repr__ = __str__
174from typing import TypeVar, Optional
175
176T = TypeVar("T")
177
178
179class _WBTMultisetNode(_WBTNodeBase[T]):
180
181    __slots__ = "_key", "_count", "_count_size", "_size", "_par", "_left", "_right"
182
183    def __init__(self, key: T, count: int) -> None:
184        super().__init__()
185        self._key: T = key
186        self._count: int = count
187        self._count_size: int = count
188        self._par: Optional[_WBTMultisetNode[T]]
189        self._left: Optional[_WBTMultisetNode[T]]
190        self._right: Optional[_WBTMultisetNode[T]]
191
192    @property
193    def key(self) -> T:
194        return self._key
195
196    @property
197    def count(self) -> T:
198        return self._count
199
200    def _update(self) -> None:
201        super()._update()
202        self._count_size = (
203            self._count
204            + (self._left._count_size if self._left else 0)
205            + (self._right._count_size if self._right else 0)
206        )
207
208    def _rotate_right(self) -> "_WBTMultisetNode[T]":
209        u = self._left
210        u._size = self._size
211        u._count_size = self._count_size
212        self._size -= u._left._size + 1 if u._left else 1
213        self._count_size -= u._left._count_size + u._count if u._left else u._count
214        u._par = self._par
215        self._left = u._right
216        if u._right:
217            u._right._par = self
218        u._right = self
219        self._par = u
220        if u._par:
221            if u._par._left is self:
222                u._par._left = u
223            else:
224                u._par._right = u
225        return u
226
227    def _rotate_left(self) -> "_WBTMultisetNode[T]":
228        u = self._right
229        u._size = self._size
230        u._count_size = self._count_size
231        self._size -= u._right._size + 1 if u._right else 1
232        self._count_size -= u._right._count_size + u._count if u._right else u._count
233        u._par = self._par
234        self._right = u._left
235        if u._left:
236            u._left._par = self
237        u._left = self
238        self._par = u
239        if u._par:
240            if u._par._left is self:
241                u._par._left = u
242            else:
243                u._par._right = u
244        return u
245
246    def _copy_from(self, other: "_WBTMultisetNode[T]") -> None:
247        super()._copy_from(other)
248        self._count = other._count
249        self._count_size = other._count_size
250from typing import Generic, TypeVar, Optional, Iterable, Iterator
251
252T = TypeVar("T")
253
254
255class WBTMultiset(Generic[T]):
256
257    __slots__ = "_root", "_min", "_max"
258
259    def __init__(self, a: Iterable[T] = []) -> None:
260        self._root: Optional[_WBTMultisetNode[T]] = None
261        self._min: Optional[_WBTMultisetNode[T]] = None
262        self._max: Optional[_WBTMultisetNode[T]] = None
263        self.__build(a)
264
265    def __build(self, a: Iterable[T]) -> None:
266        def build(
267            l: int, r: int, pnode: Optional[_WBTMultisetNode[T]] = None
268        ) -> _WBTMultisetNode[T]:
269            if l == r:
270                return None
271            mid = (l + r) // 2
272            node = _WBTMultisetNode(keys[mid], vals[mid])
273            node._left = build(l, mid, node)
274            node._right = build(mid + 1, r, node)
275            node._par = pnode
276            node._update()
277            return node
278
279        a = list(a)
280        if not a:
281            return
282        if not all(a[i] <= a[i + 1] for i in range(len(a) - 1)):
283            a.sort()
284        # RLE
285        keys, vals = [a[0]], [1]
286        for i, elm in enumerate(a):
287            if i == 0:
288                continue
289            if elm == keys[-1]:
290                vals[-1] += 1
291                continue
292            keys.append(elm)
293            vals.append(1)
294        self._root = build(0, len(keys))
295        self._max = self._root._max()
296        self._min = self._root._min()
297
298    def add(self, key: T, count: int = 1) -> None:
299        if not self._root:
300            self._root = _WBTMultisetNode(key, count)
301            self._max = self._root
302            self._min = self._root
303            return
304        pnode = None
305        node = self._root
306        while node:
307            node._count_size += count
308            if key == node._key:
309                node._count += count
310                return
311            pnode = node
312            node = node._left if key < node._key else node._right
313        if key < pnode._key:
314            pnode._left = _WBTMultisetNode(key, count)
315            if key < self._min._key:
316                self._min = pnode._left
317            pnode._left._par = pnode
318        else:
319            pnode._right = _WBTMultisetNode(key, count)
320            if key > self._max._key:
321                self._max = pnode._right
322            pnode._right._par = pnode
323        self._root = pnode._rebalance()
324
325    def find_key(self, key: T) -> Optional[_WBTMultisetNode[T]]:
326        node = self._root
327        while node:
328            if key == node._key:
329                return node
330            node = node._left if key < node._key else node._right
331        return None
332
333    def find_order(self, k: int) -> _WBTMultisetNode[T]:
334        node = self._root
335        while True:
336            t = node._left._count_size + node._count if node._left else node._count
337            if t - node._count <= k < t:
338                return node
339            if t > k:
340                node = node._left
341            else:
342                node = node._right
343                k -= t
344
345    def count(self, key: T) -> int:
346        node = self.find_key(key)
347        return node.count if node is not None else 0
348
349    def remove_iter(self, node: _WBTMultisetNode[T]) -> None:
350        if node is self._min:
351            self._min = self._min._next()
352        if node is self._max:
353            self._max = self._max._prev()
354        delnode = node
355        pnode, mnode = node._par, None
356        if node._left and node._right:
357            pnode, mnode = node, node._left
358            while mnode._right:
359                pnode, mnode = mnode, mnode._right
360            node._count = mnode._count
361            node = mnode
362        cnode = node._right if not node._left else node._left
363        if cnode:
364            cnode._par = pnode
365        if pnode:
366            if pnode._left is node:
367                pnode._left = cnode
368            else:
369                pnode._right = cnode
370            self._root = pnode._rebalance()
371        else:
372            self._root = cnode
373        if mnode:
374            if self._root is delnode:
375                self._root = mnode
376            mnode._copy_from(delnode)
377        del delnode
378
379    def remove(self, key: T, count: int = 1) -> None:
380        node = self.find_key(key)
381        assert node, f"KeyError: {key} is not found."
382        if node._count <= count:
383            self.remove_iter(node)
384        else:
385            node._count -= count
386            while node:
387                node._count_size -= count
388                node = node._par
389
390    def discard(self, key: T, count: int = 1) -> bool:
391        node = self.find_key(key)
392        if node is None:
393            return False
394        if node._count <= count:
395            self.remove_iter(node)
396        else:
397            node._count -= count
398            while node:
399                node._count_size -= count
400                node = node._par
401        return True
402
403    def pop(self, k: int = -1) -> T:
404        node = self.find_order(k)
405        key = node._key
406        if node._count == 0:
407            self.remove_iter(node)
408        else:
409            node._count -= 1
410            while node:
411                node._count_size -= 1
412                node = node._par
413        return key
414
415    def le_iter(self, key: T) -> Optional[_WBTMultisetNode[T]]:
416        res = None
417        node = self._root
418        while node:
419            if key == node._key:
420                res = node
421                break
422            if key < node._key:
423                node = node._left
424            else:
425                res = node
426                node = node._right
427        return res
428
429    def lt_iter(self, key: T) -> Optional[_WBTMultisetNode[T]]:
430        res = None
431        node = self._root
432        while node:
433            if key <= node._key:
434                node = node._left
435            else:
436                res = node
437                node = node._right
438        return res
439
440    def ge_iter(self, key: T) -> Optional[_WBTMultisetNode[T]]:
441        res = None
442        node = self._root
443        while node:
444            if key == node._key:
445                res = node
446                break
447            if key < node._key:
448                res = node
449                node = node._left
450            else:
451                node = node._right
452        return res
453
454    def gt_iter(self, key: T) -> Optional[_WBTMultisetNode[T]]:
455        res = None
456        node = self._root
457        while node:
458            if key < node._key:
459                res = node
460                node = node._left
461            else:
462                node = node._right
463        return res
464
465    def le(self, key: T) -> Optional[T]:
466        res = None
467        node = self._root
468        while node:
469            if key == node._key:
470                res = key
471                break
472            if key < node._key:
473                node = node._left
474            else:
475                res = node._key
476                node = node._right
477        return res
478
479    def lt(self, key: T) -> Optional[T]:
480        res = None
481        node = self._root
482        while node:
483            if key <= node._key:
484                node = node._left
485            else:
486                res = node._key
487                node = node._right
488        return res
489
490    def ge(self, key: T) -> Optional[T]:
491        res = None
492        node = self._root
493        while node:
494            if key == node._key:
495                res = key
496                break
497            if key < node._key:
498                res = node._key
499                node = node._left
500            else:
501                node = node._right
502        return res
503
504    def gt(self, key: T) -> Optional[T]:
505        res = None
506        node = self._root
507        while node:
508            if key < node._key:
509                res = node._key
510                node = node._left
511            else:
512                node = node._right
513        return res
514
515    def index(self, key: T) -> int:
516        k = 0
517        node = self._root
518        while node:
519            if key == node._key:
520                k += node._left._count_size if node._left else 0
521                break
522            if key < node._key:
523                node = node._left
524            else:
525                k += node._left._count_size + node._count if node._left else node._count
526                node = node._right
527        return k
528
529    def index_right(self, key: T) -> int:
530        k = 0
531        node = self._root
532        while node:
533            if key == node._key:
534                k += node._left._count_size + node._count if node._left else node._count
535                break
536            if key < node._key:
537                node = node._left
538            else:
539                k += node._left._count_size + node._count if node._left else node._count
540                node = node._right
541        return k
542
543    def tolist(self) -> list[T]:
544        return list(self)
545
546    def get_min(self) -> T:
547        assert self._min
548        return self._min._key
549
550    def get_max(self) -> T:
551        assert self._max
552        return self._max._key
553
554    def pop_min(self) -> T:
555        assert self._min
556        key = self._min._key
557        self._min._count -= 1
558        if self._min._count == 0:
559            self.remove_iter(self._min)
560        return key
561
562    def pop_max(self) -> T:
563        assert self._max
564        key = self._max._key
565        self._max._count -= 1
566        if self._max._count == 0:
567            self.remove_iter(self._max)
568        return key
569
570    def check(self) -> None:
571        if self._root is None:
572            # print("ok. 0 (empty)")
573            return
574
575        # _size, count_size, height
576        def dfs(node: _WBTMultisetNode[T]) -> tuple[int, int, int]:
577            h = 0
578            s = 1
579            cs = node.count
580            if node._left:
581                assert node._key > node._left._key
582                ls, lcs, lh = dfs(node._left)
583                s += ls
584                cs += lcs
585                h = max(h, lh)
586            if node._right:
587                assert node._key < node._right._key
588                rs, rcs, rh = dfs(node._right)
589                s += rs
590                cs += rcs
591                h = max(h, rh)
592            assert node._size == s
593            assert node._count_size == cs
594            node._balance_check()
595            return s, cs, h + 1
596
597        _, _, h = dfs(self._root)
598        # print(f"ok. {h}")
599
600    def __contains__(self, key: T) -> bool:
601        return self.find_key(key) is not None
602
603    def __getitem__(self, k: int) -> T:
604        assert (
605            -len(self) <= k < len(self)
606        ), f"IndexError: {self.__class__.__name__}[{k}], len={len(self)}"
607        if k < 0:
608            k += len(self)
609        if k == 0:
610            return self.get_min()
611        if k == len(self) - 1:
612            return self.get_max()
613        return self.find_order(k)._key
614
615    def __delitem__(self, k: int) -> None:
616        node = self.find_order(k)
617        node._count -= 1
618        if node._count == 0:
619            self.remove_iter(node)
620
621    def __len__(self) -> int:
622        return self._root._count_size if self._root else 0
623
624    def __iter__(self) -> Iterator[T]:
625        stack: list[_WBTMultisetNode[T]] = []
626        node = self._root
627        while stack or node:
628            if node:
629                stack.append(node)
630                node = node._left
631            else:
632                node = stack.pop()
633                for _ in range(node._count):
634                    yield node._key
635                node = node._right
636
637    def __reversed__(self) -> Iterator[T]:
638        stack: list[_WBTMultisetNode[T]] = []
639        node = self._root
640        while stack or node:
641            if node:
642                stack.append(node)
643                node = node._right
644            else:
645                node = stack.pop()
646                for _ in range(node._count):
647                    yield node._key
648                node = node._left
649
650    def __str__(self) -> str:
651        return "{" + ", ".join(map(str, self)) + "}"
652
653    def __repr__(self) -> str:
654        return (
655            f"{self.__class__.__name__}("
656            + "["
657            + ", ".join(map(str, self.tolist()))
658            + "])"
659        )

仕様

class WBTMultiset(a: Iterable[T] = [])[source]

Bases: Generic[T]

add(key: T, count: int = 1) None[source]
check() None[source]
count(key: T) int[source]
discard(key: T, count: int = 1) bool[source]
find_key(key: T) _WBTMultisetNode[T] | None[source]
find_order(k: int) _WBTMultisetNode[T][source]
ge(key: T) T | None[source]
ge_iter(key: T) _WBTMultisetNode[T] | None[source]
get_max() T[source]
get_min() T[source]
gt(key: T) T | None[source]
gt_iter(key: T) _WBTMultisetNode[T] | None[source]
index(key: T) int[source]
index_right(key: T) int[source]
le(key: T) T | None[source]
le_iter(key: T) _WBTMultisetNode[T] | None[source]
lt(key: T) T | None[source]
lt_iter(key: T) _WBTMultisetNode[T] | None[source]
pop(k: int = -1) T[source]
pop_max() T[source]
pop_min() T[source]
remove(key: T, count: int = 1) None[source]
remove_iter(node: _WBTMultisetNode[T]) None[source]
tolist() list[T][source]