avl_tree_set2

ソースコード

from titan_pylib.data_structures.avl_tree.avl_tree_set2 import AVLTreeSet2

view on github

展開済みコード

  1# from titan_pylib.data_structures.avl_tree.avl_tree_set2 import AVLTreeSet2
  2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  3from typing import Protocol
  4
  5
  6class SupportsLessThan(Protocol):
  7
  8    def __lt__(self, other) -> bool: ...
  9# from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
 10# from titan_pylib.my_class.supports_less_than import SupportsLessThan
 11from abc import ABC, abstractmethod
 12from typing import Iterable, Optional, Iterator, TypeVar, Generic
 13
 14T = TypeVar("T", bound=SupportsLessThan)
 15
 16
 17class OrderedSetInterface(ABC, Generic[T]):
 18
 19    @abstractmethod
 20    def __init__(self, a: Iterable[T]) -> None:
 21        raise NotImplementedError
 22
 23    @abstractmethod
 24    def add(self, key: T) -> bool:
 25        raise NotImplementedError
 26
 27    @abstractmethod
 28    def discard(self, key: T) -> bool:
 29        raise NotImplementedError
 30
 31    @abstractmethod
 32    def remove(self, key: T) -> None:
 33        raise NotImplementedError
 34
 35    @abstractmethod
 36    def le(self, key: T) -> Optional[T]:
 37        raise NotImplementedError
 38
 39    @abstractmethod
 40    def lt(self, key: T) -> Optional[T]:
 41        raise NotImplementedError
 42
 43    @abstractmethod
 44    def ge(self, key: T) -> Optional[T]:
 45        raise NotImplementedError
 46
 47    @abstractmethod
 48    def gt(self, key: T) -> Optional[T]:
 49        raise NotImplementedError
 50
 51    @abstractmethod
 52    def get_max(self) -> Optional[T]:
 53        raise NotImplementedError
 54
 55    @abstractmethod
 56    def get_min(self) -> Optional[T]:
 57        raise NotImplementedError
 58
 59    @abstractmethod
 60    def pop_max(self) -> T:
 61        raise NotImplementedError
 62
 63    @abstractmethod
 64    def pop_min(self) -> T:
 65        raise NotImplementedError
 66
 67    @abstractmethod
 68    def clear(self) -> None:
 69        raise NotImplementedError
 70
 71    @abstractmethod
 72    def tolist(self) -> list[T]:
 73        raise NotImplementedError
 74
 75    @abstractmethod
 76    def __iter__(self) -> Iterator:
 77        raise NotImplementedError
 78
 79    @abstractmethod
 80    def __next__(self) -> T:
 81        raise NotImplementedError
 82
 83    @abstractmethod
 84    def __contains__(self, key: T) -> bool:
 85        raise NotImplementedError
 86
 87    @abstractmethod
 88    def __len__(self) -> int:
 89        raise NotImplementedError
 90
 91    @abstractmethod
 92    def __bool__(self) -> bool:
 93        raise NotImplementedError
 94
 95    @abstractmethod
 96    def __str__(self) -> str:
 97        raise NotImplementedError
 98
 99    @abstractmethod
