1from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
2from typing import Optional, Iterable
3from array import array
4
5
[docs]
6class BinaryTrieSet(OrderedSetInterface):
7
8 def __init__(self, u: int, a: Iterable[int] = []) -> None:
9 """構築します。
10 :math:`O(n\\log{u})` です。
11 """
12 self.left = array("I", bytes(8))
13 self.right = array("I", bytes(8))
14 self.par = array("I", bytes(8))
15 self.size = array("I", bytes(8))
16 self.valid = array("B", bytes(8))
17 self.end = 2
18 self.root = 1
19 self.bit = (u - 1).bit_length()
20 self.lim = 1 << self.bit
21 self.xor = 0
22 for e in a:
23 self.add(e)
24
25 def _make_node(self) -> int:
26 end = self.end
27 if end >= len(self.left):
28 self.left.append(0)
29 self.right.append(0)
30 self.par.append(0)
31 self.size.append(0)
32 self.valid.append(1)
33 else:
34 self.valid[end] = 1
35 self.end += 1
36 return end
37
38 def _find(self, key: int) -> int:
39 left, right, valid = self.left, self.right, self.valid
40 key ^= self.xor
41 node = self.root
42 for i in range(self.bit - 1, -1, -1):
43 if key >> i & 1:
44 if (not right[node]) or (not valid[node]):
45 return -1
46 node = right[node]
47 else:
48 if (not left[node]) or (not valid[node]):
49 return -1
50 node = left[node]
51 return node
52
[docs]
53 def reserve(self, n: int) -> None:
54 """``n`` 要素分のメモリを確保します。
55
56 :math:`O(n)` です。
57 """
58 assert n >= 0, f"ValueError: BinaryTrieSet.reserve({n})"
59 a = array("I", bytes(4 * n))
60 self.left += a
61 self.right += a
62 self.par += a
63 self.size += a
64 self.valid += array("B", bytes(n))
65
[docs]
66 def add(self, key: int) -> bool:
67 assert (
68 0 <= key < self.lim
69 ), f"ValueError: BinaryTrieSet.add({key}), lim={self.lim}"
70 left, right, par, size = self.left, self.right, self.par, self.size
71 key ^= self.xor
72 node = self.root
73 for i in range(self.bit - 1, -1, -1):
74 if key >> i & 1:
75 if not right[node]:
76 right[node] = self._make_node()
77 par[right[node]] = node
78 node = right[node]
79 else:
80 if not left[node]:
81 left[node] = self._make_node()
82 par[left[node]] = node
83 node = left[node]
84 if size[node]:
85 return False
86 size[node] = 1
87 for _ in range(self.bit):
88 node = par[node]
89 size[node] += 1
90 return True
91
92 def _rmeove(self, node: int) -> None:
93 left, right, par, size, valid = (
94 self.left,
95 self.right,
96 self.par,
97 self.size,
98 self.valid,
99 )
100 for _ in range(self.bit):
101 size[node] -= 1
102 if left[par[node]] == node:
103 node = par[node]
104 # left[node] = 0
105 valid[left[node]] = 0
106 if right[node]:
107 break
108 else:
109 node = par[node]
110 # right[node] = 0
111 valid[right[node]] = 0
112 if left[node]:
113 break
114 while node:
115 size[node] -= 1
116 node = par[node]
117
[docs]
118 def discard(self, key: int) -> bool:
119 assert (
120 0 <= key < self.lim
121 ), f"ValueError: BinaryTrieSet.discard({key}), lim={self.lim}"
122 node = self._find(key)
123 if node == -1:
124 return False
125 self._rmeove(node)
126 return True
127
[docs]
128 def remove(self, key: int) -> None:
129 if self.discard(key):
130 return
131 raise KeyError(key)
132
[docs]
133 def pop(self, k: int = -1) -> int:
134 assert (
135 -len(self) <= k < len(self)
136 ), f"IndexError: BinaryTrieSet.pop({k}), len={len(self)}"
137 if k < 0:
138 k += len(self)
139 left, right, size = self.left, self.right, self.size
140 node = self.root
141 res = 0
142 for i in range(self.bit - 1, -1, -1):
143 res <<= 1
144 if self.xor >> i & 1:
145 left, right = right, left
146 t = size[left[node]]
147 if t <= k:
148 k -= t
149 res |= 1
150 node = right[node]
151 else:
152 node = left[node]
153 if self.xor >> i & 1:
154 left, right = right, left
155 self._rmeove(node)
156 return res ^ self.xor
157
[docs]
158 def pop_min(self) -> int:
159 assert self, f"IndexError: BinaryTrieSet.pop_min(), len={len(self)}"
160 return self.pop(0)
161
[docs]
162 def pop_max(self) -> int:
163 return self.pop()
164
[docs]
165 def all_xor(self, x: int) -> None:
166 """すべての要素に ``x`` で ``xor`` をかけます。
167
168 :math:`O(1)` です。
169 """
170 assert (
171 0 <= x < self.lim
172 ), f"ValueError: BinaryTrieSet.all_xor({x}), lim={self.lim}"
173 self.xor ^= x
174
[docs]
175 def get_min(self) -> Optional[int]:
176 if not self:
177 return None
178 left, right = self.left, self.right
179 key = self.xor
180 ans = 0
181 node = self.root
182 for i in range(self.bit - 1, -1, -1):
183 ans <<= 1
184 if key >> i & 1:
185 if right[node]:
186 node = right[node]
187 ans |= 1
188 else:
189 node = left[node]
190 else:
191 if left[node]:
192 node = left[node]
193 else:
194 node = right[node]
195 ans |= 1
196 return ans ^ self.xor
197
[docs]
198 def get_max(self) -> Optional[int]:
199 if not self:
200 return None
201 left, right = self.left, self.right
202 key = self.xor
203 ans = 0
204 node = self.root
205 for i in range(self.bit - 1, -1, -1):
206 ans <<= 1
207 if key >> i & 1:
208 if left[node]:
209 node = left[node]
210 else:
211 node = right[node]
212 ans |= 1
213 else:
214 if right[node]:
215 ans |= 1
216 node = right[node]
217 else:
218 node = left[node]
219 return ans ^ self.xor
220
[docs]
221 def index(self, key: int) -> int:
222 assert (
223 0 <= key < self.lim
224 ), f"ValueError: BinaryTrieSet.index({key}), lim={self.lim}"
225 left, right, size, valid = self.left, self.right, self.size, self.valid
226 k = 0
227 node = self.root
228 key ^= self.xor
229 for i in range(self.bit - 1, -1, -1):
230 if key >> i & 1:
231 k += size[left[node]]
232 node = right[node]
233 else:
234 node = left[node]
235 if (not node) or (not valid[node]):
236 break
237 return k
238
[docs]
239 def index_right(self, key: int) -> int:
240 assert (
241 0 <= key < self.lim
242 ), f"ValueError: BinaryTrieSet.index_right({key}), lim={self.lim}"
243 left, right, size, valid = self.left, self.right, self.size, self.valid
244 k = 0
245 node = self.root
246 key ^= self.xor
247 for i in range(self.bit - 1, -1, -1):
248 if key >> i & 1:
249 k += size[left[node]]
250 node = right[node]
251 else:
252 node = left[node]
253 if (not node) or (not valid[node]):
254 break
255 else:
256 k += 1
257 return k
258
[docs]
259 def clear(self) -> None:
260 self.root = 1
261
[docs]
262 def gt(self, key: int) -> Optional[int]:
263 assert (
264 0 <= key < self.lim
265 ), f"ValueError: BinaryTrieSet.gt({key}), lim={self.lim}"
266 i = self.index_right(key)
267 return None if i >= self.size[self.root] else self[i]
268
[docs]
269 def lt(self, key: int) -> Optional[int]:
270 assert (
271 0 <= key < self.lim
272 ), f"ValueError: BinaryTrieSet.lt({key}), lim={self.lim}"
273 i = self.index(key) - 1
274 return None if i < 0 else self[i]
275
[docs]
276 def ge(self, key: int) -> Optional[int]:
277 assert (
278 0 <= key < self.lim
279 ), f"ValueError: BinaryTrieSet.ge({key}), lim={self.lim}"
280 if key == 0:
281 return self.get_min() if self else None
282 i = self.index_right(key - 1)
283 return None if i >= self.size[self.root] else self[i]
284
[docs]
285 def le(self, key: int) -> Optional[int]:
286 assert (
287 0 <= key < self.lim
288 ), f"ValueError: BinaryTrieSet.le({key}), lim={self.lim}"
289 i = self.index(key + 1) - 1
290 return None if i < 0 else self[i]
291
[docs]
292 def tolist(self) -> list[int]:
293 a = []
294 if not self:
295 return a
296 val = self.get_min()
297 while val is not None:
298 a.append(val)
299 val = self.gt(val)
300 return a
301
302 def __contains__(self, key: int):
303 assert (
304 0 <= key < self.lim
305 ), f"ValueError: {key} in BinaryTrieSet, lim={self.lim}"
306 return self._find(key) != -1
307
308 def __getitem__(self, k: int):
309 assert (
310 -len(self) <= k < len(self)
311 ), f"IndexError: BinaryTrieSet[{k}], len={len(self)}"
312 if k < 0:
313 k += len(self)
314 left, right, size = self.left, self.right, self.size
315 node = self.root
316 res = 0
317 for i in range(self.bit - 1, -1, -1):
318 if self.xor >> i & 1:
319 left, right = right, left
320 t = size[left[node]]
321 if t <= k:
322 k -= t
323 node = right[node]
324 res |= 1 << i
325 else:
326 node = left[node]
327 if self.xor >> i & 1:
328 left, right = right, left
329 return res
330
331 def __bool__(self):
332 return self.size[self.root] != 0
333
334 def __iter__(self):
335 self.it = 0
336 return self
337
338 def __next__(self):
339 if self.it == len(self):
340 raise StopIteration
341 self.it += 1
342 return self.__getitem__(self.it - 1)
343
344 def __len__(self):
345 return self.size[self.root]
346
347 def __str__(self):
348 return "{" + ", ".join(map(str, self)) + "}"
349
350 def __repr__(self):
351 return f"BinaryTrieSet({(1<<self.bit)-1}, {self})"