avl_tree_multiset2

ソースコード

from titan_pylib.data_structures.avl_tree.avl_tree_multiset2 import AVLTreeMultiset2

view on github

展開済みコード

  1# from titan_pylib.data_structures.avl_tree.avl_tree_multiset2 import AVLTreeMultiset2
  2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  3from typing import Protocol
  4
  5
  6class SupportsLessThan(Protocol):
  7
  8    def __lt__(self, other) -> bool: ...
  9# from titan_pylib.data_structures.bst_base.bst_multiset_array_base import (
 10#     BSTMultisetArrayBase,
 11# )
 12from typing import TypeVar, Generic, Optional
 13
 14T = TypeVar("T")
 15BST = TypeVar("BST")
 16# protcolで、key,val,left,right を規定
 17
 18
 19class BSTMultisetArrayBase(Generic[BST, T]):
 20
 21    @staticmethod
 22    def _rle(a: list[T]) -> tuple[list[T], list[int]]:
 23        keys, vals = [a[0]], [1]
 24        for i, elm in enumerate(a):
 25            if i == 0:
 26                continue
 27            if elm == keys[-1]:
 28                vals[-1] += 1
 29                continue
 30            keys.append(elm)
 31            vals.append(1)
 32        return keys, vals
 33
 34    @staticmethod
 35    def count(bst: BST, key: T) -> int:
 36        keys, left, right = bst.key, bst.left, bst.right
 37        node = bst.root
 38        while node:
 39            if keys[node] == key:
 40                return bst.val[node]
 41            node = left[node] if key < keys[node] else right[node]
 42        return 0
 43
 44    @staticmethod
 45    def le(bst: BST, key: T) -> Optional[T]:
 46        keys, left, right = bst.key, bst.left, bst.right
 47        res = None
 48        node = bst.root
 49        while node:
 50            if key == keys[node]:
 51                res = key
 52                break
 53            if key < keys[node]:
 54                node = left[node]
 55            else:
 56                res = keys[node]
 57                node = right[node]
 58        return res
 59
 60    @staticmethod
 61    def lt(bst: BST, key: T) -> Optional[T]:
 62        keys, left, right = bst.key, bst.left, bst.right
 63        res = None
 64        node = bst.root
 65        while node:
 66            if key <= keys[node]:
 67                node = left[node]
 68            else:
 69                res = keys[node]
 70                node = right[node]
 71        return res
 72
 73    @staticmethod
 74    def ge(bst: BST, key: T) -> Optional[T]:
 75        keys, left, right = bst.key, bst.left, bst.right
 76        res = None
 77        node = bst.root
 78        while node:
 79            if key == keys[node]:
 80                res = key
 81                break
 82            if key < keys[node]:
 83                res = keys[node]
 84                node = left[node]
 85            else:
 86                node = right[node]
 87        return res
 88
 89    @staticmethod
 90    def gt(bst: BST, key: T) -> Optional[T]:
 91        keys, left, right = bst.key, bst.left, bst.right
 92        res = None
 93        node = bst.root
 94        while node:
 95            if key < keys[node]:
 96                res = keys[node]
 97                node = left[node]
 98            else:
 99                node = right[node]
