avl_tree_set3

ソースコード

from titan_pylib.data_structures.avl_tree.avl_tree_set3 import AVLTreeSet3

view on github

展開済みコード

  1# from titan_pylib.data_structures.avl_tree.avl_tree_set3 import AVLTreeSet3
  2# from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
  3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  4from typing import Protocol
  5
  6
  7class SupportsLessThan(Protocol):
  8
  9    def __lt__(self, other) -> bool: ...
 10from abc import ABC, abstractmethod
 11from typing import Iterable, Optional, Iterator, TypeVar, Generic
 12
 13T = TypeVar("T", bound=SupportsLessThan)
 14
 15
 16class OrderedSetInterface(ABC, Generic[T]):
 17
 18    @abstractmethod
 19    def __init__(self, a: Iterable[T]) -> None:
 20        raise NotImplementedError
 21
 22    @abstractmethod
 23    def add(self, key: T) -> bool:
 24        raise NotImplementedError
 25
 26    @abstractmethod
 27    def discard(self, key: T) -> bool:
 28        raise NotImplementedError
 29
 30    @abstractmethod
 31    def remove(self, key: T) -> None:
 32        raise NotImplementedError
 33
 34    @abstractmethod
 35    def le(self, key: T) -> Optional[T]:
 36        raise NotImplementedError
 37
 38    @abstractmethod
 39    def lt(self, key: T) -> Optional[T]:
 40        raise NotImplementedError
 41
 42    @abstractmethod
 43    def ge(self, key: T) -> Optional[T]:
 44        raise NotImplementedError
 45
 46    @abstractmethod
 47    def gt(self, key: T) -> Optional[T]:
 48        raise NotImplementedError
 49
 50    @abstractmethod
 51    def get_max(self) -> Optional[T]:
 52        raise NotImplementedError
 53
 54    @abstractmethod
 55    def get_min(self) -> Optional[T]:
 56        raise NotImplementedError
 57
 58    @abstractmethod
 59    def pop_max(self) -> T:
 60        raise NotImplementedError
 61
 62    @abstractmethod
 63    def pop_min(self) -> T:
 64        raise NotImplementedError
 65
 66    @abstractmethod
 67    def clear(self) -> None:
 68        raise NotImplementedError
 69
 70    @abstractmethod
 71    def tolist(self) -> list[T]:
 72        raise NotImplementedError
 73
 74    @abstractmethod
 75    def __iter__(self) -> Iterator:
 76        raise NotImplementedError
 77
 78    @abstractmethod
 79    def __next__(self) -> T:
 80        raise NotImplementedError
 81
 82    @abstractmethod
 83    def __contains__(self, key: T) -> bool:
 84        raise NotImplementedError
 85
 86    @abstractmethod
 87    def __len__(self) -> int:
 88        raise NotImplementedError
 89
 90    @abstractmethod
 91    def __bool__(self) -> bool:
 92        raise NotImplementedError
 93
 94    @abstractmethod
 95    def __str__(self) -> str:
 96        raise NotImplementedError
 97
 98    @abstractmethod
 99    def __repr__(self) -> str:
