wbt_set

ソースコード

from titan_pylib.data_structures.avl_tree.wbt_set import WBTSet

view on github

展開済みコード

  1# from titan_pylib.data_structures.avl_tree.wbt_set import WBTSet
  2from array import array
  3from typing import Generic, Iterable, TypeVar, Optional, Final
  4
  5T = TypeVar("T")
  6
  7DELTA: Final[int] = 3
  8GAMMA: Final[int] = 2
  9
 10
 11class WBTSet(Generic[T]):
 12
 13    def __init__(self, a: Iterable[T] = []) -> None:
 14        self.root = 0
 15        self.key = [0]
 16        self.size = array("I", bytes(4))
 17        self.left = array("I", bytes(4))
 18        self.right = array("I", bytes(4))
 19        self.end = 1
 20        if not isinstance(a, list):
 21            a = list(a)
 22        if a:
 23            self._build(a)
 24
 25    def reserve(self, n: int) -> None:
 26        if n <= 0:
 27            return
 28        self.key += [0] * n
 29        a = array("I", bytes(4 * n))
 30        self.left += a
 31        self.right += a
 32        self.size += array("I", [1] * n)
 33
 34    def _build(self, a: list[T]) -> None:
 35        left, right, size = self.left, self.right, self.size
 36
 37        def sort(l: int, r: int) -> int:
 38            mid = (l + r) >> 1
 39            node = mid
 40            if l != mid:
 41                left[node] = sort(l, mid)
 42                size[node] += size[left[node]]
 43            if mid + 1 != r:
 44                right[node] = sort(mid + 1, r)
 45                size[node] += size[right[node]]
 46            return node
 47
 48        n = len(a)
 49        if n == 0:
 50            return
 51        if not all(a[i] < a[i + 1] for i in range(n - 1)):
 52            b = sorted(a)
 53            a = [b[0]]
 54            for i in range(1, n):
 55                if b[i] != a[-1]:
 56                    a.append(b[i])
 57        n = len(a)
 58        end = self.end
 59        self.end += n
 60        self.reserve(n)
 61        self.key[end : end + n] = a
 62        self.root = sort(end, n + end)
 63
 64    def _rotate_left(self, node: int) -> int:
 65        left, right, size = self.left, self.right, self.size
 66        u = right[node]
 67        size[u] = size[node]
 68        size[node] -= size[right[u]] + 1
 69        right[node] = left[u]
 70        left[u] = node
 71        return u
 72
 73    def _rotate_right(self, node: int) -> int:
 74        left, right, size = self.left, self.right, self.size
 75        u = left[node]
 76        size[u] = size[node]
 77        size[node] -= size[left[u]] + 1
 78        left[node] = right[u]
 79        right[u] = node
 80        return u
 81
 82    def _make_node(self, key: T) -> int:
 83        end = self.end
 84        if end >= len(self.key):
 85            self.key.append(key)
 86            self.size.append(1)
 87            self.left.append(0)
 88            self.right.append(0)
 89        else:
 90            self.key[end] = key
 91        self.end += 1
 92        return end
 93
 94    def _weight_left(self, node: int) -> int:
 95        return self.size[self.left[node]] + 1
 96
 97    def _weight_right(self, node: int) -> int:
 98        return self.size[self.right[node]] + 1
 99
