wordsize_tree_set

ソースコード

from titan_pylib.data_structures.set.wordsize_tree_set import WordsizeTreeSet

view on github

展開済みコード

  1# from titan_pylib.data_structures.set.wordsize_tree_set import WordsizeTreeSet
  2from array import array
  3from typing import Iterable, Optional
  4
  5
  6class WordsizeTreeSet:
  7    """``[0, u)`` の整数集合を管理する32分木です。
  8    空間 :math:`O(u)` であることに注意してください。
  9    """
 10
 11    def __init__(self, u: int, a: Iterable[int] = []) -> None:
 12        """:math:`O(u)` です。"""
 13        assert u >= 0
 14        u += 1  # 念のため
 15        self.u = u
 16        data = []
 17        len_ = 0
 18        if a:
 19            u >>= 5
 20            A = array("I", bytes(4 * (u + 1)))
 21            for a_ in a:
 22                assert (
 23                    0 <= a_ < self.u
 24                ), f"ValueError: {self.__class__.__name__}.__init__, {a_}, u={u}"
 25                if A[a_ >> 5] >> (a_ & 31) & 1 == 0:
 26                    len_ += 1
 27                    A[a_ >> 5] |= 1 << (a_ & 31)
 28            data.append(A)
 29            while u:
 30                a = array("I", bytes(4 * ((u >> 5) + 1)))
 31                for i in range(u + 1):
 32                    if A[i]:
 33                        a[i >> 5] |= 1 << (i & 31)
 34                data.append(a)
 35                A = a
 36                u >>= 5
 37        else:
 38            while u:
 39                u >>= 5
 40                data.append(array("I", bytes(4 * (u + 1))))
 41        self.data: list[array[int]] = data
 42        self.len: int = len_
 43        self.len_data: int = len(data)
 44
 45    def add(self, v: int) -> bool:
 46        """整数 ``v`` を個追加します。
 47        :math:`O(\\log{u})` です。
 48        """
 49        assert (
 50            0 <= v < self.u
 51        ), f"ValueError: {self.__class__.__name__}.add({v}), u={self.u}"
 52        if self.data[0][v >> 5] >> (v & 31) & 1:
 53            return False
 54        self.len += 1
 55        for a in self.data:
 56            a[v >> 5] |= 1 << (v & 31)
 57            v >>= 5
 58        return True
 59
 60    def discard(self, v: int) -> bool:
 61        """整数 ``v`` を削除します。
 62        :math:`O(\\log{u})` です。
 63        """
 64        assert (
 65            0 <= v < self.u
 66        ), f"ValueError: {self.__class__.__name__}.discard({v}), u={self.u}"
 67        if self.data[0][v >> 5] >> (v & 31) & 1 == 0:
 68            return False
 69        self.len -= 1
 70        for a in self.data:
 71            a[v >> 5] &= ~(1 << (v & 31))
 72            v >>= 5
 73            if a[v]:
 74                break
 75        return True
 76
 77    def remove(self, v: int) -> None:
 78        """整数 ``v`` を削除します。
 79        :math:`O(\\log{u})` です。
 80
 81        Note: ``v`` が存在しないとき、例外を投げます。
 82        """
 83        assert (
 84            0 <= v < self.u
 85        ), f"ValueError: {self.__class__.__name__}.remove({v}), u={self.u}"
 86        assert self.discard(v), f"ValueError: {v} not in self."
 87
 88    def ge(self, v: int) -> Optional[int]:
 89        """``v`` 以上で最小の要素を返します。存在しないとき、 ``None``を返します。
 90        :math:`O(\\log{u})` です。
 91        """
 92        assert (
 93            0 <= v < self.u
 94        ), f"ValueError: {self.__class__.__name__}.ge({v}), u={self.u}"
 95        data = self.data
 96        d = 0
 97        while True:
 98            if d >= self.len_data or v >> 5 >= len(data[d]):
 99                return None
