1# from titan_pylib.data_structures.set.wordsize_tree_set import WordsizeTreeSet
2from array import array
3from typing import Iterable, Optional
4
5
6class WordsizeTreeSet:
7 """``[0, u)`` の整数集合を管理する32分木です。
8 空間 :math:`O(u)` であることに注意してください。
9 """
10
11 def __init__(self, u: int, a: Iterable[int] = []) -> None:
12 """:math:`O(u)` です。"""
13 assert u >= 0
14 u += 1 # 念のため
15 self.u = u
16 data = []
17 len_ = 0
18 if a:
19 u >>= 5
20 A = array("I", bytes(4 * (u + 1)))
21 for a_ in a:
22 assert (
23 0 <= a_ < self.u
24 ), f"ValueError: {self.__class__.__name__}.__init__, {a_}, u={u}"
25 if A[a_ >> 5] >> (a_ & 31) & 1 == 0:
26 len_ += 1
27 A[a_ >> 5] |= 1 << (a_ & 31)
28 data.append(A)
29 while u:
30 a = array("I", bytes(4 * ((u >> 5) + 1)))
31 for i in range(u + 1):
32 if A[i]:
33 a[i >> 5] |= 1 << (i & 31)
34 data.append(a)
35 A = a
36 u >>= 5
37 else:
38 while u:
39 u >>= 5
40 data.append(array("I", bytes(4 * (u + 1))))
41 self.data: list[array[int]] = data
42 self.len: int = len_
43 self.len_data: int = len(data)
44
45 def add(self, v: int) -> bool:
46 """整数 ``v`` を個追加します。
47 :math:`O(\\log{u})` です。
48 """
49 assert (
50 0 <= v < self.u
51 ), f"ValueError: {self.__class__.__name__}.add({v}), u={self.u}"
52 if self.data[0][v >> 5] >> (v & 31) & 1:
53 return False
54 self.len += 1
55 for a in self.data:
56 a[v >> 5] |= 1 << (v & 31)
57 v >>= 5
58 return True
59
60 def discard(self, v: int) -> bool:
61 """整数 ``v`` を削除します。
62 :math:`O(\\log{u})` です。
63 """
64 assert (
65 0 <= v < self.u
66 ), f"ValueError: {self.__class__.__name__}.discard({v}), u={self.u}"
67 if self.data[0][v >> 5] >> (v & 31) & 1 == 0:
68 return False
69 self.len -= 1
70 for a in self.data:
71 a[v >> 5] &= ~(1 << (v & 31))
72 v >>= 5
73 if a[v]:
74 break
75 return True
76
77 def remove(self, v: int) -> None:
78 """整数 ``v`` を削除します。
79 :math:`O(\\log{u})` です。
80
81 Note: ``v`` が存在しないとき、例外を投げます。
82 """
83 assert (
84 0 <= v < self.u
85 ), f"ValueError: {self.__class__.__name__}.remove({v}), u={self.u}"
86 assert self.discard(v), f"ValueError: {v} not in self."
87
88 def ge(self, v: int) -> Optional[int]:
89 """``v`` 以上で最小の要素を返します。存在しないとき、 ``None``を返します。
90 :math:`O(\\log{u})` です。
91 """
92 assert (
93 0 <= v < self.u
94 ), f"ValueError: {self.__class__.__name__}.ge({v}), u={self.u}"
95 data = self.data
96 d = 0
97 while True:
98 if d >= self.len_data or v >> 5 >= len(data[d]):
99 return None
100 m = data[d][v >> 5] & ((~0) << (v & 31))
101 if m == 0:
102 d += 1
103 v = (v >> 5) + 1
104 else:
105 v = (v >> 5 << 5) + (m & -m).bit_length() - 1
106 if d == 0:
107 break
108 v <<= 5
109 d -= 1
110 return v
111
112 def gt(self, v: int) -> Optional[int]:
113 """``v`` より大きい値で最小の要素を返します。存在しないとき、 ``None``を返します。
114 :math:`O(\\log{u})` です。
115 """
116 assert (
117 0 <= v < self.u
118 ), f"ValueError: {self.__class__.__name__}.gt({v}), u={self.u}"
119 if v + 1 == self.u:
120 return
121 return self.ge(v + 1)
122
123 def le(self, v: int) -> Optional[int]:
124 """``v`` 以下で最大の要素を返します。存在しないとき、 ``None``を返します。
125 :math:`O(\\log{u})` です。
126 """
127 assert (
128 0 <= v < self.u
129 ), f"ValueError: {self.__class__.__name__}.le({v}), u={self.u}"
130 data = self.data
131 d = 0
132 while True:
133 if v < 0 or d >= self.len_data:
134 return None
135 m = data[d][v >> 5] & ~((~1) << (v & 31))
136 if m == 0:
137 d += 1
138 v = (v >> 5) - 1
139 else:
140 v = (v >> 5 << 5) + m.bit_length() - 1
141 if d == 0:
142 break
143 v <<= 5
144 v += 31
145 d -= 1
146 return v
147
148 def lt(self, v: int) -> Optional[int]:
149 """``v`` より小さい値で最大の要素を返します。存在しないとき、 ``None``を返します。
150 :math:`O(\\log{u})` です。
151 """
152 assert (
153 0 <= v < self.u
154 ), f"ValueError: {self.__class__.__name__}.lt({v}), u={self.u}"
155 if v - 1 == 0:
156 return
157 return self.le(v - 1)
158
159 def get_min(self) -> Optional[int]:
160 """`最小値を返します。存在しないとき、 ``None``を返します。
161 :math:`O(\\log{u})` です。
162 """
163 return self.ge(0)
164
165 def get_max(self) -> Optional[int]:
166 """最大値を返します。存在しないとき、 ``None``を返します。
167 :math:`O(\\log{u})` です。
168 """
169 return self.le(self.u - 1)
170
171 def pop_min(self) -> int:
172 """最小値を削除して返します。
173 :math:`O(\\log{u})` です。
174 """
175 v = self.get_min()
176 assert (
177 v is not None
178 ), f"IndexError: pop_min() from empty {self.__class__.__name__}."
179 self.discard(v)
180 return v
181
182 def pop_max(self) -> int:
183 """最大値を削除して返します。
184 :math:`O(\\log{u})` です。
185 """
186 v = self.get_max()
187 assert (
188 v is not None
189 ), f"IndexError: pop_max() from empty {self.__class__.__name__}."
190 self.discard(v)
191 return v
192
193 def clear(self) -> None:
194 """集合を空にします。
195 :math:`O(n\\log{u})` です。
196 """
197 for e in self:
198 self.discard(e)
199 self.len = 0
200
201 def tolist(self) -> list[int]:
202 """リストにして返します。
203 :math:`O(n\\log{u})` です。
204 """
205 return [x for x in self]
206
207 def __bool__(self):
208 return self.len > 0
209
210 def __len__(self):
211 return self.len
212
213 def __contains__(self, v: int):
214 assert (
215 0 <= v < self.u
216 ), f"ValueError: {v} in {self.__class__.__name__}, u={self.u}"
217 return self.data[0][v >> 5] >> (v & 31) & 1 == 1
218
219 def __iter__(self):
220 self._val = self.ge(0)
221 return self
222
223 def __next__(self):
224 if self._val is None:
225 raise StopIteration
226 pre = self._val
227 self._val = self.gt(pre)
228 return pre
229
230 def __str__(self):
231 return "{" + ", ".join(map(str, self)) + "}"
232
233 def __repr__(self):
234 return f"{self.__class__.__name__}({self.u}, {self})"