Source code for titan_pylib.data_structures.wavelet_matrix.wavelet_matrix

  1from titan_pylib.data_structures.bit_vector.bit_vector import BitVector
  2from typing import Sequence
  3from heapq import heappush, heappop
  4from array import array
  5
  6
[docs] 7class WaveletMatrix: 8 """``WaveletMatrix`` です。 9 静的であることに注意してください。 10 11 以下の仕様の計算量には嘘があるかもしれません。import 元の ``BitVector`` の計算量も参考にしてください。 12 13 参考: 14 `https://miti-7.hatenablog.com/entry/2018/04/28/152259 <https://miti-7.hatenablog.com/entry/2018/04/28/152259>`_ 15 `https://www.slideshare.net/pfi/ss-15916040 <https://www.slideshare.net/pfi/ss-15916040>`_ 16 `デwiki <https://scrapbox.io/data-structures/Wavelet_Matrix>`_ 17 """ 18 19 def __init__(self, sigma: int, a: Sequence[int] = []) -> None: 20 """``[0, sigma)`` の整数列を管理する ``WaveletMatrix`` を構築します。 21 :math:`O(n\\log{\\sigma})` です。 22 23 Args: 24 sigma (int): 扱う整数の上限です。 25 a (Sequence[int], optional): 構築する配列です。 26 """ 27 self.sigma: int = sigma 28 self.log: int = (sigma - 1).bit_length() 29 self.mid: array[int] = array("I", bytes(4 * self.log)) 30 self.size: int = len(a) 31 self.v: list[BitVector] = [BitVector(self.size) for _ in range(self.log)] 32 self._build(a) 33 34 def _build(self, a: Sequence[int]) -> None: 35 # 列 a から wm を構築する 36 for bit in range(self.log - 1, -1, -1): 37 # bit目の0/1に応じてvを構築 + aを安定ソート 38 v = self.v[bit] 39 zero, one = [], [] 40 for i, e in enumerate(a): 41 if e >> bit & 1: 42 v.set(i) 43 one.append(e) 44 else: 45 zero.append(e) 46 v.build() 47 self.mid[bit] = len(zero) # 境界をmid[bit]に保持 48 a = zero + one 49
[docs] 50 def access(self, k: int) -> int: 51 """``k`` 番目の値を返します。 52 :math:`O(\\log{\\sigma})` です。 53 54 Args: 55 k (int): インデックスです。 56 """ 57 assert ( 58 -self.size <= k < self.size 59 ), f"IndexError: {self.__class__.__name__}.access({k}), size={self.size}" 60 if k < 0: 61 k += self.size 62 s = 0 # 答え 63 for bit in range(self.log - 1, -1, -1): 64 if self.v[bit].access(k): 65 # k番目が立ってたら、 66 # kまでの1とすべての0が次のk 67 s |= 1 << bit 68 k = self.v[bit].rank1(k) + self.mid[bit] 69 else: 70 # kまでの0が次のk 71 k = self.v[bit].rank0(k) 72 return s
73 74 def __getitem__(self, k: int) -> int: 75 assert ( 76 -self.size <= k < self.size 77 ), f"IndexError: {self.__class__.__name__}[{k}], size={self.size}" 78 return self.access(k) 79
[docs] 80 def rank(self, r: int, x: int) -> int: 81 """``a[0, r)`` に含まれる ``x`` の個数を返します。 82 :math:`O(\\log{\\sigma})` です。 83 """ 84 assert ( 85 0 <= r <= self.size 86 ), f"IndexError: {self.__class__.__name__}.rank(), r={r}, size={self.size}" 87 assert ( 88 0 <= x < 1 << self.log 89 ), f"ValueError: {self.__class__.__name__}.rank(), x={x}, LIM={1<<self.log}" 90 l = 0 91 mid = self.mid 92 for bit in range(self.log - 1, -1, -1): 93 # 位置 r より左に x が何個あるか 94 # x の bit 目で場合分け 95 if x >> bit & 1: 96 # 立ってたら、次のl, rは以下 97 l = self.v[bit].rank1(l) + mid[bit] 98 r = self.v[bit].rank1(r) + mid[bit] 99 else: 100 # そうでなければ次のl, rは以下 101 l = self.v[bit].rank0(l) 102 r = self.v[bit].rank0(r) 103 return r - l
104
[docs] 105 def select(self, k: int, x: int) -> int: 106 """``k`` 番目の ``v`` のインデックスを返します。 107 :math:`O(\\log{\\sigma})` です。 108 """ 109 assert ( 110 0 <= k < self.size 111 ), f"IndexError: {self.__class__.__name__}.select({k}, {x}), k={k}, size={self.size}" 112 assert ( 113 0 <= x < 1 << self.log 114 ), f"ValueError: {self.__class__.__name__}.select({k}, {x}), x={x}, LIM={1<<self.log}" 115 # x の開始位置 s を探す 116 s = 0 117 for bit in range(self.log - 1, -1, -1): 118 if x >> bit & 1: 119 s = self.v[bit].rank0(self.size) + self.v[bit].rank1(s) 120 else: 121 s = self.v[bit].rank0(s) 122 s += k # s から k 進んだ位置が、元の列で何番目か調べる 123 for bit in range(self.log): 124 if x >> bit & 1: 125 s = self.v[bit].select1(s - self.v[bit].rank0(self.size)) 126 else: 127 s = self.v[bit].select0(s) 128 return s
129
[docs] 130 def kth_smallest(self, l: int, r: int, k: int) -> int: 131 """``a[l, r)`` の中で ``k`` 番目に **小さい** 値を返します。 132 :math:`O(\\log{\\sigma})` です。 133 """ 134 assert ( 135 0 <= l <= r <= self.size 136 ), f"IndexError: {self.__class__.__name__}.kth_smallest({l}, {r}, {k}), size={self.size}" 137 assert ( 138 0 <= k < r - l 139 ), f"IndexError: {self.__class__.__name__}.kth_smallest({l}, {r}, {k}), wrong k" 140 s = 0 141 mid = self.mid 142 for bit in range(self.log - 1, -1, -1): 143 r0, l0 = self.v[bit].rank0(r), self.v[bit].rank0(l) 144 cnt = r0 - l0 # 区間内の 0 の個数 145 if cnt <= k: # 0 が k 以下のとき、 k 番目は 1 146 s |= 1 << bit 147 k -= cnt 148 # この 1 が次の bit 列でどこに行くか 149 l = l - l0 + mid[bit] 150 r = r - r0 + mid[bit] 151 else: 152 # この 0 が次の bit 列でどこに行くか 153 l = l0 154 r = r0 155 return s
156 157 quantile = kth_smallest 158
[docs] 159 def kth_largest(self, l: int, r: int, k: int) -> int: 160 """``a[l, r)`` の中で ``k`` 番目に **大きい値** を返します。 161 :math:`O(\\log{\\sigma})` です。 162 """ 163 assert ( 164 0 <= l <= r <= self.size 165 ), f"IndexError: {self.__class__.__name__}.kth_largest({l}, {r}, {k}), size={self.size}" 166 assert ( 167 0 <= k < r - l 168 ), f"IndexError: {self.__class__.__name__}.kth_largest({l}, {r}, {k}), wrong k" 169 return self.kth_smallest(l, r, r - l - k - 1)
170
[docs] 171 def topk(self, l: int, r: int, k: int) -> list[tuple[int, int]]: 172 """``a[l, r)`` の中で、要素を出現回数が多い順にその頻度とともに ``k`` 個返します。 173 :math:`O(\\min(r-l, \\sigam) \\log(\\sigam))` です。 174 175 Note: 176 :math:`\\sigma` が大きい場合、計算量に注意です。 177 178 Returns: 179 list[tuple[int, int]]: ``(要素, 頻度)`` を要素とする配列です。 180 """ 181 assert ( 182 0 <= l <= r <= self.size 183 ), f"IndexError: {self.__class__.__name__}.topk({l}, {r}, {k}), size={self.size}" 184 assert ( 185 0 <= k < r - l 186 ), f"IndexError: {self.__class__.__name__}.topk({l}, {r}, {k}), wrong k" 187 # heap[-length, x, l, bit] 188 hq: list[tuple[int, int, int, int]] = [(-(r - l), 0, l, self.log - 1)] 189 ans = [] 190 while hq: 191 length, x, l, bit = heappop(hq) 192 length = -length 193 if bit == -1: 194 ans.append((x, length)) 195 k -= 1 196 if k == 0: 197 break 198 else: 199 r = l + length 200 l0 = self.v[bit].rank0(l) 201 r0 = self.v[bit].rank0(r) 202 if l0 < r0: 203 heappush(hq, (-(r0 - l0), x, l0, bit - 1)) 204 l1 = self.v[bit].rank1(l) + self.mid[bit] 205 r1 = self.v[bit].rank1(r) + self.mid[bit] 206 if l1 < r1: 207 heappush(hq, (-(r1 - l1), x | (1 << bit), l1, bit - 1)) 208 return ans
209
[docs] 210 def sum(self, l: int, r: int) -> int: 211 """``topk`` メソッドを用いて ``a[l, r)`` の総和を返します。 212 計算量に注意です。 213 """ 214 assert False, "Yabai Keisanryo Error" 215 return sum(k * v for k, v in self.topk(l, r, r - l))
216 217 def _range_freq(self, l: int, r: int, x: int) -> int: 218 """a[l, r) で x 未満の要素の数を返す""" 219 ans = 0 220 for bit in range(self.log - 1, -1, -1): 221 l0, r0 = self.v[bit].rank0(l), self.v[bit].rank0(r) 222 if x >> bit & 1: 223 # bit が立ってたら、区間の 0 の個数を答えに加算し、新たな区間は 1 のみ 224 ans += r0 - l0 225 # 1 が次の bit 列でどこに行くか 226 l += self.mid[bit] - l0 227 r += self.mid[bit] - r0 228 else: 229 # 0 が次の bit 列でどこに行くか 230 l, r = l0, r0 231 return ans 232
[docs] 233 def range_freq(self, l: int, r: int, x: int, y: int) -> int: 234 """``a[l, r)`` に含まれる、 ``x`` 以上 ``y`` 未満である要素の個数を返します。 235 :math:`O(\\log{\\sigma})` です。 236 """ 237 assert ( 238 0 <= l <= r <= self.size 239 ), f"IndexError: {self.__class__.__name__}.range_freq({l}, {r}, {x}, {y})" 240 assert 0 <= x <= y < self.sigma, f"ValueError" 241 return self._range_freq(l, r, y) - self._range_freq(l, r, x)
242
[docs] 243 def prev_value(self, l: int, r: int, x: int) -> int: 244 """``a[l, r)`` で、``x`` 以上 ``y`` 未満であるような要素のうち最大の要素を返します。 245 :math:`O(\\log{\\sigma})` です。 246 """ 247 assert ( 248 0 <= l <= r <= self.size 249 ), f"IndexError: {self.__class__.__name__}.prev_value({l}, {r}, {x})" 250 return self.kth_smallest(l, r, self._range_freq(l, r, x) - 1)
251
[docs] 252 def next_value(self, l: int, r: int, x: int) -> int: 253 """``a[l, r)`` で、``x`` 以上 ``y`` 未満であるような要素のうち最小の要素を返します。 254 :math:`O(\\log{\\sigma})` です。 255 """ 256 assert ( 257 0 <= l <= r <= self.size 258 ), f"IndexError: {self.__class__.__name__}.next_value({l}, {r}, {x})" 259 return self.kth_smallest(l, r, self._range_freq(l, r, x))
260
[docs] 261 def range_count(self, l: int, r: int, x: int) -> int: 262 """``a[l, r)`` に含まれる ``x`` の個数を返します。 263 ``wm.rank(r, x) - wm.rank(l, x)`` と等価です。 264 :math:`O(\\log{\\sigma})` です。 265 """ 266 assert ( 267 0 <= l <= r <= self.size 268 ), f"IndexError: {self.__class__.__name__}.range_count({l}, {r}, {x})" 269 return self.rank(r, x) - self.rank(l, x)
270 271 def __len__(self) -> int: 272 return self.size 273 274 def __str__(self) -> str: 275 return ( 276 f"{self.__class__.__name__}({[self.access(i) for i in range(self.size)]})" 277 ) 278 279 __repr__ = __str__