100        raise NotImplementedError
101# from titan_pylib.my_class.supports_less_than import SupportsLessThan
102from typing import Generic, Iterable, TypeVar, Optional, Sequence
103
104T = TypeVar("T", bound=SupportsLessThan)
105
106
107class AVLTreeSet3(OrderedSetInterface, Generic[T]):
108    """
109    集合としての AVL木 です。
110    size を持ちます。
111    ``class Node()`` を用いています。
112    """
113
114    class Node:
115
116        def __init__(self, key: T):
117            self.key: T = key
118            self.size: int = 1
119            self.left: Optional["AVLTreeSet3.Node"] = None
120            self.right: Optional["AVLTreeSet3.Node"] = None
121            self.balance: int = 0
122
123        def __str__(self):
124            if self.left is None and self.right is None:
125                return f"key:{self.key, self.size}\n"
126            return (
127                f"key:{self.key, self.size},\n left:{self.left},\n right:{self.right}\n"
128            )
129
130    def __init__(self, a: Iterable[T] = []) -> None:
131        self.node = None
132        if not isinstance(a, Sequence):
133            a = list(a)
134        if a:
135            self._build(a)
136
137    def _build(self, a: Sequence[T]) -> None:
138        Node = AVLTreeSet3.Node
139
140        def rec(l: int, r: int) -> tuple[AVLTreeSet3.Node, int]:
141            mid = (l + r) >> 1
142            node = Node(a[mid])
143            hl, hr = 0, 0
144            if l != mid:
145                node.left, hl = rec(l, mid)
146                node.size += node.left.size
147            if mid + 1 != r:
148                node.right, hr = rec(mid + 1, r)
149                node.size += node.right.size
150            node.balance = hl - hr
151            return node, max(hl, hr) + 1
152
153        if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
154            a = sorted(set(a))
155        self.node = rec(0, len(a))[0]
156
157    def _rotate_L(self, node: Node) -> Node:
158        u = node.left
159        u.size = node.size
160        node.size -= 1 if u.left is None else u.left.size + 1
161        node.left = u.right
162        u.right = node
163        if u.balance == 1:
164            u.balance = 0
165            node.balance = 0
166        else:
167            u.balance = -1
168            node.balance = 1
169        return u
170
171    def _rotate_R(self, node: Node) -> Node:
172        u = node.right
173        u.size = node.size
174        node.size -= 1 if u.right is None else u.right.size + 1
175        node.right = u.left
176        u.left = node
177        if u.balance == -1:
178            u.balance = 0
179            node.balance = 0
180        else:
181            u.balance = 1
182            node.balance = -1
183        return u
184
185    def _update_balance(self, node: Node) -> None:
186        if node.balance == 1:
187            node.right.balance = -1
188            node.left.balance = 0
189        elif node.balance == -1:
190            node.right.balance = 0
191            node.left.balance = 1
192        else:
193            node.right.balance = 0
194            node.left.balance = 0
195        node.balance = 0
196
197    def _rotate_LR(self, node: Node) -> Node:
198        B = node.left
199        E = B.right
200        E.size = node.size
201        if E.right is None:
202            node.size -= B.size
203            B.size -= 1
204        else:
205            node.size -= B.size - E.right.size
206            B.size -= E.right.size + 1
207        B.right = E.left
208        E.left = B
209        node.left = E.right
210        E.right = node
211        self._update_balance(E)
212        return E
213
214    def _rotate_RL(self, node: Node) -> Node:
215        C = node.right
216        D = C.left
217        D.size = node.size
218        if D.left is None:
219            node.size -= C.size
220            C.size -= 1
221        else:
222            node.size -= C.size - D.left.size
223            C.size -= D.left.size + 1
224        C.left = D.right
225        D.right = C
226        node.right = D.left
227        D.left = node
228        self._update_balance(D)
229        return D
230
231    def _kth_elm(self, k: int) -> T:
232        if k < 0:
233            k += self.node.size
234        node = self.node
235        while True:
236            t = 0 if node.left is None else node.left.size
237            if t == k:
238                return node.key
239            elif t < k:
240                k -= t + 1
241                node = node.right
242            else:
243                node = node.left
244
245    def add(self, key: T) -> bool:
246        if self.node is None:
247            self.node = AVLTreeSet3.Node(key)
248            return True
249        pnode = self.node
250        path = []
251        di = 0
252        while pnode is not None:
253            if key == pnode.key:
254                return False
255            elif key < pnode.key:
256                path.append(pnode)
257                di <<= 1
258                di |= 1
259                pnode = pnode.left
260            else:
261                path.append(pnode)
262                di <<= 1
263                pnode = pnode.right
264        if di & 1:
265            path[-1].left = AVLTreeSet3.Node(key)
266        else:
267            path[-1].right = AVLTreeSet3.Node(key)
268        new_node = None
269        while path:
270            pnode = path.pop()
271            pnode.size += 1
272            pnode.balance += 1 if di & 1 else -1
273            di >>= 1
274            if pnode.balance == 0:
275                break
276            if pnode.balance == 2:
277                new_node = (
278                    self._rotate_LR(pnode)
279                    if pnode.left.balance == -1
280                    else self._rotate_L(pnode)
281                )
282                break
283            elif pnode.balance == -2:
284                new_node = (
285                    self._rotate_RL(pnode)
286                    if pnode.right.balance == 1
287                    else self._rotate_R(pnode)
288                )
289                break
290        if new_node is not None:
291            if path:
292                gnode = path.pop()
293                gnode.size += 1
294                if di & 1:
295                    gnode.left = new_node
296                else:
297                    gnode.right = new_node
298            else:
299                self.node = new_node
300        for p in path:
301            p.size += 1
302        return True
303
304    def discard(self, key: T) -> bool:
305        di = 0
306        path = []
307        node = self.node
308        while node:
309            if key == node.key:
310                break
311            elif key < node.key:
312                path.append(node)
313                di <<= 1
314                di |= 1
315                node = node.left
316            else:
317                path.append(node)
318                di <<= 1
319                node = node.right
320        else:
321            return False
322        if node.left and node.right:
323            path.append(node)
324            di <<= 1
325            di |= 1
326            lmax = node.left
327            while lmax.right:
328                path.append(lmax)
329                di <<= 1
330                lmax = lmax.right
331            node.key = lmax.key
332            node = lmax
333        cnode = node.right if node.left is None else node.left
334        if path:
335            if di & 1:
336                path[-1].left = cnode
337            else:
338                path[-1].right = cnode
339        else:
340            self.node = cnode
341            return True
342        while path:
343            new_node = None
344            pnode = path.pop()
345            pnode.balance -= 1 if di & 1 else -1
346            di >>= 1
347            pnode.size -= 1
348            if pnode.balance == 2:
349                new_node = (
350                    self._rotate_LR(pnode)
351                    if pnode.left.balance == -1
352                    else self._rotate_L(pnode)
353                )
354            elif pnode.balance == -2:
355                new_node = (
356                    self._rotate_RL(pnode)
357                    if pnode.right.balance == 1
358                    else self._rotate_R(pnode)
359                )
360            elif pnode.balance != 0:
361                break
362            if new_node:
363                if not path:
364                    self.node = new_node
365                    return True
366                if di & 1:
367                    path[-1].left = new_node
368                else:
369                    path[-1].right = new_node
370                if new_node.balance != 0:
371                    break
372        for p in path:
373            p.size -= 1
374        return True
375
376    def remove(self, key: T) -> None:
377        if self.discard(key):
378            return
379        raise KeyError(key)
380
381    def le(self, key: T) -> Optional[T]:
382        res = None
383        node = self.node
384        while node is not None:
385            if key == node.key:
386                res = key
387                break
388            elif key < node.key:
389                node = node.left
390            else:
391                res = node.key
392                node = node.right
393        return res
394
395    def lt(self, key: T) -> Optional[T]:
396        res = None
397        node = self.node
398        while node is not None:
399            if key <= node.key:
400                node = node.left
401            else:
402                res = node.key
403                node = node.right
404        return res
405
406    def ge(self, key: T) -> Optional[T]:
407        res = None
408        node = self.node
409        while node is not None:
410            if key == node.key:
411                res = key
412                break
413            elif key < node.key:
414                res = node.key
415                node = node.left
416            else:
417                node = node.right
418        return res
419
420    def gt(self, key: T) -> Optional[T]:
421        res = None
422        node = self.node
423        while node is not None:
424            if key < node.key:
425                res = node.key
426                node = node.left
427            else:
428                node = node.right
429        return res
430
431    def index(self, key: T) -> int:
432        k = 0
433        node = self.node
434        while node is not None:
435            if key == node.key:
436                k += 0 if node.left is None else node.left.size
437                break
438            elif key < node.key:
439                node = node.left
440            else:
441                k += 1 if node.left is None else node.left.size + 1
442                node = node.right
443        return k
444
445    def index_right(self, key: T) -> int:
446        k = 0
447        node = self.node
448        while node is not None:
449            if key == node.key:
450                k += 1 if node.left is None else node.left.size + 1
451                break
452            elif key < node.key:
453                node = node.left
454            else:
455                k += 1 if node.left is None else node.left.size + 1
456                node = node.right
457        return k
458
459    def pop(self, k: int = -1) -> T:
460        assert (
461            self.node is not None
462        ), f"IndexError: {self.__class__.__name__}.pop({k}), pop({k}) from Empty {self.__class__.__name__}"
463        x = self._kth_elm(k)
464        self.discard(x)
465        return x
466
467    def pop_max(self) -> T:
468        assert (
469            self.node is not None
470        ), f"IndexError: {self.__class__.__name__}.pop_max(), pop_max from Empty {self.__class__.__name__}"
471        return self.pop()
472
473    def pop_min(self) -> T:
474        assert (
475            self.node is not None
476        ), f"IndexError: {self.__class__.__name__}.pop_min(), pop_min from Empty {self.__class__.__name__}"
477        return self.pop(0)
478
479    def get_max(self) -> Optional[T]:
480        if self.node is None:
481            return
482        return self._kth_elm(-1)
483
484    def get_min(self) -> Optional[T]:
485        if self.node is None:
486            return
487        return self._kth_elm(0)
488
489    def clear(self) -> None:
490        self.node = None
491
492    def tolist(self) -> list[T]:
493        a = []
494        if self.node is None:
495            return a
496
497        def rec(node):
498            if node.left is not None:
499                rec(node.left)
500            a.append(node.key)
501            if node.right is not None:
502                rec(node.right)
503
504        rec(self.node)
505        return a
506
507    def __contains__(self, key: T) -> bool:
508        node = self.node
509        while node is not None:
510            if key == node.key:
511                return True
512            elif key < node.key:
513                node = node.left
514            else:
515                node = node.right
516        return False
517
518    def __getitem__(self, k: int) -> T:
519        assert (
520            -len(self) <= k < len(self)
521        ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), len={len(self)}"
522        return self._kth_elm(k)
523
524    def __iter__(self):
525        self.__iter = 0
526        return self
527
528    def __next__(self):
529        if self.__iter == self.__len__():
530            raise StopIteration
531        res = self.__getitem__(self.__iter)
532        self.__iter += 1
533        return res
534
535    def __reversed__(self):
536        for i in range(self.__len__()):
537            yield self.__getitem__(-i - 1)
538
539    def __len__(self):
540        return 0 if self.node is None else self.node.size
541
542    def __bool__(self):
543        return self.node is not None
544
545    def __str__(self):
546        return "{" + ", ".join(map(str, self.tolist())) + "}"
547
548    def __repr__(self):
549        return f"AVLTreeSet3({str(self)})"

仕様

class AVLTreeSet3(a: Iterable[T] = [])[source]

Bases: OrderedSetInterface, Generic[T]

集合としての AVL木 です。 size を持ちます。 class Node() を用いています。

class Node(key: T)[source]

Bases: object

add(key: T) bool[source]
clear() None[source]
discard(key: T) bool[source]
ge(key: T) T | None[source]
get_max() T | None[source]
get_min() T | None[source]
gt(key: T) T | None[source]
index(key: T) int[source]
index_right(key: T) int[source]
le(key: T) T | None[source]
lt(key: T) T | None[source]
pop(k: int = -1) T[source]
pop_max() T[source]
pop_min() T[source]
remove(key: T) None[source]
tolist() list[T][source]