avl_tree_set

ソースコード

from titan_pylib.data_structures.avl_tree.avl_tree_set import AVLTreeSet

view on github

展開済みコード

  1# from titan_pylib.data_structures.avl_tree.avl_tree_set import AVLTreeSet
  2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  3from typing import Protocol
  4
  5
  6class SupportsLessThan(Protocol):
  7
  8    def __lt__(self, other) -> bool: ...
  9# from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
 10# from titan_pylib.my_class.supports_less_than import SupportsLessThan
 11from abc import ABC, abstractmethod
 12from typing import Iterable, Optional, Iterator, TypeVar, Generic
 13
 14T = TypeVar("T", bound=SupportsLessThan)
 15
 16
 17class OrderedSetInterface(ABC, Generic[T]):
 18
 19    @abstractmethod
 20    def __init__(self, a: Iterable[T]) -> None:
 21        raise NotImplementedError
 22
 23    @abstractmethod
 24    def add(self, key: T) -> bool:
 25        raise NotImplementedError
 26
 27    @abstractmethod
 28    def discard(self, key: T) -> bool:
 29        raise NotImplementedError
 30
 31    @abstractmethod
 32    def remove(self, key: T) -> None:
 33        raise NotImplementedError
 34
 35    @abstractmethod
 36    def le(self, key: T) -> Optional[T]:
 37        raise NotImplementedError
 38
 39    @abstractmethod
 40    def lt(self, key: T) -> Optional[T]:
 41        raise NotImplementedError
 42
 43    @abstractmethod
 44    def ge(self, key: T) -> Optional[T]:
 45        raise NotImplementedError
 46
 47    @abstractmethod
 48    def gt(self, key: T) -> Optional[T]:
 49        raise NotImplementedError
 50
 51    @abstractmethod
 52    def get_max(self) -> Optional[T]:
 53        raise NotImplementedError
 54
 55    @abstractmethod
 56    def get_min(self) -> Optional[T]:
 57        raise NotImplementedError
 58
 59    @abstractmethod
 60    def pop_max(self) -> T:
 61        raise NotImplementedError
 62
 63    @abstractmethod
 64    def pop_min(self) -> T:
 65        raise NotImplementedError
 66
 67    @abstractmethod
 68    def clear(self) -> None:
 69        raise NotImplementedError
 70
 71    @abstractmethod
 72    def tolist(self) -> list[T]:
 73        raise NotImplementedError
 74
 75    @abstractmethod
 76    def __iter__(self) -> Iterator:
 77        raise NotImplementedError
 78
 79    @abstractmethod
 80    def __next__(self) -> T:
 81        raise NotImplementedError
 82
 83    @abstractmethod
 84    def __contains__(self, key: T) -> bool:
 85        raise NotImplementedError
 86
 87    @abstractmethod
 88    def __len__(self) -> int:
 89        raise NotImplementedError
 90
 91    @abstractmethod
 92    def __bool__(self) -> bool:
 93        raise NotImplementedError
 94
 95    @abstractmethod
 96    def __str__(self) -> str:
 97        raise NotImplementedError
 98
 99    @abstractmethod