100    def __repr__(self) -> str:
101        raise NotImplementedError
102from array import array
103from typing import Generic, Iterable, TypeVar, Optional
104
105T = TypeVar("T", bound=SupportsLessThan)
106
107
108class AVLTreeSet2(OrderedSetInterface, Generic[T]):
109    """AVLTreeSet2
110    集合としての AVL 木です。
111    配列を用いてノードを表現しています。
112    size を持たないので軽めです。
113    """
114
115    def __init__(self, a: Iterable[T] = []) -> None:
116        self.root = 0
117        self._len = 0
118        self.key = [0]
119        self.left = array("I", bytes(4))
120        self.right = array("I", bytes(4))
121        self.balance = array("b", bytes(1))
122        self.end = 1
123        if not isinstance(a, list):
124            a = list(a)
125        if a:
126            self._build(a)
127
128    def reserve(self, n: int) -> None:
129        self.key += [0] * n
130        a = array("I", bytes(4 * n))
131        self.left += a
132        self.right += a
133        self.balance += array("b", bytes(n))
134
135    def _build(self, a: list[T]) -> None:
136        left, right, balance = self.left, self.right, self.balance
137
138        def sort(l: int, r: int) -> tuple[int, int]:
139            mid = (l + r) >> 1
140            node = mid
141            hl, hr = 0, 0
142            if l != mid:
143                left[node], hl = sort(l, mid)
144            if mid + 1 != r:
145                right[node], hr = sort(mid + 1, r)
146            balance[node] = hl - hr
147            return node, max(hl, hr) + 1
148
149        n = len(a)
150        if n == 0:
151            return
152        if not all(a[i] < a[i + 1] for i in range(n - 1)):
153            b = sorted(a)
154            a = [b[0]]
155            for i in range(1, n):
156                if b[i] != a[-1]:
157                    a.append(b[i])
158            n = len(a)
159        self._len = n
160        end = self.end
161        self.end += n
162        self.reserve(n)
163        self.key[end : end + n] = a
164        self.root = sort(end, n + end)[0]
165
166    def _rotate_L(self, node: int) -> int:
167        left, right, balance = self.left, self.right, self.balance
168        u = left[node]
169        left[node] = right[u]
170        right[u] = node
171        if balance[u] == 1:
172            balance[u] = 0
173            balance[node] = 0
174        else:
175            balance[u] = -1
176            balance[node] = 1
177        return u
178
179    def _rotate_R(self, node: int) -> int:
180        left, right, balance = self.left, self.right, self.balance
181        u = right[node]
182        right[node] = left[u]
183        left[u] = node
184        if balance[u] == -1:
185            balance[u] = 0
186            balance[node] = 0
187        else:
188            balance[u] = 1
189            balance[node] = -1
190        return u
191
192    def _update_balance(self, node: int) -> None:
193        balance = self.balance
194        if balance[node] == 1:
195            balance[self.right[node]] = -1
196            balance[self.left[node]] = 0
197        elif balance[node] == -1:
198            balance[self.right[node]] = 0
199            balance[self.left[node]] = 1
200        else:
201            balance[self.right[node]] = 0
202            balance[self.left[node]] = 0
203        balance[node] = 0
204
205    def _rotate_LR(self, node: int) -> int:
206        left, right = self.left, self.right
207        B = left[node]
208        E = right[B]
209        right[B] = left[E]
210        left[E] = B
211        left[node] = right[E]
212        right[E] = node
213        self._update_balance(E)
214        return E
215
216    def _rotate_RL(self, node: int) -> int:
217        left, right = self.left, self.right
218        C = right[node]
219        D = left[C]
220        left[C] = right[D]
221        right[D] = C
222        right[node] = left[D]
223        left[D] = node
224        self._update_balance(D)
225        return D
226
227    def _make_node(self, key: T) -> int:
228        end = self.end
229        if end >= len(self.key):
230            self.key.append(key)
231            self.left.append(0)
232            self.right.append(0)
233            self.balance.append(0)
234        else:
235            self.key[end] = key
236        self.end += 1
237        return end
238
239    def add(self, key: T) -> bool:
240        if self.root == 0:
241            self.root = self._make_node(key)
242            self._len = 1
243            return True
244        left, right, balance, keys = self.left, self.right, self.balance, self.key
245        node = self.root
246        path = []
247        di = 0
248        while node:
249            if key == keys[node]:
250                return False
251            di <<= 1
252            path.append(node)
253            if key < keys[node]:
254                di |= 1
255                node = left[node]
256            else:
257                node = right[node]
258        self._len += 1
259        if di & 1:
260            left[path[-1]] = self._make_node(key)
261        else:
262            right[path[-1]] = self._make_node(key)
263        new_node = 0
264        while path:
265            pnode = path.pop()
266            balance[pnode] += 1 if di & 1 else -1
267            di >>= 1
268            if balance[pnode] == 0:
269                break
270            if balance[pnode] == 2:
271                new_node = (
272                    self._rotate_LR(pnode)
273                    if balance[left[pnode]] == -1
274                    else self._rotate_L(pnode)
275                )
276                break
277            elif balance[pnode] == -2:
278                new_node = (
279                    self._rotate_RL(pnode)
280                    if balance[right[pnode]] == 1
281                    else self._rotate_R(pnode)
282                )
283                break
284        if new_node:
285            if path:
286                gnode = path.pop()
287                if di & 1:
288                    left[gnode] = new_node
289                else:
290                    right[gnode] = new_node
291            else:
292                self.root = new_node
293        return True
294
295    def remove(self, key: T) -> bool:
296        if self.discard(key):
297            return True
298        raise KeyError(key)
299
300    def discard(self, key: T) -> bool:
301        left, right, balance, keys = self.left, self.right, self.balance, self.key
302        di = 0
303        path = []
304        node = self.root
305        while node:
306            if key == keys[node]:
307                break
308            path.append(node)
309            di <<= 1
310            if key < keys[node]:
311                di |= 1
312                node = left[node]
313            else:
314                node = right[node]
315        else:
316            return False
317        self._len -= 1
318        if left[node] and right[node]:
319            path.append(node)
320            di <<= 1
321            di |= 1
322            lmax = left[node]
323            while right[lmax]:
324                path.append(lmax)
325                di <<= 1
326                lmax = right[lmax]
327            keys[node] = keys[lmax]
328            node = lmax
329        cnode = right[node] if left[node] == 0 else left[node]
330        if path:
331            if di & 1:
332                left[path[-1]] = cnode
333            else:
334                right[path[-1]] = cnode
335        else:
336            self.root = cnode
337            return True
338        while path:
339            new_node = 0
340            pnode = path.pop()
341            balance[pnode] -= 1 if di & 1 else -1
342            di >>= 1
343            if balance[pnode] == 2:
344                new_node = (
345                    self._rotate_LR(pnode)
346                    if balance[left[pnode]] == -1
347                    else self._rotate_L(pnode)
348                )
349            elif balance[pnode] == -2:
350                new_node = (
351                    self._rotate_RL(pnode)
352                    if balance[right[pnode]] == 1
353                    else self._rotate_R(pnode)
354                )
355            elif balance[pnode]:
356                break
357            if new_node:
358                if not path:
359                    self.root = new_node
360                    return True
361                if di & 1:
362                    left[path[-1]] = new_node
363                else:
364                    right[path[-1]] = new_node
365                if balance[new_node]:
366                    break
367        return True
368
369    def le(self, key: T) -> Optional[T]:
370        keys, left, right = self.key, self.left, self.right
371        res = None
372        node = self.root
373        while node:
374            if key == keys[node]:
375                res = key
376                break
377            if key < keys[node]:
378                node = left[node]
379            else:
380                res = keys[node]
381                node = right[node]
382        return res
383
384    def lt(self, key: T) -> Optional[T]:
385        keys, left, right = self.key, self.left, self.right
386        res = None
387        node = self.root
388        while node:
389            if key <= keys[node]:
390                node = left[node]
391            else:
392                res = keys[node]
393                node = right[node]
394        return res
395
396    def ge(self, key: T) -> Optional[T]:
397        keys, left, right = self.key, self.left, self.right
398        res = None
399        node = self.root
400        while node:
401            if key == keys[node]:
402                res = key
403                break
404            if key < keys[node]:
405                res = keys[node]
406                node = left[node]
407            else:
408                node = right[node]
409        return res
410
411    def gt(self, key: T) -> Optional[T]:
412        keys, left, right = self.key, self.left, self.right
413        res = None
414        node = self.root
415        while node:
416            if key < keys[node]:
417                res = keys[node]
418                node = left[node]
419            else:
420                node = right[node]
421        return res
422
423    def get_max(self) -> Optional[T]:
424        if not self:
425            return
426        right = self.right
427        node = self.root
428        while right[node]:
429            node = right[node]
430        return self.key[node]
431
432    def get_min(self) -> Optional[T]:
433        if not self:
434            return
435        left = self.left
436        node = self.root
437        while left[node]:
438            node = left[node]
439        return self.key[node]
440
441    def pop_min(self) -> T:
442        self._len -= 1
443        left, right, balance, keys = self.left, self.right, self.balance, self.key
444        path = []
445        node = self.root
446        while left[node]:
447            path.append(node)
448            node = left[node]
449        res = keys[node]
450        cnode = right[node]
451        if path:
452            left[path[-1]] = cnode
453        else:
454            self.root = cnode
455            return res
456        while path:
457            new_node = 0
458            pnode = path.pop()
459            balance[pnode] -= 1
460            if balance[pnode] == 2:
461                new_node = (
462                    self._rotate_LR(pnode)
463                    if balance[left[pnode]] == -1
464                    else self._rotate_L(pnode)
465                )
466            elif balance[pnode] == -2:
467                new_node = (
468                    self._rotate_RL(pnode)
469                    if balance[right[pnode]] == 1
470                    else self._rotate_R(pnode)
471                )
472            elif balance[pnode]:
473                break
474            if new_node:
475                if not path:
476                    self.root = new_node
477                    return res
478                left[path[-1]] = new_node
479                if balance[new_node]:
480                    break
481        return res
482
483    def pop_max(self) -> T:
484        self._len -= 1
485        left, right, balance, keys = self.left, self.right, self.balance, self.key
486        path = []
487        node = self.root
488        while right[node]:
489            path.append(node)
490            node = right[node]
491        res = keys[node]
492        cnode = right[node] if left[node] == 0 else left[node]
493        if path:
494            right[path[-1]] = cnode
495        else:
496            self.root = cnode
497            return res
498        while path:
499            new_node = 0
500            pnode = path.pop()
501            balance[pnode] += 1
502            if balance[pnode] == 2:
503                new_node = (
504                    self._rotate_LR(pnode)
505                    if balance[left[pnode]] == -1
506                    else self._rotate_L(pnode)
507                )
508            elif balance[pnode] == -2:
509                new_node = (
510                    self._rotate_RL(pnode)
511                    if balance[right[pnode]] == 1
512                    else self._rotate_R(pnode)
513                )
514            elif balance[pnode]:
515                break
516            if new_node:
517                if not path:
518                    self.root = new_node
519                    return res
520                right[path[-1]] = new_node
521                if balance[new_node]:
522                    break
523        return res
524
525    def clear(self) -> None:
526        self.root = 0
527
528    def tolist(self) -> list[T]:
529        left, right, keys = self.left, self.right, self.key
530        node = self.root
531        stack, a = [], []
532        while stack or node:
533            if node:
534                stack.append(node)
535                node = left[node]
536            else:
537                node = stack.pop()
538                a.append(keys[node])
539                node = right[node]
540        return a
541
542    def __contains__(self, key: T) -> bool:
543        keys, left, right = self.key, self.left, self.right
544        node = self.root
545        while node:
546            if key == keys[node]:
547                return True
548            node = left[node] if key < keys[node] else right[node]
549        return False
550
551    def __iter__(self):
552        self.it = self.get_min()
553        return self
554
555    def __next__(self):
556        if self.it is None:
557            raise StopIteration
558        res = self.it
559        self.it = self.gt(res)
560        return res
561
562    def __len__(self):
563        return self._len
564
565    def __bool__(self):
566        return self.root != 0
567
568    def __str__(self):
569        return "{" + ", ".join(map(str, self.tolist())) + "}"
570
571    def __repr__(self):
572        return f"AVLTreeSet2({self})"

仕様

class AVLTreeSet2(a: Iterable[T] = [])[source]

Bases: OrderedSetInterface, Generic[T]

集合としての AVL 木です。 配列を用いてノードを表現しています。 size を持たないので軽めです。

add(key: T) bool[source]
clear() None[source]
discard(key: T) bool[source]
ge(key: T) T | None[source]
get_max() T | None[source]
get_min() T | None[source]
gt(key: T) T | None[source]
le(key: T) T | None[source]
lt(key: T) T | None[source]
pop_max() T[source]
pop_min() T[source]
remove(key: T) bool[source]
reserve(n: int) None[source]
tolist() list[T][source]