binary_trie_set¶
ソースコード¶
from titan_pylib.data_structures.binary_trie.binary_trie_set import BinaryTrieSet
展開済みコード¶
1# from titan_pylib.data_structures.binary_trie.binary_trie_set import BinaryTrieSet
2# from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
4from typing import Protocol
5
6
7class SupportsLessThan(Protocol):
8
9 def __lt__(self, other) -> bool: ...
10from abc import ABC, abstractmethod
11from typing import Iterable, Optional, Iterator, TypeVar, Generic
12
13T = TypeVar("T", bound=SupportsLessThan)
14
15
16class OrderedSetInterface(ABC, Generic[T]):
17
18 @abstractmethod
19 def __init__(self, a: Iterable[T]) -> None:
20 raise NotImplementedError
21
22 @abstractmethod
23 def add(self, key: T) -> bool:
24 raise NotImplementedError
25
26 @abstractmethod
27 def discard(self, key: T) -> bool:
28 raise NotImplementedError
29
30 @abstractmethod
31 def remove(self, key: T) -> None:
32 raise NotImplementedError
33
34 @abstractmethod
35 def le(self, key: T) -> Optional[T]:
36 raise NotImplementedError
37
38 @abstractmethod
39 def lt(self, key: T) -> Optional[T]:
40 raise NotImplementedError
41
42 @abstractmethod
43 def ge(self, key: T) -> Optional[T]:
44 raise NotImplementedError
45
46 @abstractmethod
47 def gt(self, key: T) -> Optional[T]:
48 raise NotImplementedError
49
50 @abstractmethod
51 def get_max(self) -> Optional[T]:
52 raise NotImplementedError
53
54 @abstractmethod
55 def get_min(self) -> Optional[T]:
56 raise NotImplementedError
57
58 @abstractmethod
59 def pop_max(self) -> T:
60 raise NotImplementedError
61
62 @abstractmethod
63 def pop_min(self) -> T:
64 raise NotImplementedError
65
66 @abstractmethod
67 def clear(self) -> None:
68 raise NotImplementedError
69
70 @abstractmethod
71 def tolist(self) -> list[T]:
72 raise NotImplementedError
73
74 @abstractmethod
75 def __iter__(self) -> Iterator:
76 raise NotImplementedError
77
78 @abstractmethod
79 def __next__(self) -> T:
80 raise NotImplementedError
81
82 @abstractmethod
83 def __contains__(self, key: T) -> bool:
84 raise NotImplementedError
85
86 @abstractmethod
87 def __len__(self) -> int:
88 raise NotImplementedError
89
90 @abstractmethod
91 def __bool__(self) -> bool:
92 raise NotImplementedError
93
94 @abstractmethod
95 def __str__(self) -> str:
96 raise NotImplementedError
97
98 @abstractmethod
99 def __repr__(self) -> str:
100 raise NotImplementedError
101from typing import Optional, Iterable
102from array import array
103
104
105class BinaryTrieSet(OrderedSetInterface):
106
107 def __init__(self, u: int, a: Iterable[int] = []) -> None:
108 """構築します。
109 :math:`O(n\\log{u})` です。
110 """
111 self.left = array("I", bytes(8))
112 self.right = array("I", bytes(8))
113 self.par = array("I", bytes(8))
114 self.size = array("I", bytes(8))
115 self.valid = array("B", bytes(8))
116 self.end = 2
117 self.root = 1
118 self.bit = (u - 1).bit_length()
119 self.lim = 1 << self.bit
120 self.xor = 0
121 for e in a:
122 self.add(e)
123
124 def _make_node(self) -> int:
125 end = self.end
126 if end >= len(self.left):
127 self.left.append(0)
128 self.right.append(0)
129 self.par.append(0)
130 self.size.append(0)
131 self.valid.append(1)
132 else:
133 self.valid[end] = 1
134 self.end += 1
135 return end
136
137 def _find(self, key: int) -> int:
138 left, right, valid = self.left, self.right, self.valid
139 key ^= self.xor
140 node = self.root
141 for i in range(self.bit - 1, -1, -1):
142 if key >> i & 1:
143 if (not right[node]) or (not valid[node]):
144 return -1
145 node = right[node]
146 else:
147 if (not left[node]) or (not valid[node]):
148 return -1
149 node = left[node]
150 return node
151
152 def reserve(self, n: int) -> None:
153 """``n`` 要素分のメモリを確保します。
154
155 :math:`O(n)` です。
156 """
157 assert n >= 0, f"ValueError: BinaryTrieSet.reserve({n})"
158 a = array("I", bytes(4 * n))
159 self.left += a
160 self.right += a
161 self.par += a
162 self.size += a
163 self.valid += array("B", bytes(n))
164
165 def add(self, key: int) -> bool:
166 assert (
167 0 <= key < self.lim
168 ), f"ValueError: BinaryTrieSet.add({key}), lim={self.lim}"
169 left, right, par, size = self.left, self.right, self.par, self.size
170 key ^= self.xor
171 node = self.root
172 for i in range(self.bit - 1, -1, -1):
173 if key >> i & 1:
174 if not right[node]:
175 right[node] = self._make_node()
176 par[right[node]] = node
177 node = right[node]
178 else:
179 if not left[node]:
180 left[node] = self._make_node()
181 par[left[node]] = node
182 node = left[node]
183 if size[node]:
184 return False
185 size[node] = 1
186 for _ in range(self.bit):
187 node = par[node]
188 size[node] += 1
189 return True
190
191 def _rmeove(self, node: int) -> None:
192 left, right, par, size, valid = (
193 self.left,
194 self.right,
195 self.par,
196 self.size,
197 self.valid,
198 )
199 for _ in range(self.bit):
200 size[node] -= 1
201 if left[par[node]] == node:
202 node = par[node]
203 # left[node] = 0
204 valid[left[node]] = 0
205 if right[node]:
206 break
207 else:
208 node = par[node]
209 # right[node] = 0
210 valid[right[node]] = 0
211 if left[node]:
212 break
213 while node:
214 size[node] -= 1
215 node = par[node]
216
217 def discard(self, key: int) -> bool:
218 assert (
219 0 <= key < self.lim
220 ), f"ValueError: BinaryTrieSet.discard({key}), lim={self.lim}"
221 node = self._find(key)
222 if node == -1:
223 return False
224 self._rmeove(node)
225 return True
226
227 def remove(self, key: int) -> None:
228 if self.discard(key):
229 return
230 raise KeyError(key)
231
232 def pop(self, k: int = -1) -> int:
233 assert (
234 -len(self) <= k < len(self)
235 ), f"IndexError: BinaryTrieSet.pop({k}), len={len(self)}"
236 if k < 0:
237 k += len(self)
238 left, right, size = self.left, self.right, self.size
239 node = self.root
240 res = 0
241 for i in range(self.bit - 1, -1, -1):
242 res <<= 1
243 if self.xor >> i & 1:
244 left, right = right, left
245 t = size[left[node]]
246 if t <= k:
247 k -= t
248 res |= 1
249 node = right[node]
250 else:
251 node = left[node]
252 if self.xor >> i & 1:
253 left, right = right, left
254 self._rmeove(node)
255 return res ^ self.xor
256
257 def pop_min(self) -> int:
258 assert self, f"IndexError: BinaryTrieSet.pop_min(), len={len(self)}"
259 return self.pop(0)
260
261 def pop_max(self) -> int:
262 return self.pop()
263
264 def all_xor(self, x: int) -> None:
265 """すべての要素に ``x`` で ``xor`` をかけます。
266
267 :math:`O(1)` です。
268 """
269 assert (
270 0 <= x < self.lim
271 ), f"ValueError: BinaryTrieSet.all_xor({x}), lim={self.lim}"
272 self.xor ^= x
273
274 def get_min(self) -> Optional[int]:
275 if not self:
276 return None
277 left, right = self.left, self.right
278 key = self.xor
279 ans = 0
280 node = self.root
281 for i in range(self.bit - 1, -1, -1):
282 ans <<= 1
283 if key >> i & 1:
284 if right[node]:
285 node = right[node]
286 ans |= 1
287 else:
288 node = left[node]
289 else:
290 if left[node]:
291 node = left[node]
292 else:
293 node = right[node]
294 ans |= 1
295 return ans ^ self.xor
296
297 def get_max(self) -> Optional[int]:
298 if not self:
299 return None
300 left, right = self.left, self.right
301 key = self.xor
302 ans = 0
303 node = self.root
304 for i in range(self.bit - 1, -1, -1):
305 ans <<= 1
306 if key >> i & 1:
307 if left[node]:
308 node = left[node]
309 else:
310 node = right[node]
311 ans |= 1
312 else:
313 if right[node]:
314 ans |= 1
315 node = right[node]
316 else:
317 node = left[node]
318 return ans ^ self.xor
319
320 def index(self, key: int) -> int:
321 assert (
322 0 <= key < self.lim
323 ), f"ValueError: BinaryTrieSet.index({key}), lim={self.lim}"
324 left, right, size, valid = self.left, self.right, self.size, self.valid
325 k = 0
326 node = self.root
327 key ^= self.xor
328 for i in range(self.bit - 1, -1, -1):
329 if key >> i & 1:
330 k += size[left[node]]
331 node = right[node]
332 else:
333 node = left[node]
334 if (not node) or (not valid[node]):
335 break
336 return k
337
338 def index_right(self, key: int) -> int:
339 assert (
340 0 <= key < self.lim
341 ), f"ValueError: BinaryTrieSet.index_right({key}), lim={self.lim}"
342 left, right, size, valid = self.left, self.right, self.size, self.valid
343 k = 0
344 node = self.root
345 key ^= self.xor
346 for i in range(self.bit - 1, -1, -1):
347 if key >> i & 1:
348 k += size[left[node]]
349 node = right[node]
350 else:
351 node = left[node]
352 if (not node) or (not valid[node]):
353 break
354 else:
355 k += 1
356 return k
357
358 def clear(self) -> None:
359 self.root = 1
360
361 def gt(self, key: int) -> Optional[int]:
362 assert (
363 0 <= key < self.lim
364 ), f"ValueError: BinaryTrieSet.gt({key}), lim={self.lim}"
365 i = self.index_right(key)
366 return None if i >= self.size[self.root] else self[i]
367
368 def lt(self, key: int) -> Optional[int]:
369 assert (
370 0 <= key < self.lim
371 ), f"ValueError: BinaryTrieSet.lt({key}), lim={self.lim}"
372 i = self.index(key) - 1
373 return None if i < 0 else self[i]
374
375 def ge(self, key: int) -> Optional[int]:
376 assert (
377 0 <= key < self.lim
378 ), f"ValueError: BinaryTrieSet.ge({key}), lim={self.lim}"
379 if key == 0:
380 return self.get_min() if self else None
381 i = self.index_right(key - 1)
382 return None if i >= self.size[self.root] else self[i]
383
384 def le(self, key: int) -> Optional[int]:
385 assert (
386 0 <= key < self.lim
387 ), f"ValueError: BinaryTrieSet.le({key}), lim={self.lim}"
388 i = self.index(key + 1) - 1
389 return None if i < 0 else self[i]
390
391 def tolist(self) -> list[int]:
392 a = []
393 if not self:
394 return a
395 val = self.get_min()
396 while val is not None:
397 a.append(val)
398 val = self.gt(val)
399 return a
400
401 def __contains__(self, key: int):
402 assert (
403 0 <= key < self.lim
404 ), f"ValueError: {key} in BinaryTrieSet, lim={self.lim}"
405 return self._find(key) != -1
406
407 def __getitem__(self, k: int):
408 assert (
409 -len(self) <= k < len(self)
410 ), f"IndexError: BinaryTrieSet[{k}], len={len(self)}"
411 if k < 0:
412 k += len(self)
413 left, right, size = self.left, self.right, self.size
414 node = self.root
415 res = 0
416 for i in range(self.bit - 1, -1, -1):
417 if self.xor >> i & 1:
418 left, right = right, left
419 t = size[left[node]]
420 if t <= k:
421 k -= t
422 node = right[node]
423 res |= 1 << i
424 else:
425 node = left[node]
426 if self.xor >> i & 1:
427 left, right = right, left
428 return res
429
430 def __bool__(self):
431 return self.size[self.root] != 0
432
433 def __iter__(self):
434 self.it = 0
435 return self
436
437 def __next__(self):
438 if self.it == len(self):
439 raise StopIteration
440 self.it += 1
441 return self.__getitem__(self.it - 1)
442
443 def __len__(self):
444 return self.size[self.root]
445
446 def __str__(self):
447 return "{" + ", ".join(map(str, self)) + "}"
448
449 def __repr__(self):
450 return f"BinaryTrieSet({(1<<self.bit)-1}, {self})"
仕様¶
- class BinaryTrieSet(u: int, a: Iterable[int] = [])[source]¶
Bases:
OrderedSetInterface