100            m = data[d][v >> 5] & ((~0) << (v & 31))
101            if m == 0:
102                d += 1
103                v = (v >> 5) + 1
104            else:
105                v = (v >> 5 << 5) + (m & -m).bit_length() - 1
106                if d == 0:
107                    break
108                v <<= 5
109                d -= 1
110        return v
111
112    def gt(self, v: int) -> Optional[int]:
113        """``v`` より大きい値で最小の要素を返します。存在しないとき、 ``None``を返します。
114        :math:`O(\\log{u})` です。
115        """
116        assert (
117            0 <= v < self.u
118        ), f"ValueError: {self.__class__.__name__}.gt({v}), u={self.u}"
119        if v + 1 == self.u:
120            return
121        return self.ge(v + 1)
122
123    def le(self, v: int) -> Optional[int]:
124        """``v`` 以下で最大の要素を返します。存在しないとき、 ``None``を返します。
125        :math:`O(\\log{u})` です。
126        """
127        assert (
128            0 <= v < self.u
129        ), f"ValueError: {self.__class__.__name__}.le({v}), u={self.u}"
130        data = self.data
131        d = 0
132        while True:
133            if v < 0 or d >= self.len_data:
134                return None
135            m = data[d][v >> 5] & ~((~1) << (v & 31))
136            if m == 0:
137                d += 1
138                v = (v >> 5) - 1
139            else:
140                v = (v >> 5 << 5) + m.bit_length() - 1
141                if d == 0:
142                    break
143                v <<= 5
144                v += 31
145                d -= 1
146        return v
147
148    def lt(self, v: int) -> Optional[int]:
149        """``v`` より小さい値で最大の要素を返します。存在しないとき、 ``None``を返します。
150        :math:`O(\\log{u})` です。
151        """
152        assert (
153            0 <= v < self.u
154        ), f"ValueError: {self.__class__.__name__}.lt({v}), u={self.u}"
155        if v - 1 == 0:
156            return
157        return self.le(v - 1)
158
159    def get_min(self) -> Optional[int]:
160        """`最小値を返します。存在しないとき、 ``None``を返します。
161        :math:`O(\\log{u})` です。
162        """
163        return self.ge(0)
164
165    def get_max(self) -> Optional[int]:
166        """最大値を返します。存在しないとき、 ``None``を返します。
167        :math:`O(\\log{u})` です。
168        """
169        return self.le(self.u - 1)
170
171    def pop_min(self) -> int:
172        """最小値を削除して返します。
173        :math:`O(\\log{u})` です。
174        """
175        v = self.get_min()
176        assert (
177            v is not None
178        ), f"IndexError: pop_min() from empty {self.__class__.__name__}."
179        self.discard(v)
180        return v
181
182    def pop_max(self) -> int:
183        """最大値を削除して返します。
184        :math:`O(\\log{u})` です。
185        """
186        v = self.get_max()
187        assert (
188            v is not None
189        ), f"IndexError: pop_max() from empty {self.__class__.__name__}."
190        self.discard(v)
191        return v
192
193    def clear(self) -> None:
194        """集合を空にします。
195        :math:`O(n\\log{u})` です。
196        """
197        for e in self:
198            self.discard(e)
199        self.len = 0
200
201    def tolist(self) -> list[int]:
202        """リストにして返します。
203        :math:`O(n\\log{u})` です。
204        """
205        return [x for x in self]
206
207    def __bool__(self):
208        return self.len > 0
209
210    def __len__(self):
211        return self.len
212
213    def __contains__(self, v: int):
214        assert (
215            0 <= v < self.u
216        ), f"ValueError: {v} in {self.__class__.__name__}, u={self.u}"
217        return self.data[0][v >> 5] >> (v & 31) & 1 == 1
218
219    def __iter__(self):
220        self._val = self.ge(0)
221        return self
222
223    def __next__(self):
224        if self._val is None:
225            raise StopIteration
226        pre = self._val
227        self._val = self.gt(pre)
228        return pre
229
230    def __str__(self):
231        return "{" + ", ".join(map(str, self)) + "}"
232
233    def __repr__(self):
234        return f"{self.__class__.__name__}({self.u}, {self})"

仕様

class WordsizeTreeSet(u: int, a: Iterable[int] = [])[source]

Bases: object

[0, u) の整数集合を管理する32分木です。 空間 \(O(u)\) であることに注意してください。

add(v: int) bool[source]

整数 v を個追加します。 \(O(\log{u})\) です。

clear() None[source]

集合を空にします。 \(O(n\log{u})\) です。

discard(v: int) bool[source]

整数 v を削除します。 \(O(\log{u})\) です。

ge(v: int) int | None[source]

v 以上で最小の要素を返します。存在しないとき、 ``None``を返します。 \(O(\log{u})\) です。

get_max() int | None[source]

最大値を返します。存在しないとき、 ``None``を返します。 \(O(\log{u})\) です。

get_min() int | None[source]

最小値を返します。存在しないとき、 ``None``を返します。 :math:`O(log{u}) です。

gt(v: int) int | None[source]

v より大きい値で最小の要素を返します。存在しないとき、 ``None``を返します。 \(O(\log{u})\) です。

le(v: int) int | None[source]

v 以下で最大の要素を返します。存在しないとき、 ``None``を返します。 \(O(\log{u})\) です。

lt(v: int) int | None[source]

v より小さい値で最大の要素を返します。存在しないとき、 ``None``を返します。 \(O(\log{u})\) です。

pop_max() int[source]

最大値を削除して返します。 \(O(\log{u})\) です。

pop_min() int[source]

最小値を削除して返します。 \(O(\log{u})\) です。

remove(v: int) None[source]

整数 v を削除します。 \(O(\log{u})\) です。

Note: v が存在しないとき、例外を投げます。

tolist() list[int][source]

リストにして返します。 \(O(n\log{u})\) です。