binary_trie_set

ソースコード

from titan_pylib.data_structures.binary_trie.binary_trie_set import BinaryTrieSet

view on github

展開済みコード

  1# from titan_pylib.data_structures.binary_trie.binary_trie_set import BinaryTrieSet
  2# from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
  3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  4from typing import Protocol
  5
  6
  7class SupportsLessThan(Protocol):
  8
  9    def __lt__(self, other) -> bool: ...
 10from abc import ABC, abstractmethod
 11from typing import Iterable, Optional, Iterator, TypeVar, Generic
 12
 13T = TypeVar("T", bound=SupportsLessThan)
 14
 15
 16class OrderedSetInterface(ABC, Generic[T]):
 17
 18    @abstractmethod
 19    def __init__(self, a: Iterable[T]) -> None:
 20        raise NotImplementedError
 21
 22    @abstractmethod
 23    def add(self, key: T) -> bool:
 24        raise NotImplementedError
 25
 26    @abstractmethod
 27    def discard(self, key: T) -> bool:
 28        raise NotImplementedError
 29
 30    @abstractmethod
 31    def remove(self, key: T) -> None:
 32        raise NotImplementedError
 33
 34    @abstractmethod
 35    def le(self, key: T) -> Optional[T]:
 36        raise NotImplementedError
 37
 38    @abstractmethod
 39    def lt(self, key: T) -> Optional[T]:
 40        raise NotImplementedError
 41
 42    @abstractmethod
 43    def ge(self, key: T) -> Optional[T]:
 44        raise NotImplementedError
 45
 46    @abstractmethod
 47    def gt(self, key: T) -> Optional[T]:
 48        raise NotImplementedError
 49
 50    @abstractmethod
 51    def get_max(self) -> Optional[T]:
 52        raise NotImplementedError
 53
 54    @abstractmethod
 55    def get_min(self) -> Optional[T]:
 56        raise NotImplementedError
 57
 58    @abstractmethod
 59    def pop_max(self) -> T:
 60        raise NotImplementedError
 61
 62    @abstractmethod
 63    def pop_min(self) -> T:
 64        raise NotImplementedError
 65
 66    @abstractmethod
 67    def clear(self) -> None:
 68        raise NotImplementedError
 69
 70    @abstractmethod
 71    def tolist(self) -> list[T]:
 72        raise NotImplementedError
 73
 74    @abstractmethod
 75    def __iter__(self) -> Iterator:
 76        raise NotImplementedError
 77
 78    @abstractmethod
 79    def __next__(self) -> T:
 80        raise NotImplementedError
 81
 82    @abstractmethod
 83    def __contains__(self, key: T) -> bool:
 84        raise NotImplementedError
 85
 86    @abstractmethod
 87    def __len__(self) -> int:
 88        raise NotImplementedError
 89
 90    @abstractmethod
 91    def __bool__(self) -> bool:
 92        raise NotImplementedError
 93
 94    @abstractmethod
 95    def __str__(self) -> str:
 96        raise NotImplementedError
 97
 98    @abstractmethod
 99    def __repr__(self) -> str:
100        raise NotImplementedError
101from typing import Optional, Iterable
102from array import array
103
104
105class BinaryTrieSet(OrderedSetInterface):
106
107    def __init__(self, u: int, a: Iterable[int] = []) -> None:
108        """構築します。
109        :math:`O(n\\log{u})` です。
110        """
111        self.left = array("I", bytes(8))
112        self.right = array("I", bytes(8))
113        self.par = array("I", bytes(8))
114        self.size = array("I", bytes(8))
115        self.valid = array("B", bytes(8))
116        self.end = 2
117        self.root = 1
118        self.bit = (u - 1).bit_length()
119        self.lim = 1 << self.bit
120        self.xor = 0
121        for e in a:
122            self.add(e)
123
124    def _make_node(self) -> int:
125        end = self.end
126        if end >= len(self.left):
127            self.left.append(0)
128            self.right.append(0)
129            self.par.append(0)
130            self.size.append(0)
131            self.valid.append(1)
132        else:
133            self.valid[end] = 1
134        self.end += 1
135        return end
136
137    def _find(self, key: int) -> int:
138        left, right, valid = self.left, self.right, self.valid
139        key ^= self.xor
140        node = self.root
141        for i in range(self.bit - 1, -1, -1):
142            if key >> i & 1:
143                if (not right[node]) or (not valid[node]):
144                    return -1
145                node = right[node]
146            else:
147                if (not left[node]) or (not valid[node]):
148                    return -1
149                node = left[node]
150        return node
151
152    def reserve(self, n: int) -> None:
153        """``n`` 要素分のメモリを確保します。
154
155        :math:`O(n)` です。
156        """
157        assert n >= 0, f"ValueError: BinaryTrieSet.reserve({n})"
158        a = array("I", bytes(4 * n))
159        self.left += a
160        self.right += a
161        self.par += a
162        self.size += a
163        self.valid += array("B", bytes(n))
164
165    def add(self, key: int) -> bool:
166        assert (
167            0 <= key < self.lim
168        ), f"ValueError: BinaryTrieSet.add({key}), lim={self.lim}"
169        left, right, par, size = self.left, self.right, self.par, self.size
170        key ^= self.xor
171        node = self.root
172        for i in range(self.bit - 1, -1, -1):
173            if key >> i & 1:
174                if not right[node]:
175                    right[node] = self._make_node()
176                    par[right[node]] = node
177                node = right[node]
178            else:
179                if not left[node]:
180                    left[node] = self._make_node()
181                    par[left[node]] = node
182                node = left[node]
183        if size[node]:
184            return False
185        size[node] = 1
186        for _ in range(self.bit):
187            node = par[node]
188            size[node] += 1
189        return True
190
191    def _rmeove(self, node: int) -> None:
192        left, right, par, size, valid = (
193            self.left,
194            self.right,
195            self.par,
196            self.size,
197            self.valid,
198        )
199        for _ in range(self.bit):
200            size[node] -= 1
201            if left[par[node]] == node:
202                node = par[node]
203                # left[node] = 0
204                valid[left[node]] = 0
205                if right[node]:
206                    break
207            else:
208                node = par[node]
209                # right[node] = 0
210                valid[right[node]] = 0
211                if left[node]:
212                    break
213        while node:
214            size[node] -= 1
215            node = par[node]
216
217    def discard(self, key: int) -> bool:
218        assert (
219            0 <= key < self.lim
220        ), f"ValueError: BinaryTrieSet.discard({key}), lim={self.lim}"
221        node = self._find(key)
222        if node == -1:
223            return False
224        self._rmeove(node)
225        return True
226
227    def remove(self, key: int) -> None:
228        if self.discard(key):
229            return
230        raise KeyError(key)
231
232    def pop(self, k: int = -1) -> int:
233        assert (
234            -len(self) <= k < len(self)
235        ), f"IndexError: BinaryTrieSet.pop({k}), len={len(self)}"
236        if k < 0:
237            k += len(self)
238        left, right, size = self.left, self.right, self.size
239        node = self.root
240        res = 0
241        for i in range(self.bit - 1, -1, -1):
242            res <<= 1
243            if self.xor >> i & 1:
244                left, right = right, left
245            t = size[left[node]]
246            if t <= k:
247                k -= t
248                res |= 1
249                node = right[node]
250            else:
251                node = left[node]
252            if self.xor >> i & 1:
253                left, right = right, left
254        self._rmeove(node)
255        return res ^ self.xor
256
257    def pop_min(self) -> int:
258        assert self, f"IndexError: BinaryTrieSet.pop_min(), len={len(self)}"
259        return self.pop(0)
260
261    def pop_max(self) -> int:
262        return self.pop()
263
264    def all_xor(self, x: int) -> None:
265        """すべての要素に ``x`` で ``xor`` をかけます。
266
267        :math:`O(1)` です。
268        """
269        assert (
270            0 <= x < self.lim
271        ), f"ValueError: BinaryTrieSet.all_xor({x}), lim={self.lim}"
272        self.xor ^= x
273
274    def get_min(self) -> Optional[int]:
275        if not self:
276            return None
277        left, right = self.left, self.right
278        key = self.xor
279        ans = 0
280        node = self.root
281        for i in range(self.bit - 1, -1, -1):
282            ans <<= 1
283            if key >> i & 1:
284                if right[node]:
285                    node = right[node]
286                    ans |= 1
287                else:
288                    node = left[node]
289            else:
290                if left[node]:
291                    node = left[node]
292                else:
293                    node = right[node]
294                    ans |= 1
295        return ans ^ self.xor
296
297    def get_max(self) -> Optional[int]:
298        if not self:
299            return None
300        left, right = self.left, self.right
301        key = self.xor
302        ans = 0
303        node = self.root
304        for i in range(self.bit - 1, -1, -1):
305            ans <<= 1
306            if key >> i & 1:
307                if left[node]:
308                    node = left[node]
309                else:
310                    node = right[node]
311                    ans |= 1
312            else:
313                if right[node]:
314                    ans |= 1
315                    node = right[node]
316                else:
317                    node = left[node]
318        return ans ^ self.xor
319
320    def index(self, key: int) -> int:
321        assert (
322            0 <= key < self.lim
323        ), f"ValueError: BinaryTrieSet.index({key}), lim={self.lim}"
324        left, right, size, valid = self.left, self.right, self.size, self.valid
325        k = 0
326        node = self.root
327        key ^= self.xor
328        for i in range(self.bit - 1, -1, -1):
329            if key >> i & 1:
330                k += size[left[node]]
331                node = right[node]
332            else:
333                node = left[node]
334            if (not node) or (not valid[node]):
335                break
336        return k
337
338    def index_right(self, key: int) -> int:
339        assert (
340            0 <= key < self.lim
341        ), f"ValueError: BinaryTrieSet.index_right({key}), lim={self.lim}"
342        left, right, size, valid = self.left, self.right, self.size, self.valid
343        k = 0
344        node = self.root
345        key ^= self.xor
346        for i in range(self.bit - 1, -1, -1):
347            if key >> i & 1:
348                k += size[left[node]]
349                node = right[node]
350            else:
351                node = left[node]
352            if (not node) or (not valid[node]):
353                break
354        else:
355            k += 1
356        return k
357
358    def clear(self) -> None:
359        self.root = 1
360
361    def gt(self, key: int) -> Optional[int]:
362        assert (
363            0 <= key < self.lim
364        ), f"ValueError: BinaryTrieSet.gt({key}), lim={self.lim}"
365        i = self.index_right(key)
366        return None if i >= self.size[self.root] else self[i]
367
368    def lt(self, key: int) -> Optional[int]:
369        assert (
370            0 <= key < self.lim
371        ), f"ValueError: BinaryTrieSet.lt({key}), lim={self.lim}"
372        i = self.index(key) - 1
373        return None if i < 0 else self[i]
374
375    def ge(self, key: int) -> Optional[int]:
376        assert (
377            0 <= key < self.lim
378        ), f"ValueError: BinaryTrieSet.ge({key}), lim={self.lim}"
379        if key == 0:
380            return self.get_min() if self else None
381        i = self.index_right(key - 1)
382        return None if i >= self.size[self.root] else self[i]
383
384    def le(self, key: int) -> Optional[int]:
385        assert (
386            0 <= key < self.lim
387        ), f"ValueError: BinaryTrieSet.le({key}), lim={self.lim}"
388        i = self.index(key + 1) - 1
389        return None if i < 0 else self[i]
390
391    def tolist(self) -> list[int]:
392        a = []
393        if not self:
394            return a
395        val = self.get_min()
396        while val is not None:
397            a.append(val)
398            val = self.gt(val)
399        return a
400
401    def __contains__(self, key: int):
402        assert (
403            0 <= key < self.lim
404        ), f"ValueError: {key} in BinaryTrieSet, lim={self.lim}"
405        return self._find(key) != -1
406
407    def __getitem__(self, k: int):
408        assert (
409            -len(self) <= k < len(self)
410        ), f"IndexError: BinaryTrieSet[{k}], len={len(self)}"
411        if k < 0:
412            k += len(self)
413        left, right, size = self.left, self.right, self.size
414        node = self.root
415        res = 0
416        for i in range(self.bit - 1, -1, -1):
417            if self.xor >> i & 1:
418                left, right = right, left
419            t = size[left[node]]
420            if t <= k:
421                k -= t
422                node = right[node]
423                res |= 1 << i
424            else:
425                node = left[node]
426            if self.xor >> i & 1:
427                left, right = right, left
428        return res
429
430    def __bool__(self):
431        return self.size[self.root] != 0
432
433    def __iter__(self):
434        self.it = 0
435        return self
436
437    def __next__(self):
438        if self.it == len(self):
439            raise StopIteration
440        self.it += 1
441        return self.__getitem__(self.it - 1)
442
443    def __len__(self):
444        return self.size[self.root]
445
446    def __str__(self):
447        return "{" + ", ".join(map(str, self)) + "}"
448
449    def __repr__(self):
450        return f"BinaryTrieSet({(1<<self.bit)-1}, {self})"

仕様

class BinaryTrieSet(u: int, a: Iterable[int] = [])[source]

Bases: OrderedSetInterface

add(key: int) bool[source]
all_xor(x: int) None[source]

すべての要素に xxor をかけます。

\(O(1)\) です。

clear() None[source]
discard(key: int) bool[source]
ge(key: int) int | None[source]
get_max() int | None[source]
get_min() int | None[source]
gt(key: int) int | None[source]
index(key: int) int[source]
index_right(key: int) int[source]
le(key: int) int | None[source]
lt(key: int) int | None[source]
pop(k: int = -1) int[source]
pop_max() int[source]
pop_min() int[source]
remove(key: int) None[source]
reserve(n: int) None[source]

n 要素分のメモリを確保します。

\(O(n)\) です。

tolist() list[int][source]