1from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
2from typing import Optional, Iterable
3from array import array
4
5
[docs]
6class BinaryTrieMultiset(OrderedMultisetInterface):
7
8 def __init__(self, u: int, a: Iterable[int] = []) -> None:
9 self.left = array("I", bytes(8))
10 self.right = array("I", bytes(8))
11 self.par = array("I", bytes(8))
12 self.size = array("I", bytes(8))
13 self.end: int = 2
14 self.root: int = 1
15 self.bit: int = (u - 1).bit_length()
16 self.lim: int = 1 << self.bit
17 self.xor: int = 0
18 for e in a:
19 self.add(e)
20
21 def _make_node(self) -> int:
22 if self.end >= len(self.left):
23 self.left.append(0)
24 self.right.append(0)
25 self.par.append(0)
26 self.size.append(0)
27 self.end += 1
28 return self.end - 1
29
30 def _find(self, key: int) -> int:
31 assert (
32 0 <= key < self.lim
33 ), f"ValueError: BinaryTrieMultiset._find({key}), lim={self.lim}"
34 left, right = self.left, self.right
35 key ^= self.xor
36 node = self.root
37 for i in range(self.bit - 1, -1, -1):
38 if key >> i & 1:
39 if not right[node]:
40 return -1
41 node = right[node]
42 else:
43 if not left[node]:
44 return -1
45 node = left[node]
46 return node
47
[docs]
48 def reserve(self, n: int) -> None:
49 assert n >= 0, f"ValueError: BinaryTrieMultiset.reserve({n})"
50 a = array("I", bytes(4 * n))
51 self.left += a
52 self.right += a
53 self.par += a
54 self.size += a
55
[docs]
56 def add(self, key: int, cnt: int = 1) -> None:
57 assert (
58 0 <= key < self.lim
59 ), f"ValueError: BinaryTrieMultiset.add({key}), lim={self.lim}"
60 left, right, par, size = self.left, self.right, self.par, self.size
61 key ^= self.xor
62 node = self.root
63 for i in range(self.bit - 1, -1, -1):
64 size[node] += cnt
65 if key >> i & 1:
66 if not right[node]:
67 right[node] = self._make_node()
68 par[right[node]] = node
69 node = right[node]
70 else:
71 if not left[node]:
72 left[node] = self._make_node()
73 par[left[node]] = node
74 node = left[node]
75 size[node] += cnt
76
77 def _remove(self, node: int) -> None:
78 left, right, par, size = self.left, self.right, self.par, self.size
79 cnt = size[node]
80 for _ in range(self.bit):
81 size[node] -= cnt
82 if left[par[node]] == node:
83 node = par[node]
84 left[node] = 0
85 if right[node]:
86 break
87 else:
88 node = par[node]
89 right[node] = 0
90 if left[node]:
91 break
92 while node:
93 size[node] -= cnt
94 node = par[node]
95
[docs]
96 def discard(self, key: int, cnt: int = 1) -> bool:
97 assert (
98 0 <= key < self.lim
99 ), f"ValueError: BinaryTrieMultiset.discard({key}), lim={self.lim}"
100 par, size = self.par, self.size
101 node = self._find(key)
102 if node == -1:
103 return False
104 if size[node] <= cnt:
105 self._remove(node)
106 else:
107 while node:
108 size[node] -= cnt
109 node = par[node]
110 return True
111
[docs]
112 def discard_all(self, key: int) -> bool:
113 return self.discard(key, self.count(key))
114
[docs]
115 def remove(self, key: int, cnt: int = 1) -> None:
116 left, right, par, size = self.left, self.right, self.par, self.size
117 _key = key
118 key ^= self.xor
119 node = self.root
120 for i in range(self.bit - 1, -1, -1):
121 if key >> i & 1:
122 if not right[node]:
123 assert False, f"{_key} is not found."
124 node = right[node]
125 else:
126 if not left[node]:
127 assert False, f"{_key} is not found."
128 node = left[node]
129 c = self.size[node]
130 if c < cnt:
131 assert False, f"{_key} is not found."
132 elif c == cnt:
133 self._remove(node)
134 else:
135 while node:
136 size[node] -= cnt
137 node = par[node]
138
[docs]
139 def count(self, key: int) -> int:
140 node = self._find(key)
141 return 0 if node == -1 else self.size[node]
142
[docs]
143 def pop(self, k: int = -1) -> int:
144 assert (
145 -len(self) <= k < len(self)
146 ), f"IndexError: BinaryTrieMultiset.pop({k}), len={len(self)}"
147 if k < 0:
148 k += len(self)
149 left, right, par, size = self.left, self.right, self.par, self.size
150 node = self.root
151 res = 0
152 for i in range(self.bit - 1, -1, -1):
153 b = self.xor >> i & 1
154 if b:
155 left, right = right, left
156 t = size[left[node]]
157 res <<= 1
158 if not left[node]:
159 node = right[node]
160 res |= 1
161 elif not right[node]:
162 node = left[node]
163 else:
164 t = size[left[node]]
165 if t <= k:
166 k -= t
167 res |= 1
168 node = right[node]
169 else:
170 node = left[node]
171 if b:
172 left, right = right, left
173 if size[node] == 1:
174 self._remove(node)
175 else:
176 while node:
177 size[node] -= 1
178 node = par[node]
179 return res ^ self.xor
180
[docs]
181 def pop_min(self) -> int:
182 assert self, f"IndexError: BinaryTrieMultiset.pop_min(), len={len(self)}"
183 return self.pop(0)
184
[docs]
185 def pop_max(self) -> int:
186 return self.pop()
187
[docs]
188 def all_xor(self, x: int) -> None:
189 assert (
190 0 <= x < self.lim
191 ), f"ValueError: BinaryTrieMultiset.all_xor({x}), lim={self.lim}"
192 self.xor ^= x
193
[docs]
194 def get_min(self) -> Optional[int]:
195 if not self:
196 return None
197 left, right = self.left, self.right
198 key = self.xor
199 ans = 0
200 node = self.root
201 for i in range(self.bit - 1, -1, -1):
202 ans <<= 1
203 if key >> i & 1:
204 if right[node]:
205 node = right[node]
206 ans |= 1
207 else:
208 node = left[node]
209 else:
210 if left[node]:
211 node = left[node]
212 else:
213 node = right[node]
214 ans |= 1
215 return ans ^ self.xor
216
[docs]
217 def get_max(self) -> Optional[int]:
218 if not self:
219 return None
220 left, right = self.left, self.right
221 key = self.xor
222 ans = 0
223 node = self.root
224 for i in range(self.bit - 1, -1, -1):
225 ans <<= 1
226 if key >> i & 1:
227 if left[node]:
228 node = left[node]
229 else:
230 node = right[node]
231 ans |= 1
232 else:
233 if right[node]:
234 ans |= 1
235 node = right[node]
236 else:
237 node = left[node]
238 return ans ^ self.xor
239
[docs]
240 def index(self, key: int) -> int:
241 assert (
242 0 <= key < self.lim
243 ), f"ValueError: BinaryTrieMultiset.index({key}), lim={self.lim}"
244 left, right, size = self.left, self.right, self.size
245 k = 0
246 node = self.root
247 key ^= self.xor
248 for i in range(self.bit - 1, -1, -1):
249 if key >> i & 1:
250 k += size[left[node]]
251 node = right[node]
252 else:
253 node = left[node]
254 if not node:
255 break
256 return k
257
[docs]
258 def index_right(self, key: int) -> int:
259 assert (
260 0 <= key < self.lim
261 ), f"ValueError: BinaryTrieMultiset.index_right({key}), lim={self.lim}"
262 left, right, size = self.left, self.right, self.size
263 k = 0
264 node = self.root
265 key ^= self.xor
266 for i in range(self.bit - 1, -1, -1):
267 if key >> i & 1:
268 k += size[left[node]]
269 node = right[node]
270 else:
271 node = left[node]
272 if not node:
273 break
274 else:
275 k += size[node]
276 return k
277
[docs]
278 def gt(self, key: int) -> Optional[int]:
279 assert (
280 0 <= key < self.lim
281 ), f"ValueError: BinaryTrieMultiset.gt({key}), lim={self.lim}"
282 i = self.index_right(key)
283 return None if i >= self.size[self.root] else self[i]
284
[docs]
285 def lt(self, key: int) -> Optional[int]:
286 assert (
287 0 <= key < self.lim
288 ), f"ValueError: BinaryTrieMultiset.lt({key}), lim={self.lim}"
289 i = self.index(key) - 1
290 return None if i < 0 else self[i]
291
[docs]
292 def ge(self, key: int) -> Optional[int]:
293 assert (
294 0 <= key < self.lim
295 ), f"ValueError: BinaryTrieMultiset.ge({key}), lim={self.lim}"
296 if key == 0:
297 return self.get_min() if self else None
298 i = self.index_right(key - 1)
299 return None if i >= self.size[self.root] else self[i]
300
[docs]
301 def le(self, key: int) -> Optional[int]:
302 assert (
303 0 <= key < self.lim
304 ), f"ValueError: BinaryTrieMultiset.le({key}), lim={self.lim}"
305 i = self.index(key + 1) - 1
306 return None if i < 0 else self[i]
307
[docs]
308 def tolist(self) -> list[int]:
309 a = []
310 if not self:
311 return a
312 val = self.get_min()
313 while val is not None:
314 for _ in range(self.count(val)):
315 a.append(val)
316 val = self.gt(val)
317 return a
318
[docs]
319 def clear(self) -> None:
320 self.root = 1
321
322 def __contains__(self, key: int) -> bool:
323 assert (
324 0 <= key < self.lim
325 ), f"ValueError: BinaryTrieMultiset.__contains__({key}), lim={self.lim}"
326 return self._find(key) != -1
327
328 def __getitem__(self, k: int) -> int:
329 assert (
330 -len(self) <= k < len(self)
331 ), f"IndexError: BinaryTrieMultiset[({k}], len={len(self)}"
332 if k < 0:
333 k += len(self)
334 left, right, size = self.left, self.right, self.size
335 node = self.root
336 res = 0
337 for i in range(self.bit - 1, -1, -1):
338 b = self.xor >> i & 1
339 if b:
340 left, right = right, left
341 t = size[left[node]]
342 res <<= 1
343 if not left[node]:
344 node = right[node]
345 res |= 1
346 elif not right[node]:
347 node = left[node]
348 else:
349 t = size[left[node]]
350 if t <= k:
351 k -= t
352 res |= 1
353 node = right[node]
354 else:
355 node = left[node]
356 if b:
357 left, right = right, left
358 return res
359
360 def __bool__(self):
361 return self.size[self.root] != 0
362
363 def __iter__(self):
364 self.it = 0
365 return self
366
367 def __next__(self):
368 if self.it == len(self):
369 raise StopIteration
370 self.it += 1
371 return self.__getitem__(self.it - 1)
372
373 def __len__(self):
374 return self.size[self.root]
375
376 def __str__(self):
377 return "{" + ", ".join(map(str, self.tolist())) + "}"
378
379 def __repr__(self):
380 return f"BinaryTrieMultiset({(1<<self.bit)-1}, {self.tolist()})"