100        return res
101
102    @staticmethod
103    def index(bst: BST, key: T) -> int:
104        keys, left, right, vals, valsize = (
105            bst.key,
106            bst.left,
107            bst.right,
108            bst.val,
109            bst.valsize,
110        )
111        k = 0
112        node = bst.root
113        while node:
114            if key == keys[node]:
115                if left[node]:
116                    k += valsize[left[node]]
117                break
118            if key < keys[node]:
119                node = left[node]
120            else:
121                k += valsize[left[node]] + vals[node]
122                node = right[node]
123        return k
124
125    @staticmethod
126    def index_right(bst: BST, key: T) -> int:
127        keys, left, right, vals, valsize = (
128            bst.key,
129            bst.left,
130            bst.right,
131            bst.val,
132            bst.valsize,
133        )
134        k = 0
135        node = bst.root
136        while node:
137            if key == keys[node]:
138                k += valsize[left[node]] + vals[node]
139                break
140            if key < keys[node]:
141                node = left[node]
142            else:
143                k += valsize[left[node]] + vals[node]
144                node = right[node]
145        return k
146
147    @staticmethod
148    def _kth_elm(bst: BST, k: int) -> tuple[T, int]:
149        left, right, vals, valsize = bst.left, bst.right, bst.val, bst.valsize
150        if k < 0:
151            k += len(bst)
152        node = bst.root
153        while True:
154            t = vals[node] + valsize[left[node]]
155            if t - vals[node] <= k < t:
156                return bst.key[node], vals[node]
157            if t > k:
158                node = left[node]
159            else:
160                node = right[node]
161                k -= t
162
163    @staticmethod
164    def contains(bst: BST, key: T) -> bool:
165        left, right, keys = bst.left, bst.right, bst.key
166        node = bst.root
167        while node:
168            if keys[node] == key:
169                return True
170            node = left[node] if key < keys[node] else right[node]
171        return False
172
173    @staticmethod
174    def tolist(bst: BST) -> list[T]:
175        left, right, keys, vals = bst.left, bst.right, bst.key, bst.val
176        node = bst.root
177        stack, a = [], []
178        while stack or node:
179            if node:
180                stack.append(node)
181                node = left[node]
182            else:
183                node = stack.pop()
184                key = keys[node]
185                for _ in range(vals[node]):
186                    a.append(key)
187                node = right[node]
188        return a
189from typing import Generic, Iterable, TypeVar, Optional
190from array import array
191
192T = TypeVar("T", bound=SupportsLessThan)
193
194
195class AVLTreeMultiset2(Generic[T]):
196    """
197    多重集合としての AVL 木です。
198    配列を用いてノードを表現しています。
199    size を持たないので軽めです。
200    """
201
202    def __init__(self, a: Iterable[T] = []):
203        self.root = 0
204        self._len = 0
205        self.key = [0]
206        self.val = [0]
207        self.left = array("I", bytes(4))
208        self.right = array("I", bytes(4))
209        self.balance = array("b", bytes(1))
210        self.end = 1
211        if not isinstance(a, list):
212            a = list(a)
213        if a:
214            self._build(a)
215
216    def _make_node(self, key: T, val: int) -> int:
217        end = self.end
218        if end >= len(self.key):
219            self.key.append(key)
220            self.val.append(val)
221            self.left.append(0)
222            self.right.append(0)
223            self.balance.append(0)
224        else:
225            self.key[end] = key
226            self.val[end] = val
227        self.end += 1
228        return end
229
230    def reserve(self, n: int) -> None:
231        a = [0] * n
232        self.key += a
233        self.val += a
234        a = array("I", bytes(4 * n))
235        self.left += a
236        self.right += a
237        self.balance += array("b", bytes(n))
238
239    def _build(self, a: list[T]) -> None:
240        left, right, balance = self.left, self.right, self.balance
241
242        def sort(l: int, r: int) -> tuple[int, int]:
243            mid = (l + r) >> 1
244            node = mid
245            hl, hr = 0, 0
246            if l != mid:
247                left[node], hl = sort(l, mid)
248            if mid + 1 != r:
249                right[node], hr = sort(mid + 1, r)
250            balance[node] = hl - hr
251            return node, max(hl, hr) + 1
252
253        self._len = len(a)
254        if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
255            a = sorted(a)
256        x, y = BSTMultisetArrayBase[AVLTreeMultiset2, T]._rle(a)
257        n = len(x)
258        end = self.end
259        self.end += n
260        self.reserve(n)
261        self.key[end : end + n] = x
262        self.val[end : end + n] = y
263        self.root = sort(end, n + end)[0]
264
265    def _rotate_L(self, node: int) -> int:
266        left, right, balance = self.left, self.right, self.balance
267        u = left[node]
268        left[node] = right[u]
269        right[u] = node
270        if balance[u] == 1:
271            balance[u] = 0
272            balance[node] = 0
273        else:
274            balance[u] = -1
275            balance[node] = 1
276        return u
277
278    def _rotate_R(self, node: int) -> int:
279        left, right, balance = self.left, self.right, self.balance
280        u = right[node]
281        right[node] = left[u]
282        left[u] = node
283        if balance[u] == -1:
284            balance[u] = 0
285            balance[node] = 0
286        else:
287            balance[u] = 1
288            balance[node] = -1
289        return u
290
291    def _update_balance(self, node: int) -> None:
292        left, right, balance = self.left, self.right, self.balance
293        if balance[node] == 1:
294            balance[right[node]] = -1
295            balance[left[node]] = 0
296        elif balance[node] == -1:
297            balance[right[node]] = 0
298            balance[left[node]] = 1
299        else:
300            balance[right[node]] = 0
301            balance[left[node]] = 0
302        balance[node] = 0
303
304    def _rotate_LR(self, node: int) -> int:
305        left, right = self.left, self.right
306        B = left[node]
307        E = right[B]
308        right[B] = left[E]
309        left[E] = B
310        left[node] = right[E]
311        right[E] = node
312        self._update_balance(E)
313        return E
314
315    def _rotate_RL(self, node: int) -> int:
316        left, right = self.left, self.right
317        C = right[node]
318        D = left[C]
319        left[C] = right[D]
320        right[D] = C
321        right[node] = left[D]
322        left[D] = node
323        self._update_balance(D)
324        return D
325
326    def _discard(self, node: int, path: list[int], di: int) -> bool:
327        left, right, keys, vals, balance = (
328            self.left,
329            self.right,
330            self.key,
331            self.val,
332            self.balance,
333        )
334        if left[node] and right[node]:
335            path.append(node)
336            di <<= 1
337            di |= 1
338            lmax = left[node]
339            while right[lmax]:
340                path.append(lmax)
341                di <<= 1
342                lmax = right[lmax]
343            lmax_val = vals[lmax]
344            keys[node] = keys[lmax]
345            vals[node] = lmax_val
346            node = lmax
347        cnode = right[node] if left[node] == 0 else left[node]
348        if path:
349            if di & 1:
350                left[path[-1]] = cnode
351            else:
352                right[path[-1]] = cnode
353        else:
354            self.root = cnode
355            return True
356        while path:
357            new_node = 0
358            pnode = path.pop()
359            balance[pnode] -= 1 if di & 1 else -1
360            di >>= 1
361            if balance[pnode] == 2:
362                new_node = (
363                    self._rotate_LR(pnode)
364                    if balance[left[pnode]] < 0
365                    else self._rotate_L(pnode)
366                )
367            elif balance[pnode] == -2:
368                new_node = (
369                    self._rotate_RL(pnode)
370                    if balance[right[pnode]] > 0
371                    else self._rotate_R(pnode)
372                )
373            elif balance[pnode] != 0:
374                break
375            if new_node:
376                if not path:
377                    self.root = new_node
378                    return
379                if di & 1:
380                    left[path[-1]] = new_node
381                else:
382                    right[path[-1]] = new_node
383                if balance[new_node] != 0:
384                    break
385        return True
386
387    def discard(self, key: T, val: int = 1) -> bool:
388        keys, vals, left, right = self.key, self.val, self.left, self.right
389        path = []
390        di = 0
391        node = self.root
392        while node:
393            if key == keys[node]:
394                break
395            path.append(node)
396            di <<= 1
397            if key < keys[node]:
398                di |= 1
399                node = left[node]
400            else:
401                node = right[node]
402        else:
403            return False
404        self._len -= min(val, vals[node])
405        if val > vals[node]:
406            val = vals[node] - 1
407            vals[node] -= val
408        if vals[node] == 1:
409            self._discard(node, path, di)
410        else:
411            vals[node] -= val
412        return True
413
414    def discard_all(self, key: T) -> None:
415        self.discard(key, self.count(key))
416
417    def remove(self, key: T, val: int = 1) -> None:
418        if self.discard(key, val):
419            return
420        raise KeyError(key)
421
422    def add(self, key: T, val: int = 1) -> None:
423        self._len += val
424        if self.root == 0:
425            self.root = self._make_node(key, val)
426            return
427        left, right, keys, balance = self.left, self.right, self.key, self.balance
428        node = self.root
429        di = 0
430        path = []
431        while node:
432            if key == keys[node]:
433                self.val[node] += val
434                return
435            path.append(node)
436            di <<= 1
437            if key < keys[node]:
438                di |= 1
439                node = left[node]
440            else:
441                node = right[node]
442        if di & 1:
443            left[path[-1]] = self._make_node(key, val)
444        else:
445            right[path[-1]] = self._make_node(key, val)
446        new_node = 0
447        while path:
448            node = path.pop()
449            balance[node] += 1 if di & 1 else -1
450            di >>= 1
451            if balance[node] == 0:
452                break
453            if balance[node] == 2:
454                new_node = (
455                    self._rotate_LR(node)
456                    if balance[left[node]] < 0
457                    else self._rotate_L(node)
458                )
459                break
460            elif balance[node] == -2:
461                new_node = (
462                    self._rotate_RL(node)
463                    if balance[right[node]] > 0
464                    else self._rotate_R(node)
465                )
466                break
467        if new_node:
468            if path:
469                if di & 1:
470                    left[path[-1]] = new_node
471                else:
472                    right[path[-1]] = new_node
473            else:
474                self.root = new_node
475
476    def count(self, key: T) -> int:
477        return BSTMultisetArrayBase[AVLTreeMultiset2, T].count(self, key)
478
479    def le(self, key: T) -> Optional[T]:
480        return BSTMultisetArrayBase[AVLTreeMultiset2, T].le(self, key)
481
482    def lt(self, key: T) -> Optional[T]:
483        return BSTMultisetArrayBase[AVLTreeMultiset2, T].lt(self, key)
484
485    def ge(self, key: T) -> Optional[T]:
486        return BSTMultisetArrayBase[AVLTreeMultiset2, T].ge(self, key)
487
488    def gt(self, key: T) -> Optional[T]:
489        return BSTMultisetArrayBase[AVLTreeMultiset2, T].gt(self, key)
490
491    def get_min(self) -> Optional[T]:
492        if self.root == 0:
493            return
494        left = self.left
495        node = self.root
496        while left[node]:
497            node = left[node]
498        return self.key[node]
499
500    def get_max(self) -> Optional[T]:
501        if self.root == 0:
502            return
503        right = self.right
504        node = self.root
505        while right[node]:
506            node = right[node]
507        return self.key[node]
508
509    def pop_min(self) -> T:
510        left, vals, keys = self.left, self.val, self.key
511        self._len -= 1
512        node = self.root
513        path = []
514        while left[node]:
515            path.append(node)
516            node = left[node]
517        x = keys[node]
518        if vals[node] == 1:
519            self._discard(node, path, (1 << len(path)) - 1)
520        else:
521            vals[node] -= 1
522        return x
523
524    def pop_max(self) -> T:
525        right, vals, keys = self.right, self.val, self.key
526        self._len -= 1
527        node = self.root
528        path = []
529        while right[node]:
530            path.append(node)
531            node = right[node]
532        x = keys[node]
533        if vals[node] == 1:
534            self._discard(node, path, 0)
535        else:
536            vals[node] -= 1
537        return x
538
539    def clear(self) -> None:
540        self.root = 0
541
542    def tolist(self) -> list[T]:
543        return BSTMultisetArrayBase[AVLTreeMultiset2, T].tolist(self)
544
545    def tolist_items(self) -> list[tuple[T, int]]:
546        left, right, keys, vals = self.left, self.right, self.key, self.val
547        node = self.root
548        stack: list[int] = []
549        a: list[tuple[T, int]] = []
550        while stack or node:
551            if node:
552                stack.append(node)
553                node = left[node]
554            else:
555                node = stack.pop()
556                a.append((keys[node], vals[node]))
557                node = right[node]
558        return a
559
560    def __contains__(self, key: T):
561        return BSTMultisetArrayBase[AVLTreeMultiset2, T].contains(self, key)
562
563    def __len__(self):
564        return self._len
565
566    def __bool__(self):
567        return self.root != 0
568
569    def __str__(self):
570        return "{" + ", ".join(map(str, self.tolist())) + "}"
571
572    def __repr__(self):
573        return f"{self.__class__.__name__}({self.tolist()})"

仕様

class AVLTreeMultiset2(a: Iterable[T] = [])[source]

Bases: Generic[T]

多重集合としての AVL 木です。 配列を用いてノードを表現しています。 size を持たないので軽めです。

add(key: T, val: int = 1) None[source]
clear() None[source]
count(key: T) int[source]
discard(key: T, val: int = 1) bool[source]
discard_all(key: T) None[source]
ge(key: T) T | None[source]
get_max() T | None[source]
get_min() T | None[source]
gt(key: T) T | None[source]
le(key: T) T | None[source]
lt(key: T) T | None[source]
pop_max() T[source]
pop_min() T[source]
remove(key: T, val: int = 1) None[source]
reserve(n: int) None[source]
tolist() list[T][source]
tolist_items() list[tuple[T, int]][source]