100    def ave_height(self):
101        if not self.root:
102            return 0
103        left, right, size, keys = self.left, self.right, self.size, self.key
104        ans = 0
105
106        def dfs(node, dep):
107            nonlocal ans
108            ans += dep
109            if left[node]:
110                dfs(left[node], dep + 1)
111            if right[node]:
112                dfs(right[node], dep + 1)
113
114        dfs(self.root, 1)
115        ans /= len(self)
116        return ans
117
118    def debug(self, root):
119        left, right, size, keys = self.left, self.right, self.size, self.key
120
121        def dfs(node, indent):
122            if not node:
123                return
124            s = " " * indent
125            print(f"{s}key={keys[node]}, idx={node}")
126            if left[node]:
127                print(f"{s}left: {keys[left[node]]}, idx={left[node]}")
128                dfs(left[node], indent + 2)
129            if right[node]:
130                print(f"{s}righ: {keys[right[node]]}, idx={right[node]}")
131                dfs(right[node], indent + 2)
132
133        dfs(root, 0)
134
135    def add(self, key: T) -> bool:
136        if self.root == 0:
137            self.root = self._make_node(key)
138            return True
139        left, right, size, keys = self.left, self.right, self.size, self.key
140        node = self.root
141        path = []
142        di = 0
143        while node:
144            if key == keys[node]:
145                return False
146            path.append(node)
147            di <<= 1
148            if key < keys[node]:
149                di |= 1
150                node = left[node]
151            else:
152                node = right[node]
153        # self.debug(self.root)
154        if di & 1:
155            left[path[-1]] = self._make_node(key)
156        else:
157            right[path[-1]] = self._make_node(key)
158        while path:
159            node = path.pop()
160            size[node] += 1
161            di >>= 1
162            wl = self._weight_left(node)
163            wr = self._weight_right(node)
164            if wl * DELTA < wr:
165                # print("wl * DELTA < wr")
166                # self.debug(node)
167                if (
168                    self._weight_left(right[node])
169                    >= self._weight_right(right[node]) * GAMMA
170                ):
171                    right[node] = self._rotate_right(right[node])
172                node = self._rotate_left(node)
173                # self.debug(node)
174                # assert node
175            elif wr * DELTA < wl:
176                # print("wr * DELTA < wl")
177                if (
178                    self._weight_right(left[node])
179                    >= self._weight_left(left[node]) * GAMMA
180                ):
181                    # print("left")
182                    left[node] = self._rotate_left(left[node])
183                # print("right")
184                node = self._rotate_right(node)
185                # assert node
186            if path:
187                if di & 1:
188                    left[path[-1]] = node
189                else:
190                    right[path[-1]] = node
191            else:
192                self.root = node
193        return True
194
195    def remove(self, key: T) -> bool:
196        if self.discard(key):
197            return True
198        raise KeyError(key)
199
200    def discard(self, key: T) -> bool:
201        left, right, size, keys = self.left, self.right, self.size, self.key
202        di = 0
203        path = []
204        node = self.root
205        while node:
206            if key == keys[node]:
207                break
208            path.append(node)
209            di <<= 1
210            if key < keys[node]:
211                di |= 1
212                node = left[node]
213            else:
214                node = right[node]
215        else:
216            return False
217        if left[node] and right[node]:
218            path.append(node)
219            di <<= 1
220            di |= 1
221            lmax = left[node]
222            while right[lmax]:
223                path.append(lmax)
224                di <<= 1
225                lmax = right[lmax]
226            keys[node] = keys[lmax]
227            node = lmax
228        cnode = right[node] if left[node] == 0 else left[node]
229        if path:
230            if di & 1:
231                left[path[-1]] = cnode
232            else:
233                right[path[-1]] = cnode
234        else:
235            self.root = cnode
236            return True
237        while path:
238            node = path.pop()
239            size[node] -= 1
240            di >>= 1
241            wl = self._weight_left(node)
242            wr = self._weight_right(node)
243            if wl * DELTA < wr:
244                if (
245                    self._weight_left(right[node])
246                    >= self._weight_right(right[node]) * GAMMA
247                ):
248                    right[node] = self._rotate_right(right[node])
249                node = self._rotate_left(node)
250            elif wr * DELTA < wl:
251                if (
252                    self._weight_right(left[node])
253                    >= self._weight_left(left[node]) * GAMMA
254                ):
255                    left[node] = self._rotate_left(left[node])
256                node = self._rotate_right(node)
257            if path:
258                if di & 1:
259                    left[path[-1]] = node
260                else:
261                    right[path[-1]] = node
262            else:
263                self.root = node
264        return True
265
266    def le(self, key: T) -> Optional[T]:
267        keys, left, right = self.key, self.left, self.right
268        res = None
269        node = self.root
270        while node:
271            if key == keys[node]:
272                return keys[node]
273            if key < keys[node]:
274                node = left[node]
275            else:
276                res = keys[node]
277                node = right[node]
278        return res
279
280    def lt(self, key: T) -> Optional[T]:
281        keys, left, right = self.key, self.left, self.right
282        res = None
283        node = self.root
284        while node:
285            if key <= keys[node]:
286                node = left[node]
287            else:
288                res = keys[node]
289                node = right[node]
290        return res
291
292    def ge(self, key: T) -> Optional[T]:
293        keys, left, right = self.key, self.left, self.right
294        res = None
295        node = self.root
296        while node:
297            if key == keys[node]:
298                return keys[node]
299            if key < keys[node]:
300                res = keys[node]
301                node = left[node]
302            else:
303                node = right[node]
304        return res
305
306    def gt(self, key: T) -> Optional[T]:
307        keys, left, right = self.key, self.left, self.right
308        res = None
309        node = self.root
310        while node:
311            if key < keys[node]:
312                res = keys[node]
313                node = left[node]
314            else:
315                node = right[node]
316        return res
317
318    def index(self, key: T) -> int:
319        keys, left, right, size = self.key, self.left, self.right, self.size
320        k = 0
321        node = self.root
322        while node:
323            if key == keys[node]:
324                k += size[left[node]]
325                break
326            if key < keys[node]:
327                node = left[node]
328            else:
329                k += size[left[node]] + 1
330                node = right[node]
331        return k
332
333    def index_right(self, key: T) -> int:
334        keys, left, right, size = self.key, self.left, self.right, self.size
335        k, node = 0, self.root
336        while node:
337            if key == keys[node]:
338                k += size[left[node]] + 1
339                break
340            if key < keys[node]:
341                node = left[node]
342            else:
343                k += size[left[node]] + 1
344                node = right[node]
345        return k
346
347    def get_max(self) -> Optional[T]:
348        if not self:
349            return
350        return self[len(self) - 1]
351
352    def get_min(self) -> Optional[T]:
353        if not self:
354            return
355        return self[0]
356
357    def pop(self, k: int = -1) -> T:
358        left, right, size, key = self.left, self.right, self.size, self.key
359        if k < 0:
360            k += size[self.root]
361        assert 0 <= k and k < size[self.root], "IndexError"
362        path = []
363        di = 0
364        node = self.root
365        while True:
366            t = size[left[node]]
367            if t == k:
368                res = key[node]
369                break
370            path.append(node)
371            di <<= 1
372            if t < k:
373                k -= t + 1
374                node = right[node]
375            else:
376                di |= 1
377                node = left[node]
378        if left[node] and right[node]:
379            path.append(node)
380            di <<= 1
381            di |= 1
382            lmax = left[node]
383            while right[lmax]:
384                path.append(lmax)
385                di <<= 1
386                lmax = right[lmax]
387            key[node] = key[lmax]
388            node = lmax
389        cnode = right[node] if left[node] == 0 else left[node]
390        if path:
391            if di & 1:
392                left[path[-1]] = cnode
393            else:
394                right[path[-1]] = cnode
395        else:
396            self.root = cnode
397            return res
398        while path:
399            node = path.pop()
400            size[node] -= 1
401            di >>= 1
402            wl = self._weight_left(node)
403            wr = self._weight_right(node)
404            if wl * DELTA < wr:
405                if (
406                    self._weight_left(right[node])
407                    >= self._weight_right(right[node]) * GAMMA
408                ):
409                    right[node] = self._rotate_right(right[node])
410                node = self._rotate_left(node)
411            elif wr * DELTA < wl:
412                if (
413                    self._weight_right(left[node])
414                    >= self._weight_left(left[node]) * GAMMA
415                ):
416                    left[node] = self._rotate_left(left[node])
417                node = self._rotate_right(node)
418            if path:
419                if di & 1:
420                    left[path[-1]] = node
421                else:
422                    right[path[-1]] = node
423            else:
424                self.root = node
425        return res
426
427    def pop_max(self) -> T:
428        return self.pop()
429
430    def pop_min(self) -> T:
431        return self.pop(0)
432
433    def clear(self) -> None:
434        self.root = 0
435
436    def tolist(self) -> list[T]:
437        left, right, keys = self.left, self.right, self.key
438        node = self.root
439        stack, a = [], []
440        while stack or node:
441            if node:
442                stack.append(node)
443                node = left[node]
444            else:
445                node = stack.pop()
446                a.append(keys[node])
447                node = right[node]
448        return a
449
450    def check(self) -> int:
451        """作業用デバック関数"""
452        if self.root == 0:
453            return 0
454
455        def _balance_check(node: int) -> None:
456            if node == 0:
457                return
458            if not self._weight_left(node) * DELTA >= self._weight_right(node):
459                print(self._weight_left(node), self._weight_right(node), flush=True)
460                assert False, f"self._weight_left() * DELTA >= self._weight_right()"
461            if not self._weight_right(node) * DELTA >= self._weight_left(node):
462                print(self._weight_left(node), self._weight_right(node), flush=True)
463                assert False, f"self._weight_right() * DELTA >= self._weight_left()"
464
465        keys = self.key
466
467        # _size, height
468        def dfs(node) -> tuple[int, int]:
469            _balance_check(node)
470            h = 0
471            s = 1
472            if self.left[node]:
473                assert keys[self.left[node]] < keys[node]
474                ls, lh = dfs(self.left[node])
475                s += ls
476                h = max(h, lh)
477            if self.right[node]:
478                assert keys[node] < keys[self.right[node]]
479                rs, rh = dfs(self.right[node])
480                s += rs
481                h = max(h, rh)
482            assert self.size[node] == s
483            return s, h + 1
484
485        _, h = dfs(self.root)
486        return h
487
488    def __contains__(self, key: T) -> bool:
489        keys, left, right = self.key, self.left, self.right
490        node = self.root
491        while node:
492            if key == keys[node]:
493                return True
494            node = left[node] if key < keys[node] else right[node]
495        return False
496
497    def __getitem__(self, k: int) -> T:
498        left, right, size, key = self.left, self.right, self.size, self.key
499        if k < 0:
500            k += size[self.root]
501        assert (
502            0 <= k and k < size[self.root]
503        ), f"IndexError: WBTSet[{k}], len={len(self)}"
504        node = self.root
505        while True:
506            t = size[left[node]]
507            if t == k:
508                return key[node]
509            if t < k:
510                k -= t + 1
511                node = right[node]
512            else:
513                node = left[node]
514
515    def __iter__(self):
516        self.__iter = 0
517        return self
518
519    def __next__(self):
520        if self.__iter == self.__len__():
521            raise StopIteration
522        res = self[self.__iter]
523        self.__iter += 1
524        return res
525
526    def __reversed__(self):
527        for i in range(self.__len__()):
528            yield self[-i - 1]
529
530    def __len__(self):
531        return self.size[self.root]
532
533    def __bool__(self):
534        return self.root != 0
535
536    def __str__(self):
537        return "{" + ", ".join(map(str, self.tolist())) + "}"
538
539    def __repr__(self):
540        return f"WBTSet({self})"

仕様

class WBTSet(a: Iterable[T] = [])[source]

Bases: Generic[T]

add(key: T) bool[source]
ave_height()[source]
check() int[source]

作業用デバック関数

clear() None[source]
debug(root)[source]
discard(key: T) bool[source]
ge(key: T) T | None[source]
get_max() T | None[source]
get_min() T | None[source]
gt(key: T) T | None[source]
index(key: T) int[source]
index_right(key: T) int[source]
le(key: T) T | None[source]
lt(key: T) T | None[source]
pop(k: int = -1) T[source]
pop_max() T[source]
pop_min() T[source]
remove(key: T) bool[source]
reserve(n: int) None[source]
tolist() list[T][source]