binary_trie_multiset¶
ソースコード¶
from titan_pylib.data_structures.binary_trie.binary_trie_multiset import BinaryTrieMultiset
展開済みコード¶
1# from titan_pylib.data_structures.binary_trie.binary_trie_multiset import BinaryTrieMultiset
2# from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
4from typing import Protocol
5
6
7class SupportsLessThan(Protocol):
8
9 def __lt__(self, other) -> bool: ...
10from abc import ABC, abstractmethod
11from typing import Iterable, Optional, Iterator, TypeVar, Generic
12
13T = TypeVar("T", bound=SupportsLessThan)
14
15
16class OrderedMultisetInterface(ABC, Generic[T]):
17
18 @abstractmethod
19 def __init__(self, a: Iterable[T]) -> None:
20 raise NotImplementedError
21
22 @abstractmethod
23 def add(self, key: T, cnt: int) -> None:
24 raise NotImplementedError
25
26 @abstractmethod
27 def discard(self, key: T, cnt: int) -> bool:
28 raise NotImplementedError
29
30 @abstractmethod
31 def discard_all(self, key: T) -> bool:
32 raise NotImplementedError
33
34 @abstractmethod
35 def count(self, key: T) -> int:
36 raise NotImplementedError
37
38 @abstractmethod
39 def remove(self, key: T, cnt: int) -> None:
40 raise NotImplementedError
41
42 @abstractmethod
43 def le(self, key: T) -> Optional[T]:
44 raise NotImplementedError
45
46 @abstractmethod
47 def lt(self, key: T) -> Optional[T]:
48 raise NotImplementedError
49
50 @abstractmethod
51 def ge(self, key: T) -> Optional[T]:
52 raise NotImplementedError
53
54 @abstractmethod
55 def gt(self, key: T) -> Optional[T]:
56 raise NotImplementedError
57
58 @abstractmethod
59 def get_max(self) -> Optional[T]:
60 raise NotImplementedError
61
62 @abstractmethod
63 def get_min(self) -> Optional[T]:
64 raise NotImplementedError
65
66 @abstractmethod
67 def pop_max(self) -> T:
68 raise NotImplementedError
69
70 @abstractmethod
71 def pop_min(self) -> T:
72 raise NotImplementedError
73
74 @abstractmethod
75 def clear(self) -> None:
76 raise NotImplementedError
77
78 @abstractmethod
79 def tolist(self) -> list[T]:
80 raise NotImplementedError
81
82 @abstractmethod
83 def __iter__(self) -> Iterator:
84 raise NotImplementedError
85
86 @abstractmethod
87 def __next__(self) -> T:
88 raise NotImplementedError
89
90 @abstractmethod
91 def __contains__(self, key: T) -> bool:
92 raise NotImplementedError
93
94 @abstractmethod
95 def __len__(self) -> int:
96 raise NotImplementedError
97
98 @abstractmethod
99 def __bool__(self) -> bool:
100 raise NotImplementedError
101
102 @abstractmethod
103 def __str__(self) -> str:
104 raise NotImplementedError
105
106 @abstractmethod
107 def __repr__(self) -> str:
108 raise NotImplementedError
109from typing import Optional, Iterable
110from array import array
111
112
113class BinaryTrieMultiset(OrderedMultisetInterface):
114
115 def __init__(self, u: int, a: Iterable[int] = []) -> None:
116 self.left = array("I", bytes(8))
117 self.right = array("I", bytes(8))
118 self.par = array("I", bytes(8))
119 self.size = array("I", bytes(8))
120 self.end: int = 2
121 self.root: int = 1
122 self.bit: int = (u - 1).bit_length()
123 self.lim: int = 1 << self.bit
124 self.xor: int = 0
125 for e in a:
126 self.add(e)
127
128 def _make_node(self) -> int:
129 if self.end >= len(self.left):
130 self.left.append(0)
131 self.right.append(0)
132 self.par.append(0)
133 self.size.append(0)
134 self.end += 1
135 return self.end - 1
136
137 def _find(self, key: int) -> int:
138 assert (
139 0 <= key < self.lim
140 ), f"ValueError: BinaryTrieMultiset._find({key}), lim={self.lim}"
141 left, right = self.left, self.right
142 key ^= self.xor
143 node = self.root
144 for i in range(self.bit - 1, -1, -1):
145 if key >> i & 1:
146 if not right[node]:
147 return -1
148 node = right[node]
149 else:
150 if not left[node]:
151 return -1
152 node = left[node]
153 return node
154
155 def reserve(self, n: int) -> None:
156 assert n >= 0, f"ValueError: BinaryTrieMultiset.reserve({n})"
157 a = array("I", bytes(4 * n))
158 self.left += a
159 self.right += a
160 self.par += a
161 self.size += a
162
163 def add(self, key: int, cnt: int = 1) -> None:
164 assert (
165 0 <= key < self.lim
166 ), f"ValueError: BinaryTrieMultiset.add({key}), lim={self.lim}"
167 left, right, par, size = self.left, self.right, self.par, self.size
168 key ^= self.xor
169 node = self.root
170 for i in range(self.bit - 1, -1, -1):
171 size[node] += cnt
172 if key >> i & 1:
173 if not right[node]:
174 right[node] = self._make_node()
175 par[right[node]] = node
176 node = right[node]
177 else:
178 if not left[node]:
179 left[node] = self._make_node()
180 par[left[node]] = node
181 node = left[node]
182 size[node] += cnt
183
184 def _remove(self, node: int) -> None:
185 left, right, par, size = self.left, self.right, self.par, self.size
186 cnt = size[node]
187 for _ in range(self.bit):
188 size[node] -= cnt
189 if left[par[node]] == node:
190 node = par[node]
191 left[node] = 0
192 if right[node]:
193 break
194 else:
195 node = par[node]
196 right[node] = 0
197 if left[node]:
198 break
199 while node:
200 size[node] -= cnt
201 node = par[node]
202
203 def discard(self, key: int, cnt: int = 1) -> bool:
204 assert (
205 0 <= key < self.lim
206 ), f"ValueError: BinaryTrieMultiset.discard({key}), lim={self.lim}"
207 par, size = self.par, self.size
208 node = self._find(key)
209 if node == -1:
210 return False
211 if size[node] <= cnt:
212 self._remove(node)
213 else:
214 while node:
215 size[node] -= cnt
216 node = par[node]
217 return True
218
219 def discard_all(self, key: int) -> bool:
220 return self.discard(key, self.count(key))
221
222 def remove(self, key: int, cnt: int = 1) -> None:
223 left, right, par, size = self.left, self.right, self.par, self.size
224 _key = key
225 key ^= self.xor
226 node = self.root
227 for i in range(self.bit - 1, -1, -1):
228 if key >> i & 1:
229 if not right[node]:
230 assert False, f"{_key} is not found."
231 node = right[node]
232 else:
233 if not left[node]:
234 assert False, f"{_key} is not found."
235 node = left[node]
236 c = self.size[node]
237 if c < cnt:
238 assert False, f"{_key} is not found."
239 elif c == cnt:
240 self._remove(node)
241 else:
242 while node:
243 size[node] -= cnt
244 node = par[node]
245
246 def count(self, key: int) -> int:
247 node = self._find(key)
248 return 0 if node == -1 else self.size[node]
249
250 def pop(self, k: int = -1) -> int:
251 assert (
252 -len(self) <= k < len(self)
253 ), f"IndexError: BinaryTrieMultiset.pop({k}), len={len(self)}"
254 if k < 0:
255 k += len(self)
256 left, right, par, size = self.left, self.right, self.par, self.size
257 node = self.root
258 res = 0
259 for i in range(self.bit - 1, -1, -1):
260 b = self.xor >> i & 1
261 if b:
262 left, right = right, left
263 t = size[left[node]]
264 res <<= 1
265 if not left[node]:
266 node = right[node]
267 res |= 1
268 elif not right[node]:
269 node = left[node]
270 else:
271 t = size[left[node]]
272 if t <= k:
273 k -= t
274 res |= 1
275 node = right[node]
276 else:
277 node = left[node]
278 if b:
279 left, right = right, left
280 if size[node] == 1:
281 self._remove(node)
282 else:
283 while node:
284 size[node] -= 1
285 node = par[node]
286 return res ^ self.xor
287
288 def pop_min(self) -> int:
289 assert self, f"IndexError: BinaryTrieMultiset.pop_min(), len={len(self)}"
290 return self.pop(0)
291
292 def pop_max(self) -> int:
293 return self.pop()
294
295 def all_xor(self, x: int) -> None:
296 assert (
297 0 <= x < self.lim
298 ), f"ValueError: BinaryTrieMultiset.all_xor({x}), lim={self.lim}"
299 self.xor ^= x
300
301 def get_min(self) -> Optional[int]:
302 if not self:
303 return None
304 left, right = self.left, self.right
305 key = self.xor
306 ans = 0
307 node = self.root
308 for i in range(self.bit - 1, -1, -1):
309 ans <<= 1
310 if key >> i & 1:
311 if right[node]:
312 node = right[node]
313 ans |= 1
314 else:
315 node = left[node]
316 else:
317 if left[node]:
318 node = left[node]
319 else:
320 node = right[node]
321 ans |= 1
322 return ans ^ self.xor
323
324 def get_max(self) -> Optional[int]:
325 if not self:
326 return None
327 left, right = self.left, self.right
328 key = self.xor
329 ans = 0
330 node = self.root
331 for i in range(self.bit - 1, -1, -1):
332 ans <<= 1
333 if key >> i & 1:
334 if left[node]:
335 node = left[node]
336 else:
337 node = right[node]
338 ans |= 1
339 else:
340 if right[node]:
341 ans |= 1
342 node = right[node]
343 else:
344 node = left[node]
345 return ans ^ self.xor
346
347 def index(self, key: int) -> int:
348 assert (
349 0 <= key < self.lim
350 ), f"ValueError: BinaryTrieMultiset.index({key}), lim={self.lim}"
351 left, right, size = self.left, self.right, self.size
352 k = 0
353 node = self.root
354 key ^= self.xor
355 for i in range(self.bit - 1, -1, -1):
356 if key >> i & 1:
357 k += size[left[node]]
358 node = right[node]
359 else:
360 node = left[node]
361 if not node:
362 break
363 return k
364
365 def index_right(self, key: int) -> int:
366 assert (
367 0 <= key < self.lim
368 ), f"ValueError: BinaryTrieMultiset.index_right({key}), lim={self.lim}"
369 left, right, size = self.left, self.right, self.size
370 k = 0
371 node = self.root
372 key ^= self.xor
373 for i in range(self.bit - 1, -1, -1):
374 if key >> i & 1:
375 k += size[left[node]]
376 node = right[node]
377 else:
378 node = left[node]
379 if not node:
380 break
381 else:
382 k += size[node]
383 return k
384
385 def gt(self, key: int) -> Optional[int]:
386 assert (
387 0 <= key < self.lim
388 ), f"ValueError: BinaryTrieMultiset.gt({key}), lim={self.lim}"
389 i = self.index_right(key)
390 return None if i >= self.size[self.root] else self[i]
391
392 def lt(self, key: int) -> Optional[int]:
393 assert (
394 0 <= key < self.lim
395 ), f"ValueError: BinaryTrieMultiset.lt({key}), lim={self.lim}"
396 i = self.index(key) - 1
397 return None if i < 0 else self[i]
398
399 def ge(self, key: int) -> Optional[int]:
400 assert (
401 0 <= key < self.lim
402 ), f"ValueError: BinaryTrieMultiset.ge({key}), lim={self.lim}"
403 if key == 0:
404 return self.get_min() if self else None
405 i = self.index_right(key - 1)
406 return None if i >= self.size[self.root] else self[i]
407
408 def le(self, key: int) -> Optional[int]:
409 assert (
410 0 <= key < self.lim
411 ), f"ValueError: BinaryTrieMultiset.le({key}), lim={self.lim}"
412 i = self.index(key + 1) - 1
413 return None if i < 0 else self[i]
414
415 def tolist(self) -> list[int]:
416 a = []
417 if not self:
418 return a
419 val = self.get_min()
420 while val is not None:
421 for _ in range(self.count(val)):
422 a.append(val)
423 val = self.gt(val)
424 return a
425
426 def clear(self) -> None:
427 self.root = 1
428
429 def __contains__(self, key: int) -> bool:
430 assert (
431 0 <= key < self.lim
432 ), f"ValueError: BinaryTrieMultiset.__contains__({key}), lim={self.lim}"
433 return self._find(key) != -1
434
435 def __getitem__(self, k: int) -> int:
436 assert (
437 -len(self) <= k < len(self)
438 ), f"IndexError: BinaryTrieMultiset[({k}], len={len(self)}"
439 if k < 0:
440 k += len(self)
441 left, right, size = self.left, self.right, self.size
442 node = self.root
443 res = 0
444 for i in range(self.bit - 1, -1, -1):
445 b = self.xor >> i & 1
446 if b:
447 left, right = right, left
448 t = size[left[node]]
449 res <<= 1
450 if not left[node]:
451 node = right[node]
452 res |= 1
453 elif not right[node]:
454 node = left[node]
455 else:
456 t = size[left[node]]
457 if t <= k:
458 k -= t
459 res |= 1
460 node = right[node]
461 else:
462 node = left[node]
463 if b:
464 left, right = right, left
465 return res
466
467 def __bool__(self):
468 return self.size[self.root] != 0
469
470 def __iter__(self):
471 self.it = 0
472 return self
473
474 def __next__(self):
475 if self.it == len(self):
476 raise StopIteration
477 self.it += 1
478 return self.__getitem__(self.it - 1)
479
480 def __len__(self):
481 return self.size[self.root]
482
483 def __str__(self):
484 return "{" + ", ".join(map(str, self.tolist())) + "}"
485
486 def __repr__(self):
487 return f"BinaryTrieMultiset({(1<<self.bit)-1}, {self.tolist()})"
仕様¶
- class BinaryTrieMultiset(u: int, a: Iterable[int] = [])[source]¶
Bases:
OrderedMultisetInterface