100    def __repr__(self) -> str:
101        raise NotImplementedError
102from array import array
103from typing import Generic, Iterable, TypeVar, Optional
104
105T = TypeVar("T", bound=SupportsLessThan)
106
107
108class AVLTreeSet(OrderedSetInterface, Generic[T]):
109    """AVLTreeSet
110    集合としての AVL 木です。
111    配列を用いてノードを表現しています。
112    size を持ちます。
113    """
114
115    def __init__(self, a: Iterable[T] = []) -> None:
116        self.root = 0
117        self.key = [0]
118        self.size = array("I", bytes(4))
119        self.left = array("I", bytes(4))
120        self.right = array("I", bytes(4))
121        self.balance = array("b", bytes(1))
122        self.end = 1
123        if not isinstance(a, list):
124            a = list(a)
125        if a:
126            self._build(a)
127
128    def reserve(self, n: int) -> None:
129        self.key += [0] * n
130        a = array("I", bytes(4 * n))
131        self.left += a
132        self.right += a
133        self.size += array("I", [1] * n)
134        self.balance += array("b", bytes(n))
135
136    def _build(self, a: list[T]) -> None:
137        left, right, size, balance = self.left, self.right, self.size, self.balance
138
139        def sort(l: int, r: int) -> tuple[int, int]:
140            mid = (l + r) >> 1
141            node = mid
142            hl, hr = 0, 0
143            if l != mid:
144                left[node], hl = sort(l, mid)
145                size[node] += size[left[node]]
146            if mid + 1 != r:
147                right[node], hr = sort(mid + 1, r)
148                size[node] += size[right[node]]
149            balance[node] = hl - hr
150            return node, max(hl, hr) + 1
151
152        n = len(a)
153        if n == 0:
154            return
155        if not all(a[i] < a[i + 1] for i in range(n - 1)):
156            b = sorted(a)
157            a = [b[0]]
158            for i in range(1, n):
159                if b[i] != a[-1]:
160                    a.append(b[i])
161        n = len(a)
162        end = self.end
163        self.end += n
164        self.reserve(n)
165        self.key[end : end + n] = a
166        self.root = sort(end, n + end)[0]
167
168    def _rotate_L(self, node: int) -> int:
169        left, right, size, balance = self.left, self.right, self.size, self.balance
170        u = left[node]
171        size[u] = size[node]
172        size[node] -= size[left[u]] + 1
173        left[node] = right[u]
174        right[u] = node
175        if balance[u] == 1:
176            balance[u] = 0
177            balance[node] = 0
178        else:
179            balance[u] = -1
180            balance[node] = 1
181        return u
182
183    def _rotate_R(self, node: int) -> int:
184        left, right, size, balance = self.left, self.right, self.size, self.balance
185        u = right[node]
186        size[u] = size[node]
187        size[node] -= size[right[u]] + 1
188        right[node] = left[u]
189        left[u] = node
190        if balance[u] == -1:
191            balance[u] = 0
192            balance[node] = 0
193        else:
194            balance[u] = 1
195            balance[node] = -1
196        return u
197
198    def _update_balance(self, node: int) -> None:
199        balance = self.balance
200        if balance[node] == 1:
201            balance[self.right[node]] = -1
202            balance[self.left[node]] = 0
203        elif balance[node] == -1:
204            balance[self.right[node]] = 0
205            balance[self.left[node]] = 1
206        else:
207            balance[self.right[node]] = 0
208            balance[self.left[node]] = 0
209        balance[node] = 0
210
211    def _rotate_LR(self, node: int) -> int:
212        left, right, size = self.left, self.right, self.size
213        B = left[node]
214        E = right[B]
215        size[E] = size[node]
216        size[node] -= size[B] - size[right[E]]
217        size[B] -= size[right[E]] + 1
218        right[B] = left[E]
219        left[E] = B
220        left[node] = right[E]
221        right[E] = node
222        self._update_balance(E)
223        return E
224
225    def _rotate_RL(self, node: int) -> int:
226        left, right, size = self.left, self.right, self.size
227        C = right[node]
228        D = left[C]
229        size[D] = size[node]
230        size[node] -= size[C] - size[left[D]]
231        size[C] -= size[left[D]] + 1
232        left[C] = right[D]
233        right[D] = C
234        right[node] = left[D]
235        left[D] = node
236        self._update_balance(D)
237        return D
238
239    def _make_node(self, key: T) -> int:
240        end = self.end
241        if end >= len(self.key):
242            self.key.append(key)
243            self.size.append(1)
244            self.left.append(0)
245            self.right.append(0)
246            self.balance.append(0)
247        else:
248            self.key[end] = key
249        self.end += 1
250        return end
251
252    def add(self, key: T) -> bool:
253        if self.root == 0:
254            self.root = self._make_node(key)
255            return True
256        left, right, size, balance, keys = (
257            self.left,
258            self.right,
259            self.size,
260            self.balance,
261            self.key,
262        )
263        node = self.root
264        path = []
265        di = 0
266        while node:
267            if key == keys[node]:
268                return False
269            di <<= 1
270            path.append(node)
271            if key < keys[node]:
272                di |= 1
273                node = left[node]
274            else:
275                node = right[node]
276        if di & 1:
277            left[path[-1]] = self._make_node(key)
278        else:
279            right[path[-1]] = self._make_node(key)
280        new_node = 0
281        while path:
282            node = path.pop()
283            size[node] += 1
284            balance[node] += 1 if di & 1 else -1
285            di >>= 1
286            if balance[node] == 0:
287                break
288            if balance[node] == 2:
289                new_node = (
290                    self._rotate_LR(node)
291                    if balance[left[node]] == -1
292                    else self._rotate_L(node)
293                )
294                break
295            elif balance[node] == -2:
296                new_node = (
297                    self._rotate_RL(node)
298                    if balance[right[node]] == 1
299                    else self._rotate_R(node)
300                )
301                break
302        if new_node:
303            if path:
304                node = path.pop()
305                size[node] += 1
306                if di & 1:
307                    left[node] = new_node
308                else:
309                    right[node] = new_node
310            else:
311                self.root = new_node
312        for p in path:
313            size[p] += 1
314        return True
315
316    def remove(self, key: T) -> bool:
317        if self.discard(key):
318            return True
319        raise KeyError(key)
320
321    def discard(self, key: T) -> bool:
322        left, right, size, balance, keys = (
323            self.left,
324            self.right,
325            self.size,
326            self.balance,
327            self.key,
328        )
329        di = 0
330        path = []
331        node = self.root
332        while node:
333            if key == keys[node]:
334                break
335            path.append(node)
336            di <<= 1
337            if key < keys[node]:
338                di |= 1
339                node = left[node]
340            else:
341                node = right[node]
342        else:
343            return False
344        if left[node] and right[node]:
345            path.append(node)
346            di <<= 1
347            di |= 1
348            lmax = left[node]
349            while right[lmax]:
350                path.append(lmax)
351                di <<= 1
352                lmax = right[lmax]
353            keys[node] = keys[lmax]
354            node = lmax
355        cnode = right[node] if left[node] == 0 else left[node]
356        if path:
357            if di & 1:
358                left[path[-1]] = cnode
359            else:
360                right[path[-1]] = cnode
361        else:
362            self.root = cnode
363            return True
364        while path:
365            new_node = 0
366            node = path.pop()
367            balance[node] -= 1 if di & 1 else -1
368            di >>= 1
369            size[node] -= 1
370            if balance[node] == 2:
371                new_node = (
372                    self._rotate_LR(node)
373                    if balance[left[node]] == -1
374                    else self._rotate_L(node)
375                )
376            elif balance[node] == -2:
377                new_node = (
378                    self._rotate_RL(node)
379                    if balance[right[node]] == 1
380                    else self._rotate_R(node)
381                )
382            elif balance[node]:
383                break
384            if new_node:
385                if not path:
386                    self.root = new_node
387                    return True
388                if di & 1:
389                    left[path[-1]] = new_node
390                else:
391                    right[path[-1]] = new_node
392                if balance[new_node]:
393                    break
394        for p in path:
395            size[p] -= 1
396        return True
397
398    def le(self, key: T) -> Optional[T]:
399        keys, left, right = self.key, self.left, self.right
400        res = None
401        node = self.root
402        while node:
403            if key == keys[node]:
404                return keys[node]
405            if key < keys[node]:
406                node = left[node]
407            else:
408                res = keys[node]
409                node = right[node]
410        return res
411
412    def lt(self, key: T) -> Optional[T]:
413        keys, left, right = self.key, self.left, self.right
414        res = None
415        node = self.root
416        while node:
417            if key <= keys[node]:
418                node = left[node]
419            else:
420                res = keys[node]
421                node = right[node]
422        return res
423
424    def ge(self, key: T) -> Optional[T]:
425        keys, left, right = self.key, self.left, self.right
426        res = None
427        node = self.root
428        while node:
429            if key == keys[node]:
430                return keys[node]
431            if key < keys[node]:
432                res = keys[node]
433                node = left[node]
434            else:
435                node = right[node]
436        return res
437
438    def gt(self, key: T) -> Optional[T]:
439        keys, left, right = self.key, self.left, self.right
440        res = None
441        node = self.root
442        while node:
443            if key < keys[node]:
444                res = keys[node]
445                node = left[node]
446            else:
447                node = right[node]
448        return res
449
450    def index(self, key: T) -> int:
451        keys, left, right, size = self.key, self.left, self.right, self.size
452        k = 0
453        node = self.root
454        while node:
455            if key == keys[node]:
456                k += size[left[node]]
457                break
458            if key < keys[node]:
459                node = left[node]
460            else:
461                k += size[left[node]] + 1
462                node = right[node]
463        return k
464
465    def index_right(self, key: T) -> int:
466        keys, left, right, size = self.key, self.left, self.right, self.size
467        k, node = 0, self.root
468        while node:
469            if key == keys[node]:
470                k += size[left[node]] + 1
471                break
472            if key < keys[node]:
473                node = left[node]
474            else:
475                k += size[left[node]] + 1
476                node = right[node]
477        return k
478
479    def get_max(self) -> Optional[T]:
480        if not self:
481            return
482        return self[len(self) - 1]
483
484    def get_min(self) -> Optional[T]:
485        if not self:
486            return
487        return self[0]
488
489    def pop(self, k: int = -1) -> T:
490        left, right, size, key, balance = (
491            self.left,
492            self.right,
493            self.size,
494            self.key,
495            self.balance,
496        )
497        if k < 0:
498            k += size[self.root]
499        assert 0 <= k and k < size[self.root], "IndexError"
500        path = []
501        di = 0
502        node = self.root
503        while True:
504            t = size[left[node]]
505            if t == k:
506                res = key[node]
507                break
508            path.append(node)
509            di <<= 1
510            if t < k:
511                k -= t + 1
512                node = right[node]
513            else:
514                di |= 1
515                node = left[node]
516        if left[node] and right[node]:
517            path.append(node)
518            di <<= 1
519            di |= 1
520            lmax = left[node]
521            while right[lmax]:
522                path.append(lmax)
523                di <<= 1
524                lmax = right[lmax]
525            key[node] = key[lmax]
526            node = lmax
527        cnode = right[node] if left[node] == 0 else left[node]
528        if path:
529            if di & 1:
530                left[path[-1]] = cnode
531            else:
532                right[path[-1]] = cnode
533        else:
534            self.root = cnode
535            return res
536        while path:
537            new_node = 0
538            node = path.pop()
539            balance[node] -= 1 if di & 1 else -1
540            di >>= 1
541            size[node] -= 1
542            if balance[node] == 2:
543                new_node = (
544                    self._rotate_LR(node)
545                    if balance[left[node]] == -1
546                    else self._rotate_L(node)
547                )
548            elif balance[node] == -2:
549                new_node = (
550                    self._rotate_RL(node)
551                    if balance[right[node]] == 1
552                    else self._rotate_R(node)
553                )
554            elif balance[node]:
555                break
556            if new_node:
557                if not path:
558                    self.root = new_node
559                    return res
560                if di & 1:
561                    left[path[-1]] = new_node
562                else:
563                    right[path[-1]] = new_node
564                if balance[new_node]:
565                    break
566        for p in path:
567            size[p] -= 1
568        return res
569
570    def pop_max(self) -> T:
571        return self.pop()
572
573    def pop_min(self) -> T:
574        return self.pop(0)
575
576    def clear(self) -> None:
577        self.root = 0
578
579    def tolist(self) -> list[T]:
580        left, right, keys = self.left, self.right, self.key
581        node = self.root
582        stack, a = [], []
583        while stack or node:
584            if node:
585                stack.append(node)
586                node = left[node]
587            else:
588                node = stack.pop()
589                a.append(keys[node])
590                node = right[node]
591        return a
592
593    def ave_height(self):
594        if not self.root:
595            return 0
596        left, right, size, keys = self.left, self.right, self.size, self.key
597        ans = 0
598
599        def dfs(node, dep):
600            nonlocal ans
601            ans += dep
602            if left[node]:
603                dfs(left[node], dep + 1)
604            if right[node]:
605                dfs(right[node], dep + 1)
606
607        dfs(self.root, 1)
608        ans /= len(self)
609        return ans
610
611    def _get_height(self) -> int:
612        """作業用デバック関数
613        size,key,balanceをチェックして、正しければ高さを表示する
614        """
615        if self.root == 0:
616            return 0
617
618        # _size, height
619        def dfs(node) -> tuple[int, int]:
620            h = 0
621            s = 1
622            if self.left[node]:
623                ls, lh = dfs(self.left[node])
624                s += ls
625                h = max(h, lh)
626            if self.right[node]:
627                rs, rh = dfs(self.right[node])
628                s += rs
629                h = max(h, rh)
630            assert self.size[node] == s
631            return s, h + 1
632
633        _, h = dfs(self.root)
634        return h
635
636    def __contains__(self, key: T) -> bool:
637        keys, left, right = self.key, self.left, self.right
638        node = self.root
639        while node:
640            if key == keys[node]:
641                return True
642            node = left[node] if key < keys[node] else right[node]
643        return False
644
645    def __getitem__(self, k: int) -> T:
646        left, right, size, key = self.left, self.right, self.size, self.key
647        if k < 0:
648            k += size[self.root]
649        assert (
650            0 <= k and k < size[self.root]
651        ), f"IndexError: AVLTreeSet[{k}], len={len(self)}"
652        node = self.root
653        while True:
654            t = size[left[node]]
655            if t == k:
656                return key[node]
657            if t < k:
658                k -= t + 1
659                node = right[node]
660            else:
661                node = left[node]
662
663    def __iter__(self):
664        self.__iter = 0
665        return self
666
667    def __next__(self):
668        if self.__iter == self.__len__():
669            raise StopIteration
670        res = self[self.__iter]
671        self.__iter += 1
672        return res
673
674    def __reversed__(self):
675        for i in range(self.__len__()):
676            yield self[-i - 1]
677
678    def __len__(self):
679        return self.size[self.root]
680
681    def __bool__(self):
682        return self.root != 0
683
684    def __str__(self):
685        return "{" + ", ".join(map(str, self.tolist())) + "}"
686
687    def __repr__(self):
688        return f"AVLTreeSet({self})"

仕様

class AVLTreeSet(a: Iterable[T] = [])[source]

Bases: OrderedSetInterface, Generic[T]

集合としての AVL 木です。 配列を用いてノードを表現しています。 size を持ちます。

add(key: T) bool[source]
ave_height()[source]
clear() None[source]
discard(key: T) bool[source]
ge(key: T) T | None[source]
get_max() T | None[source]
get_min() T | None[source]
gt(key: T) T | None[source]
index(key: T) int[source]
index_right(key: T) int[source]
le(key: T) T | None[source]
lt(key: T) T | None[source]
pop(k: int = -1) T[source]
pop_max() T[source]
pop_min() T[source]
remove(key: T) bool[source]
reserve(n: int) None[source]
tolist() list[T][source]