1from titan_pylib.data_structures.bit_vector.bit_vector import BitVector
2from typing import Sequence
3from heapq import heappush, heappop
4from array import array
5
6
[docs]
7class WaveletMatrix:
8 """``WaveletMatrix`` です。
9 静的であることに注意してください。
10
11 以下の仕様の計算量には嘘があるかもしれません。import 元の ``BitVector`` の計算量も参考にしてください。
12
13 参考:
14 `https://miti-7.hatenablog.com/entry/2018/04/28/152259 <https://miti-7.hatenablog.com/entry/2018/04/28/152259>`_
15 `https://www.slideshare.net/pfi/ss-15916040 <https://www.slideshare.net/pfi/ss-15916040>`_
16 `デwiki <https://scrapbox.io/data-structures/Wavelet_Matrix>`_
17 """
18
19 def __init__(self, sigma: int, a: Sequence[int] = []) -> None:
20 """``[0, sigma)`` の整数列を管理する ``WaveletMatrix`` を構築します。
21 :math:`O(n\\log{\\sigma})` です。
22
23 Args:
24 sigma (int): 扱う整数の上限です。
25 a (Sequence[int], optional): 構築する配列です。
26 """
27 self.sigma: int = sigma
28 self.log: int = (sigma - 1).bit_length()
29 self.mid: array[int] = array("I", bytes(4 * self.log))
30 self.size: int = len(a)
31 self.v: list[BitVector] = [BitVector(self.size) for _ in range(self.log)]
32 self._build(a)
33
34 def _build(self, a: Sequence[int]) -> None:
35 # 列 a から wm を構築する
36 for bit in range(self.log - 1, -1, -1):
37 # bit目の0/1に応じてvを構築 + aを安定ソート
38 v = self.v[bit]
39 zero, one = [], []
40 for i, e in enumerate(a):
41 if e >> bit & 1:
42 v.set(i)
43 one.append(e)
44 else:
45 zero.append(e)
46 v.build()
47 self.mid[bit] = len(zero) # 境界をmid[bit]に保持
48 a = zero + one
49
[docs]
50 def access(self, k: int) -> int:
51 """``k`` 番目の値を返します。
52 :math:`O(\\log{\\sigma})` です。
53
54 Args:
55 k (int): インデックスです。
56 """
57 assert (
58 -self.size <= k < self.size
59 ), f"IndexError: {self.__class__.__name__}.access({k}), size={self.size}"
60 if k < 0:
61 k += self.size
62 s = 0 # 答え
63 for bit in range(self.log - 1, -1, -1):
64 if self.v[bit].access(k):
65 # k番目が立ってたら、
66 # kまでの1とすべての0が次のk
67 s |= 1 << bit
68 k = self.v[bit].rank1(k) + self.mid[bit]
69 else:
70 # kまでの0が次のk
71 k = self.v[bit].rank0(k)
72 return s
73
74 def __getitem__(self, k: int) -> int:
75 assert (
76 -self.size <= k < self.size
77 ), f"IndexError: {self.__class__.__name__}[{k}], size={self.size}"
78 return self.access(k)
79
[docs]
80 def rank(self, r: int, x: int) -> int:
81 """``a[0, r)`` に含まれる ``x`` の個数を返します。
82 :math:`O(\\log{\\sigma})` です。
83 """
84 assert (
85 0 <= r <= self.size
86 ), f"IndexError: {self.__class__.__name__}.rank(), r={r}, size={self.size}"
87 assert (
88 0 <= x < 1 << self.log
89 ), f"ValueError: {self.__class__.__name__}.rank(), x={x}, LIM={1<<self.log}"
90 l = 0
91 mid = self.mid
92 for bit in range(self.log - 1, -1, -1):
93 # 位置 r より左に x が何個あるか
94 # x の bit 目で場合分け
95 if x >> bit & 1:
96 # 立ってたら、次のl, rは以下
97 l = self.v[bit].rank1(l) + mid[bit]
98 r = self.v[bit].rank1(r) + mid[bit]
99 else:
100 # そうでなければ次のl, rは以下
101 l = self.v[bit].rank0(l)
102 r = self.v[bit].rank0(r)
103 return r - l
104
[docs]
105 def select(self, k: int, x: int) -> int:
106 """``k`` 番目の ``v`` のインデックスを返します。
107 :math:`O(\\log{\\sigma})` です。
108 """
109 assert (
110 0 <= k < self.size
111 ), f"IndexError: {self.__class__.__name__}.select({k}, {x}), k={k}, size={self.size}"
112 assert (
113 0 <= x < 1 << self.log
114 ), f"ValueError: {self.__class__.__name__}.select({k}, {x}), x={x}, LIM={1<<self.log}"
115 # x の開始位置 s を探す
116 s = 0
117 for bit in range(self.log - 1, -1, -1):
118 if x >> bit & 1:
119 s = self.v[bit].rank0(self.size) + self.v[bit].rank1(s)
120 else:
121 s = self.v[bit].rank0(s)
122 s += k # s から k 進んだ位置が、元の列で何番目か調べる
123 for bit in range(self.log):
124 if x >> bit & 1:
125 s = self.v[bit].select1(s - self.v[bit].rank0(self.size))
126 else:
127 s = self.v[bit].select0(s)
128 return s
129
[docs]
130 def kth_smallest(self, l: int, r: int, k: int) -> int:
131 """``a[l, r)`` の中で ``k`` 番目に **小さい** 値を返します。
132 :math:`O(\\log{\\sigma})` です。
133 """
134 assert (
135 0 <= l <= r <= self.size
136 ), f"IndexError: {self.__class__.__name__}.kth_smallest({l}, {r}, {k}), size={self.size}"
137 assert (
138 0 <= k < r - l
139 ), f"IndexError: {self.__class__.__name__}.kth_smallest({l}, {r}, {k}), wrong k"
140 s = 0
141 mid = self.mid
142 for bit in range(self.log - 1, -1, -1):
143 r0, l0 = self.v[bit].rank0(r), self.v[bit].rank0(l)
144 cnt = r0 - l0 # 区間内の 0 の個数
145 if cnt <= k: # 0 が k 以下のとき、 k 番目は 1
146 s |= 1 << bit
147 k -= cnt
148 # この 1 が次の bit 列でどこに行くか
149 l = l - l0 + mid[bit]
150 r = r - r0 + mid[bit]
151 else:
152 # この 0 が次の bit 列でどこに行くか
153 l = l0
154 r = r0
155 return s
156
157 quantile = kth_smallest
158
[docs]
159 def kth_largest(self, l: int, r: int, k: int) -> int:
160 """``a[l, r)`` の中で ``k`` 番目に **大きい値** を返します。
161 :math:`O(\\log{\\sigma})` です。
162 """
163 assert (
164 0 <= l <= r <= self.size
165 ), f"IndexError: {self.__class__.__name__}.kth_largest({l}, {r}, {k}), size={self.size}"
166 assert (
167 0 <= k < r - l
168 ), f"IndexError: {self.__class__.__name__}.kth_largest({l}, {r}, {k}), wrong k"
169 return self.kth_smallest(l, r, r - l - k - 1)
170
[docs]
171 def topk(self, l: int, r: int, k: int) -> list[tuple[int, int]]:
172 """``a[l, r)`` の中で、要素を出現回数が多い順にその頻度とともに ``k`` 個返します。
173 :math:`O(\\min(r-l, \\sigam) \\log(\\sigam))` です。
174
175 Note:
176 :math:`\\sigma` が大きい場合、計算量に注意です。
177
178 Returns:
179 list[tuple[int, int]]: ``(要素, 頻度)`` を要素とする配列です。
180 """
181 assert (
182 0 <= l <= r <= self.size
183 ), f"IndexError: {self.__class__.__name__}.topk({l}, {r}, {k}), size={self.size}"
184 assert (
185 0 <= k < r - l
186 ), f"IndexError: {self.__class__.__name__}.topk({l}, {r}, {k}), wrong k"
187 # heap[-length, x, l, bit]
188 hq: list[tuple[int, int, int, int]] = [(-(r - l), 0, l, self.log - 1)]
189 ans = []
190 while hq:
191 length, x, l, bit = heappop(hq)
192 length = -length
193 if bit == -1:
194 ans.append((x, length))
195 k -= 1
196 if k == 0:
197 break
198 else:
199 r = l + length
200 l0 = self.v[bit].rank0(l)
201 r0 = self.v[bit].rank0(r)
202 if l0 < r0:
203 heappush(hq, (-(r0 - l0), x, l0, bit - 1))
204 l1 = self.v[bit].rank1(l) + self.mid[bit]
205 r1 = self.v[bit].rank1(r) + self.mid[bit]
206 if l1 < r1:
207 heappush(hq, (-(r1 - l1), x | (1 << bit), l1, bit - 1))
208 return ans
209
[docs]
210 def sum(self, l: int, r: int) -> int:
211 """``topk`` メソッドを用いて ``a[l, r)`` の総和を返します。
212 計算量に注意です。
213 """
214 assert False, "Yabai Keisanryo Error"
215 return sum(k * v for k, v in self.topk(l, r, r - l))
216
217 def _range_freq(self, l: int, r: int, x: int) -> int:
218 """a[l, r) で x 未満の要素の数を返す"""
219 ans = 0
220 for bit in range(self.log - 1, -1, -1):
221 l0, r0 = self.v[bit].rank0(l), self.v[bit].rank0(r)
222 if x >> bit & 1:
223 # bit が立ってたら、区間の 0 の個数を答えに加算し、新たな区間は 1 のみ
224 ans += r0 - l0
225 # 1 が次の bit 列でどこに行くか
226 l += self.mid[bit] - l0
227 r += self.mid[bit] - r0
228 else:
229 # 0 が次の bit 列でどこに行くか
230 l, r = l0, r0
231 return ans
232
[docs]
233 def range_freq(self, l: int, r: int, x: int, y: int) -> int:
234 """``a[l, r)`` に含まれる、 ``x`` 以上 ``y`` 未満である要素の個数を返します。
235 :math:`O(\\log{\\sigma})` です。
236 """
237 assert (
238 0 <= l <= r <= self.size
239 ), f"IndexError: {self.__class__.__name__}.range_freq({l}, {r}, {x}, {y})"
240 assert 0 <= x <= y < self.sigma, f"ValueError"
241 return self._range_freq(l, r, y) - self._range_freq(l, r, x)
242
[docs]
243 def prev_value(self, l: int, r: int, x: int) -> int:
244 """``a[l, r)`` で、``x`` 以上 ``y`` 未満であるような要素のうち最大の要素を返します。
245 :math:`O(\\log{\\sigma})` です。
246 """
247 assert (
248 0 <= l <= r <= self.size
249 ), f"IndexError: {self.__class__.__name__}.prev_value({l}, {r}, {x})"
250 return self.kth_smallest(l, r, self._range_freq(l, r, x) - 1)
251
[docs]
252 def next_value(self, l: int, r: int, x: int) -> int:
253 """``a[l, r)`` で、``x`` 以上 ``y`` 未満であるような要素のうち最小の要素を返します。
254 :math:`O(\\log{\\sigma})` です。
255 """
256 assert (
257 0 <= l <= r <= self.size
258 ), f"IndexError: {self.__class__.__name__}.next_value({l}, {r}, {x})"
259 return self.kth_smallest(l, r, self._range_freq(l, r, x))
260
[docs]
261 def range_count(self, l: int, r: int, x: int) -> int:
262 """``a[l, r)`` に含まれる ``x`` の個数を返します。
263 ``wm.rank(r, x) - wm.rank(l, x)`` と等価です。
264 :math:`O(\\log{\\sigma})` です。
265 """
266 assert (
267 0 <= l <= r <= self.size
268 ), f"IndexError: {self.__class__.__name__}.range_count({l}, {r}, {x})"
269 return self.rank(r, x) - self.rank(l, x)
270
271 def __len__(self) -> int:
272 return self.size
273
274 def __str__(self) -> str:
275 return (
276 f"{self.__class__.__name__}({[self.access(i) for i in range(self.size)]})"
277 )
278
279 __repr__ = __str__