binary_trie_multiset

ソースコード

from titan_pylib.data_structures.binary_trie.binary_trie_multiset import BinaryTrieMultiset

view on github

展開済みコード

  1# from titan_pylib.data_structures.binary_trie.binary_trie_multiset import BinaryTrieMultiset
  2# from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
  3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  4from typing import Protocol
  5
  6
  7class SupportsLessThan(Protocol):
  8
  9    def __lt__(self, other) -> bool: ...
 10from abc import ABC, abstractmethod
 11from typing import Iterable, Optional, Iterator, TypeVar, Generic
 12
 13T = TypeVar("T", bound=SupportsLessThan)
 14
 15
 16class OrderedMultisetInterface(ABC, Generic[T]):
 17
 18    @abstractmethod
 19    def __init__(self, a: Iterable[T]) -> None:
 20        raise NotImplementedError
 21
 22    @abstractmethod
 23    def add(self, key: T, cnt: int) -> None:
 24        raise NotImplementedError
 25
 26    @abstractmethod
 27    def discard(self, key: T, cnt: int) -> bool:
 28        raise NotImplementedError
 29
 30    @abstractmethod
 31    def discard_all(self, key: T) -> bool:
 32        raise NotImplementedError
 33
 34    @abstractmethod
 35    def count(self, key: T) -> int:
 36        raise NotImplementedError
 37
 38    @abstractmethod
 39    def remove(self, key: T, cnt: int) -> None:
 40        raise NotImplementedError
 41
 42    @abstractmethod
 43    def le(self, key: T) -> Optional[T]:
 44        raise NotImplementedError
 45
 46    @abstractmethod
 47    def lt(self, key: T) -> Optional[T]:
 48        raise NotImplementedError
 49
 50    @abstractmethod
 51    def ge(self, key: T) -> Optional[T]:
 52        raise NotImplementedError
 53
 54    @abstractmethod
 55    def gt(self, key: T) -> Optional[T]:
 56        raise NotImplementedError
 57
 58    @abstractmethod
 59    def get_max(self) -> Optional[T]:
 60        raise NotImplementedError
 61
 62    @abstractmethod
 63    def get_min(self) -> Optional[T]:
 64        raise NotImplementedError
 65
 66    @abstractmethod
 67    def pop_max(self) -> T:
 68        raise NotImplementedError
 69
 70    @abstractmethod
 71    def pop_min(self) -> T:
 72        raise NotImplementedError
 73
 74    @abstractmethod
 75    def clear(self) -> None:
 76        raise NotImplementedError
 77
 78    @abstractmethod
 79    def tolist(self) -> list[T]:
 80        raise NotImplementedError
 81
 82    @abstractmethod
 83    def __iter__(self) -> Iterator:
 84        raise NotImplementedError
 85
 86    @abstractmethod
 87    def __next__(self) -> T:
 88        raise NotImplementedError
 89
 90    @abstractmethod
 91    def __contains__(self, key: T) -> bool:
 92        raise NotImplementedError
 93
 94    @abstractmethod
 95    def __len__(self) -> int:
 96        raise NotImplementedError
 97
 98    @abstractmethod
 99    def __bool__(self) -> bool:
100        raise NotImplementedError
101
102    @abstractmethod
103    def __str__(self) -> str:
104        raise NotImplementedError
105
106    @abstractmethod
107    def __repr__(self) -> str:
108        raise NotImplementedError
109from typing import Optional, Iterable
110from array import array
111
112
113class BinaryTrieMultiset(OrderedMultisetInterface):
114
115    def __init__(self, u: int, a: Iterable[int] = []) -> None:
116        self.left = array("I", bytes(8))
117        self.right = array("I", bytes(8))
118        self.par = array("I", bytes(8))
119        self.size = array("I", bytes(8))
120        self.end: int = 2
121        self.root: int = 1
122        self.bit: int = (u - 1).bit_length()
123        self.lim: int = 1 << self.bit
124        self.xor: int = 0
125        for e in a:
126            self.add(e)
127
128    def _make_node(self) -> int:
129        if self.end >= len(self.left):
130            self.left.append(0)
131            self.right.append(0)
132            self.par.append(0)
133            self.size.append(0)
134        self.end += 1
135        return self.end - 1
136
137    def _find(self, key: int) -> int:
138        assert (
139            0 <= key < self.lim
140        ), f"ValueError: BinaryTrieMultiset._find({key}), lim={self.lim}"
141        left, right = self.left, self.right
142        key ^= self.xor
143        node = self.root
144        for i in range(self.bit - 1, -1, -1):
145            if key >> i & 1:
146                if not right[node]:
147                    return -1
148                node = right[node]
149            else:
150                if not left[node]:
151                    return -1
152                node = left[node]
153        return node
154
155    def reserve(self, n: int) -> None:
156        assert n >= 0, f"ValueError: BinaryTrieMultiset.reserve({n})"
157        a = array("I", bytes(4 * n))
158        self.left += a
159        self.right += a
160        self.par += a
161        self.size += a
162
163    def add(self, key: int, cnt: int = 1) -> None:
164        assert (
165            0 <= key < self.lim
166        ), f"ValueError: BinaryTrieMultiset.add({key}), lim={self.lim}"
167        left, right, par, size = self.left, self.right, self.par, self.size
168        key ^= self.xor
169        node = self.root
170        for i in range(self.bit - 1, -1, -1):
171            size[node] += cnt
172            if key >> i & 1:
173                if not right[node]:
174                    right[node] = self._make_node()
175                    par[right[node]] = node
176                node = right[node]
177            else:
178                if not left[node]:
179                    left[node] = self._make_node()
180                    par[left[node]] = node
181                node = left[node]
182        size[node] += cnt
183
184    def _remove(self, node: int) -> None:
185        left, right, par, size = self.left, self.right, self.par, self.size
186        cnt = size[node]
187        for _ in range(self.bit):
188            size[node] -= cnt
189            if left[par[node]] == node:
190                node = par[node]
191                left[node] = 0
192                if right[node]:
193                    break
194            else:
195                node = par[node]
196                right[node] = 0
197                if left[node]:
198                    break
199        while node:
200            size[node] -= cnt
201            node = par[node]
202
203    def discard(self, key: int, cnt: int = 1) -> bool:
204        assert (
205            0 <= key < self.lim
206        ), f"ValueError: BinaryTrieMultiset.discard({key}), lim={self.lim}"
207        par, size = self.par, self.size
208        node = self._find(key)
209        if node == -1:
210            return False
211        if size[node] <= cnt:
212            self._remove(node)
213        else:
214            while node:
215                size[node] -= cnt
216                node = par[node]
217        return True
218
219    def discard_all(self, key: int) -> bool:
220        return self.discard(key, self.count(key))
221
222    def remove(self, key: int, cnt: int = 1) -> None:
223        left, right, par, size = self.left, self.right, self.par, self.size
224        _key = key
225        key ^= self.xor
226        node = self.root
227        for i in range(self.bit - 1, -1, -1):
228            if key >> i & 1:
229                if not right[node]:
230                    assert False, f"{_key} is not found."
231                node = right[node]
232            else:
233                if not left[node]:
234                    assert False, f"{_key} is not found."
235                node = left[node]
236        c = self.size[node]
237        if c < cnt:
238            assert False, f"{_key} is not found."
239        elif c == cnt:
240            self._remove(node)
241        else:
242            while node:
243                size[node] -= cnt
244                node = par[node]
245
246    def count(self, key: int) -> int:
247        node = self._find(key)
248        return 0 if node == -1 else self.size[node]
249
250    def pop(self, k: int = -1) -> int:
251        assert (
252            -len(self) <= k < len(self)
253        ), f"IndexError: BinaryTrieMultiset.pop({k}), len={len(self)}"
254        if k < 0:
255            k += len(self)
256        left, right, par, size = self.left, self.right, self.par, self.size
257        node = self.root
258        res = 0
259        for i in range(self.bit - 1, -1, -1):
260            b = self.xor >> i & 1
261            if b:
262                left, right = right, left
263            t = size[left[node]]
264            res <<= 1
265            if not left[node]:
266                node = right[node]
267                res |= 1
268            elif not right[node]:
269                node = left[node]
270            else:
271                t = size[left[node]]
272                if t <= k:
273                    k -= t
274                    res |= 1
275                    node = right[node]
276                else:
277                    node = left[node]
278            if b:
279                left, right = right, left
280        if size[node] == 1:
281            self._remove(node)
282        else:
283            while node:
284                size[node] -= 1
285                node = par[node]
286        return res ^ self.xor
287
288    def pop_min(self) -> int:
289        assert self, f"IndexError: BinaryTrieMultiset.pop_min(), len={len(self)}"
290        return self.pop(0)
291
292    def pop_max(self) -> int:
293        return self.pop()
294
295    def all_xor(self, x: int) -> None:
296        assert (
297            0 <= x < self.lim
298        ), f"ValueError: BinaryTrieMultiset.all_xor({x}), lim={self.lim}"
299        self.xor ^= x
300
301    def get_min(self) -> Optional[int]:
302        if not self:
303            return None
304        left, right = self.left, self.right
305        key = self.xor
306        ans = 0
307        node = self.root
308        for i in range(self.bit - 1, -1, -1):
309            ans <<= 1
310            if key >> i & 1:
311                if right[node]:
312                    node = right[node]
313                    ans |= 1
314                else:
315                    node = left[node]
316            else:
317                if left[node]:
318                    node = left[node]
319                else:
320                    node = right[node]
321                    ans |= 1
322        return ans ^ self.xor
323
324    def get_max(self) -> Optional[int]:
325        if not self:
326            return None
327        left, right = self.left, self.right
328        key = self.xor
329        ans = 0
330        node = self.root
331        for i in range(self.bit - 1, -1, -1):
332            ans <<= 1
333            if key >> i & 1:
334                if left[node]:
335                    node = left[node]
336                else:
337                    node = right[node]
338                    ans |= 1
339            else:
340                if right[node]:
341                    ans |= 1
342                    node = right[node]
343                else:
344                    node = left[node]
345        return ans ^ self.xor
346
347    def index(self, key: int) -> int:
348        assert (
349            0 <= key < self.lim
350        ), f"ValueError: BinaryTrieMultiset.index({key}), lim={self.lim}"
351        left, right, size = self.left, self.right, self.size
352        k = 0
353        node = self.root
354        key ^= self.xor
355        for i in range(self.bit - 1, -1, -1):
356            if key >> i & 1:
357                k += size[left[node]]
358                node = right[node]
359            else:
360                node = left[node]
361            if not node:
362                break
363        return k
364
365    def index_right(self, key: int) -> int:
366        assert (
367            0 <= key < self.lim
368        ), f"ValueError: BinaryTrieMultiset.index_right({key}), lim={self.lim}"
369        left, right, size = self.left, self.right, self.size
370        k = 0
371        node = self.root
372        key ^= self.xor
373        for i in range(self.bit - 1, -1, -1):
374            if key >> i & 1:
375                k += size[left[node]]
376                node = right[node]
377            else:
378                node = left[node]
379            if not node:
380                break
381        else:
382            k += size[node]
383        return k
384
385    def gt(self, key: int) -> Optional[int]:
386        assert (
387            0 <= key < self.lim
388        ), f"ValueError: BinaryTrieMultiset.gt({key}), lim={self.lim}"
389        i = self.index_right(key)
390        return None if i >= self.size[self.root] else self[i]
391
392    def lt(self, key: int) -> Optional[int]:
393        assert (
394            0 <= key < self.lim
395        ), f"ValueError: BinaryTrieMultiset.lt({key}), lim={self.lim}"
396        i = self.index(key) - 1
397        return None if i < 0 else self[i]
398
399    def ge(self, key: int) -> Optional[int]:
400        assert (
401            0 <= key < self.lim
402        ), f"ValueError: BinaryTrieMultiset.ge({key}), lim={self.lim}"
403        if key == 0:
404            return self.get_min() if self else None
405        i = self.index_right(key - 1)
406        return None if i >= self.size[self.root] else self[i]
407
408    def le(self, key: int) -> Optional[int]:
409        assert (
410            0 <= key < self.lim
411        ), f"ValueError: BinaryTrieMultiset.le({key}), lim={self.lim}"
412        i = self.index(key + 1) - 1
413        return None if i < 0 else self[i]
414
415    def tolist(self) -> list[int]:
416        a = []
417        if not self:
418            return a
419        val = self.get_min()
420        while val is not None:
421            for _ in range(self.count(val)):
422                a.append(val)
423            val = self.gt(val)
424        return a
425
426    def clear(self) -> None:
427        self.root = 1
428
429    def __contains__(self, key: int) -> bool:
430        assert (
431            0 <= key < self.lim
432        ), f"ValueError: BinaryTrieMultiset.__contains__({key}), lim={self.lim}"
433        return self._find(key) != -1
434
435    def __getitem__(self, k: int) -> int:
436        assert (
437            -len(self) <= k < len(self)
438        ), f"IndexError: BinaryTrieMultiset[({k}], len={len(self)}"
439        if k < 0:
440            k += len(self)
441        left, right, size = self.left, self.right, self.size
442        node = self.root
443        res = 0
444        for i in range(self.bit - 1, -1, -1):
445            b = self.xor >> i & 1
446            if b:
447                left, right = right, left
448            t = size[left[node]]
449            res <<= 1
450            if not left[node]:
451                node = right[node]
452                res |= 1
453            elif not right[node]:
454                node = left[node]
455            else:
456                t = size[left[node]]
457                if t <= k:
458                    k -= t
459                    res |= 1
460                    node = right[node]
461                else:
462                    node = left[node]
463            if b:
464                left, right = right, left
465        return res
466
467    def __bool__(self):
468        return self.size[self.root] != 0
469
470    def __iter__(self):
471        self.it = 0
472        return self
473
474    def __next__(self):
475        if self.it == len(self):
476            raise StopIteration
477        self.it += 1
478        return self.__getitem__(self.it - 1)
479
480    def __len__(self):
481        return self.size[self.root]
482
483    def __str__(self):
484        return "{" + ", ".join(map(str, self.tolist())) + "}"
485
486    def __repr__(self):
487        return f"BinaryTrieMultiset({(1<<self.bit)-1}, {self.tolist()})"

仕様

class BinaryTrieMultiset(u: int, a: Iterable[int] = [])[source]

Bases: OrderedMultisetInterface

add(key: int, cnt: int = 1) None[source]
all_xor(x: int) None[source]
clear() None[source]
count(key: int) int[source]
discard(key: int, cnt: int = 1) bool[source]
discard_all(key: int) bool[source]
ge(key: int) int | None[source]
get_max() int | None[source]
get_min() int | None[source]
gt(key: int) int | None[source]
index(key: int) int[source]
index_right(key: int) int[source]
le(key: int) int | None[source]
lt(key: int) int | None[source]
pop(k: int = -1) int[source]
pop_max() int[source]
pop_min() int[source]
remove(key: int, cnt: int = 1) None[source]
reserve(n: int) None[source]
tolist() list[int][source]