Source code for titan_pylib.data_structures.wavelet_matrix.dynamic_wavelet_matrix

  1from titan_pylib.data_structures.bit_vector.avl_tree_bit_vector import AVLTreeBitVector
  2from titan_pylib.data_structures.wavelet_matrix.wavelet_matrix import WaveletMatrix
  3from typing import Sequence
  4from array import array
  5
  6
[docs] 7class DynamicWaveletMatrix(WaveletMatrix): 8 """動的ウェーブレット行列です。 9 10 (静的)ウェーブレット行列でできる操作に加えて ``insert / pop / set`` 等ができます。 11 - ``BitVector`` を平衡二分木にしています(``AVLTreeBitVector``)。あらゆる操作に平衡二分木の log がつきます。これヤバくね 12 13 :math:`O(n\\log{(\\sigma)})` です。 14 """ 15 16 def __init__(self, sigma: int, a: Sequence[int] = []) -> None: 17 self.sigma: int = sigma 18 self.log: int = (sigma - 1).bit_length() 19 self.v: list[AVLTreeBitVector] = [AVLTreeBitVector()] * self.log 20 self.mid: array[int] = array("I", bytes(4 * self.log)) 21 self.size: int = len(a) 22 self._build(a) 23 24 def _build(self, a: Sequence[int]) -> None: 25 v = array("B", bytes(self.size)) 26 for bit in range(self.log - 1, -1, -1): 27 # bit目の0/1に応じてvを構築 + aを安定ソート 28 zero, one = [], [] 29 for i, e in enumerate(a): 30 if e >> bit & 1: 31 v[i] = 1 32 one.append(e) 33 else: 34 v[i] = 0 35 zero.append(e) 36 self.mid[bit] = len(zero) # 境界をmid[bit]に保持 37 self.v[bit] = AVLTreeBitVector(v) 38 a = zero + one 39
[docs] 40 def reserve(self, n: int) -> None: 41 """``n`` 要素分のメモリを確保します。 42 :math:`O(n)` です。 43 """ 44 assert n >= 0, f"ValueError: {self.__class__.__name__}.reserve({n})" 45 for v in self.v: 46 v.reserve(n)
47
[docs] 48 def insert(self, k: int, x: int) -> None: 49 """位置 ``k`` に ``x`` を挿入します。 50 :math:`O(\\log{(n)}\\log{(\\sigma)})` です。 51 """ 52 assert ( 53 0 <= k <= self.size 54 ), f"IndexError: {self.__class__.__name__}.insert({k}, {x}), n={self.size}" 55 assert ( 56 0 <= x < 1 << self.log 57 ), f"ValueError: {self.__class__.__name__}.insert({k}, {x}), LIM={1<<self.log}" 58 mid = self.mid 59 for bit in range(self.log - 1, -1, -1): 60 v = self.v[bit] 61 # if x >> bit & 1: 62 # v.insert(k, 1) 63 # k = v.rank1(k) + mid[bit] 64 # else: 65 # v.insert(k, 0) 66 # mid[bit] += 1 67 # k = v.rank0(k) 68 if x >> bit & 1: 69 s = v._insert_and_rank1(k, 1) 70 k = s + mid[bit] 71 else: 72 s = v._insert_and_rank1(k, 0) 73 k -= s 74 mid[bit] += 1 75 self.size += 1
76
[docs] 77 def pop(self, k: int) -> int: 78 """位置 ``k`` の要素を削除し、その値を返します。 79 :math:`O(\\log{(n)}\\log{(\\sigma)})` です。 80 """ 81 assert ( 82 0 <= k < self.size 83 ), f"IndexError: {self.__class__.__name__}.pop({k}), n={self.size}" 84 mid = self.mid 85 ans = 0 86 for bit in range(self.log - 1, -1, -1): 87 v = self.v[bit] 88 # K = k 89 # if v.access(k): 90 # ans |= 1 << bit 91 # k = v.rank1(k) + mid[bit] 92 # else: 93 # mid[bit] -= 1 94 # k = v.rank0(k) 95 # v.pop(K) 96 sb = v._access_pop_and_rank1(k) 97 s = sb >> 1 98 if sb & 1: 99 ans |= 1 << bit 100 k = s + mid[bit] 101 else: 102 mid[bit] -= 1 103 k -= s 104 self.size -= 1 105 return ans
106
[docs] 107 def set(self, k: int, x: int) -> None: 108 """位置 ``k`` の要素を ``x`` に更新します。 109 :math:`O(\\log{(n)}\\log{(\\sigma)})` です。 110 """ 111 assert ( 112 0 <= k < self.size 113 ), f"IndexError: {self.__class__.__name__}.set({k}, {x}), n={self.size}" 114 assert ( 115 0 <= x < 1 << self.log 116 ), f"ValueError: {self.__class__.__name__}.set({k}, {x}), LIM={1<<self.log}" 117 self.pop(k) 118 self.insert(k, x)
119 120 def __setitem__(self, k: int, x: int): 121 assert ( 122 0 <= k < self.size 123 ), f"IndexError: {self.__class__.__name__}[{k}] = {x}, n={self.size}" 124 assert ( 125 0 <= x < 1 << self.log 126 ), f"ValueError: {self.__class__.__name__}[{k}] = {x}, LIM={1<<self.log}" 127 self.set(k, x) 128 129 def __str__(self): 130 return f"{self.__class__.__name__}({[self[i] for i in range(self.size)]})"