segki_set

ソースコード

from titan_pylib.data_structures.set.segki_set import SegkiSet

view on github

展開済みコード

  1# from titan_pylib.data_structures.set.segki_set import SegkiSet
  2from typing import Optional, Iterable
  3
  4
  5class SegkiSet:
  6
  7    # 0以上u未満の整数が載る集合
  8    # セグ木的な構造、各Nodeはその子孫のOR値を保持(ORではなくSUMならBITと同じ感じ)
  9    #
 10    # 空間: O(u)
 11    # add, discard, predecessor, successor: O(logu)
 12    # contains, len: O(1)
 13    # iteration: (nlogu)
 14    # kth element: O(klogu)
 15
 16    def __init__(self, u: int, a: Iterable[int] = []):
 17        self.log = (u - 1).bit_length()
 18        self.size = 1 << self.log
 19        self.u = u
 20        self.data = bytearray(self.size << 1)
 21        self.len = 0
 22        for _a in a:
 23            self.add(_a)
 24
 25    def add(self, k: int) -> bool:
 26        k += self.size
 27        if self.data[k]:
 28            return False
 29        self.len += 1
 30        self.data[k] = 1
 31        while k > 1:
 32            k >>= 1
 33            if self.data[k]:
 34                break
 35            self.data[k] = 1
 36        return True
 37
 38    def discard(self, k: int) -> bool:
 39        k += self.size
 40        if self.data[k] == 0:
 41            return False
 42        self.len -= 1
 43        self.data[k] = 0
 44        while k > 1:
 45            if k & 1:
 46                if self.data[k - 1]:
 47                    break
 48            else:
 49                if self.data[k + 1]:
 50                    break
 51            k >>= 1
 52            self.data[k] = 0
 53        return True
 54
 55    def get_min(self) -> Optional[int]:
 56        if self.data[1] == 0:
 57            return None
 58        k = 1
 59        while k < self.size:
 60            k <<= 1
 61            if self.data[k] == 0:
 62                k |= 1
 63        return k - self.size
 64
 65    def get_max(self) -> Optional[int]:
 66        if self.data[1] == 0:
 67            return None
 68        k = 1
 69        while k < self.size:
 70            k <<= 1
 71            if self.data[k | 1]:
 72                k |= 1
 73        return k - self.size
 74
 75    """Find the largest element < key, or None if it doesn't exist. / O(logN)"""
 76
 77    def lt(self, k: int) -> Optional[int]:
 78        if self.data[1] == 0:
 79            return None
 80        x = k
 81        k += self.size
 82        while k > 1:
 83            if k & 1 and self.data[k - 1]:
 84                k >>= 1
 85                break
 86            k >>= 1
 87        k <<= 1
 88        if self.data[k] == 0:
 89            return None
 90        while k < self.size:
 91            k <<= 1
 92            if self.data[k | 1]:
 93                k |= 1
 94        k -= self.size
 95        return k if k < x else None
 96
 97    """Find the smallest element > key, or None if it doesn't exist. / O(logN)"""
 98
 99    def gt(self, k: int) -> Optional[int]:
100        if self.data[1] == 0:
101            return None
102        x = k
103        k += self.size
104        while k > 1:
105            if k & 1 == 0 and self.data[k + 1]:
106                k >>= 1
107                break
108            k >>= 1
109        k = k << 1 | 1
110        while k < self.size:
111            k <<= 1
112            if self.data[k] == 0:
113                k |= 1
114        k -= self.size
115        return k if k > x and k < self.u else None
116
117    def le(self, k: int) -> Optional[int]:
118        if self.data[k + self.size]:
119            return k
120        return self.lt(k)
121
122    def ge(self, k: int) -> Optional[int]:
123        if self.data[k + self.size]:
124            return k
125        return self.gt(k)
126
127    def debug(self):
128        print(
129            "\n".join(
130                " ".join(map(str, (self.data[(1 << i) + j] for j in range(1 << i))))
131                for i in range(self.log + 1)
132            )
133        )
134
135    def __contains__(self, k: int):
136        return self.data[k + self.size] == 1
137
138    def __getitem__(self, k: int):  # kは先頭か末尾にすることを推奨
139        # O(klogu)
140        if k < 0:
141            k += self.len
142        if k == 0:
143            return self.get_min()
144        if k == self.len - 1:
145            return self.get_max()
146        if k < self.len >> 1:
147            key = self.get_min()
148            for _ in range(k):
149                key = self.gt(key)
150        else:
151            key = self.get_max()
152            for _ in range(self.len - k - 1):
153                key = self.lt(key)
154        return key
155
156    def __len__(self):
157        return self.len
158
159    def __iter__(self):
160        key = self.get_min()
161        while key is not None:
162            yield key
163            key = self.gt(key)
164
165    def __str__(self):
166        return "{" + ", ".join(map(str, self)) + "}"

仕様

class SegkiSet(u: int, a: Iterable[int] = [])[source]

Bases: object

add(k: int) bool[source]
debug()[source]
discard(k: int) bool[source]
ge(k: int) int | None[source]
get_max() int | None[source]
get_min() int | None[source]
gt(k: int) int | None[source]
le(k: int) int | None[source]
lt(k: int) int | None[source]