wavelet_matrix

ソースコード

from titan_pylib.data_structures.wavelet_matrix.wavelet_matrix import WaveletMatrix

view on github

展開済みコード

  1# from titan_pylib.data_structures.wavelet_matrix.wavelet_matrix import WaveletMatrix
  2# from titan_pylib.data_structures.bit_vector.bit_vector import BitVector
  3# from titan_pylib.data_structures.bit_vector.bit_vector_interface import (
  4#     BitVectorInterface,
  5# )
  6from abc import ABC, abstractmethod
  7
  8
  9class BitVectorInterface(ABC):
 10
 11    @abstractmethod
 12    def access(self, k: int) -> int:
 13        raise NotImplementedError
 14
 15    @abstractmethod
 16    def __getitem__(self, k: int) -> int:
 17        raise NotImplementedError
 18
 19    @abstractmethod
 20    def rank0(self, r: int) -> int:
 21        raise NotImplementedError
 22
 23    @abstractmethod
 24    def rank1(self, r: int) -> int:
 25        raise NotImplementedError
 26
 27    @abstractmethod
 28    def rank(self, r: int, v: int) -> int:
 29        raise NotImplementedError
 30
 31    @abstractmethod
 32    def select0(self, k: int) -> int:
 33        raise NotImplementedError
 34
 35    @abstractmethod
 36    def select1(self, k: int) -> int:
 37        raise NotImplementedError
 38
 39    @abstractmethod
 40    def select(self, k: int, v: int) -> int:
 41        raise NotImplementedError
 42
 43    @abstractmethod
 44    def __len__(self) -> int:
 45        raise NotImplementedError
 46
 47    @abstractmethod
 48    def __str__(self) -> str:
 49        raise NotImplementedError
 50
 51    @abstractmethod
 52    def __repr__(self) -> str:
 53        raise NotImplementedError
 54from array import array
 55
 56
 57class BitVector(BitVectorInterface):
 58    """コンパクトな bit vector です。"""
 59
 60    def __init__(self, n: int):
 61        """長さ ``n`` の ``BitVector`` です。
 62
 63        bit を保持するのに ``array[I]`` を使用します。
 64        ``block_size= n / 32`` として、使用bitは ``32*block_size=2n bit`` です。
 65
 66        累積和を保持するのに同様の ``array[I]`` を使用します。
 67        32bitごとの和を保存しています。同様に使用bitは ``2n bit`` です。
 68        """
 69        assert 0 <= n < 4294967295
 70        self.N = n
 71        self.block_size = (n + 31) >> 5
 72        b = bytes(4 * (self.block_size + 1))
 73        self.bit = array("I", b)
 74        self.acc = array("I", b)
 75
 76    @staticmethod
 77    def _popcount(x: int) -> int:
 78        x = x - ((x >> 1) & 0x55555555)
 79        x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
 80        x = x + (x >> 4) & 0x0F0F0F0F
 81        x += x >> 8
 82        x += x >> 16
 83        return x & 0x0000007F
 84
 85    def set(self, k: int) -> None:
 86        """``k`` 番目の bit を ``1`` にします。
 87        :math:`O(1)` です。
 88
 89        Args:
 90          k (int): インデックスです。
 91        """
 92        self.bit[k >> 5] |= 1 << (k & 31)
 93
 94    def build(self) -> None:
 95        """構築します。
 96        **これ以降 ``set`` メソッドを使用してはいけません。**
 97        :math:`O(n)` です。
 98        """
 99        acc, bit = self.acc, self.bit
100        for i in range(self.block_size):
101            acc[i + 1] = acc[i] + BitVector._popcount(bit[i])
102
103    def access(self, k: int) -> int:
104        """``k`` 番目の bit を返します。
105        :math:`O(1)` です。
106        """
107        return (self.bit[k >> 5] >> (k & 31)) & 1
108
109    def __getitem__(self, k: int) -> int:
110        return (self.bit[k >> 5] >> (k & 31)) & 1
111
112    def rank0(self, r: int) -> int:
113        """``a[0, r)`` に含まれる ``0`` の個数を返します。
114        :math:`O(1)` です。
115        """
116        return r - (
117            self.acc[r >> 5]
118            + BitVector._popcount(self.bit[r >> 5] & ((1 << (r & 31)) - 1))
119        )
120
121    def rank1(self, r: int) -> int:
122        """``a[0, r)`` に含まれる ``1`` の個数を返します。
123        :math:`O(1)` です。
124        """
125        return self.acc[r >> 5] + BitVector._popcount(
126            self.bit[r >> 5] & ((1 << (r & 31)) - 1)
127        )
128
129    def rank(self, r: int, v: int) -> int:
130        """``a[0, r)`` に含まれる ``v`` の個数を返します。
131        :math:`O(1)` です。
132        """
133        return self.rank1(r) if v else self.rank0(r)
134
135    def select0(self, k: int) -> int:
136        """``k`` 番目の ``0`` のインデックスを返します。
137        :math:`O(\\log{n})` です。
138        """
139        if k < 0 or self.rank0(self.N) <= k:
140            return -1
141        l, r = 0, self.block_size + 1
142        while r - l > 1:
143            m = (l + r) >> 1
144            if m * 32 - self.acc[m] > k:
145                r = m
146            else:
147                l = m
148        indx = 32 * l
149        k = k - (l * 32 - self.acc[l]) + self.rank0(indx)
150        l, r = indx, indx + 32
151        while r - l > 1:
152            m = (l + r) >> 1
153            if self.rank0(m) > k:
154                r = m
155            else:
156                l = m
157        return l
158
159    def select1(self, k: int) -> int:
160        """``k`` 番目の ``1`` のインデックスを返します。
161        :math:`O(\\log{n})` です。
162        """
163        if k < 0 or self.rank1(self.N) <= k:
164            return -1
165        l, r = 0, self.block_size + 1
166        while r - l > 1:
167            m = (l + r) >> 1
168            if self.acc[m] > k:
169                r = m
170            else:
171                l = m
172        indx = 32 * l
173        k = k - self.acc[l] + self.rank1(indx)
174        l, r = indx, indx + 32
175        while r - l > 1:
176            m = (l + r) >> 1
177            if self.rank1(m) > k:
178                r = m
179            else:
180                l = m
181        return l
182
183    def select(self, k: int, v: int) -> int:
184        """``k`` 番目の ``v`` のインデックスを返します。
185        :math:`O(\\log{n})` です。
186        """
187        return self.select1(k) if v else self.select0(k)
188
189    def __len__(self):
190        return self.N
191
192    def __str__(self):
193        return str([self.access(i) for i in range(self.N)])
194
195    def __repr__(self):
196        return f"{self.__class__.__name__}({self})"
197from typing import Sequence
198from heapq import heappush, heappop
199from array import array
200
201
202class WaveletMatrix:
203    """``WaveletMatrix`` です。
204    静的であることに注意してください。
205
206    以下の仕様の計算量には嘘があるかもしれません。import 元の ``BitVector`` の計算量も参考にしてください。
207
208    参考:
209      `https://miti-7.hatenablog.com/entry/2018/04/28/152259 <https://miti-7.hatenablog.com/entry/2018/04/28/152259>`_
210      `https://www.slideshare.net/pfi/ss-15916040 <https://www.slideshare.net/pfi/ss-15916040>`_
211      `デwiki <https://scrapbox.io/data-structures/Wavelet_Matrix>`_
212    """
213
214    def __init__(self, sigma: int, a: Sequence[int] = []) -> None:
215        """``[0, sigma)`` の整数列を管理する ``WaveletMatrix`` を構築します。
216        :math:`O(n\\log{\\sigma})` です。
217
218        Args:
219            sigma (int): 扱う整数の上限です。
220            a (Sequence[int], optional): 構築する配列です。
221        """
222        self.sigma: int = sigma
223        self.log: int = (sigma - 1).bit_length()
224        self.mid: array[int] = array("I", bytes(4 * self.log))
225        self.size: int = len(a)
226        self.v: list[BitVector] = [BitVector(self.size) for _ in range(self.log)]
227        self._build(a)
228
229    def _build(self, a: Sequence[int]) -> None:
230        # 列 a から wm を構築する
231        for bit in range(self.log - 1, -1, -1):
232            # bit目の0/1に応じてvを構築 + aを安定ソート
233            v = self.v[bit]
234            zero, one = [], []
235            for i, e in enumerate(a):
236                if e >> bit & 1:
237                    v.set(i)
238                    one.append(e)
239                else:
240                    zero.append(e)
241            v.build()
242            self.mid[bit] = len(zero)  # 境界をmid[bit]に保持
243            a = zero + one
244
245    def access(self, k: int) -> int:
246        """``k`` 番目の値を返します。
247        :math:`O(\\log{\\sigma})` です。
248
249        Args:
250            k (int): インデックスです。
251        """
252        assert (
253            -self.size <= k < self.size
254        ), f"IndexError: {self.__class__.__name__}.access({k}), size={self.size}"
255        if k < 0:
256            k += self.size
257        s = 0  # 答え
258        for bit in range(self.log - 1, -1, -1):
259            if self.v[bit].access(k):
260                # k番目が立ってたら、
261                # kまでの1とすべての0が次のk
262                s |= 1 << bit
263                k = self.v[bit].rank1(k) + self.mid[bit]
264            else:
265                # kまでの0が次のk
266                k = self.v[bit].rank0(k)
267        return s
268
269    def __getitem__(self, k: int) -> int:
270        assert (
271            -self.size <= k < self.size
272        ), f"IndexError: {self.__class__.__name__}[{k}], size={self.size}"
273        return self.access(k)
274
275    def rank(self, r: int, x: int) -> int:
276        """``a[0, r)`` に含まれる ``x`` の個数を返します。
277        :math:`O(\\log{\\sigma})` です。
278        """
279        assert (
280            0 <= r <= self.size
281        ), f"IndexError: {self.__class__.__name__}.rank(), r={r}, size={self.size}"
282        assert (
283            0 <= x < 1 << self.log
284        ), f"ValueError: {self.__class__.__name__}.rank(), x={x}, LIM={1<<self.log}"
285        l = 0
286        mid = self.mid
287        for bit in range(self.log - 1, -1, -1):
288            # 位置 r より左に x が何個あるか
289            # x の bit 目で場合分け
290            if x >> bit & 1:
291                # 立ってたら、次のl, rは以下
292                l = self.v[bit].rank1(l) + mid[bit]
293                r = self.v[bit].rank1(r) + mid[bit]
294            else:
295                # そうでなければ次のl, rは以下
296                l = self.v[bit].rank0(l)
297                r = self.v[bit].rank0(r)
298        return r - l
299
300    def select(self, k: int, x: int) -> int:
301        """``k`` 番目の ``v`` のインデックスを返します。
302        :math:`O(\\log{\\sigma})` です。
303        """
304        assert (
305            0 <= k < self.size
306        ), f"IndexError: {self.__class__.__name__}.select({k}, {x}), k={k}, size={self.size}"
307        assert (
308            0 <= x < 1 << self.log
309        ), f"ValueError: {self.__class__.__name__}.select({k}, {x}), x={x}, LIM={1<<self.log}"
310        # x の開始位置 s を探す
311        s = 0
312        for bit in range(self.log - 1, -1, -1):
313            if x >> bit & 1:
314                s = self.v[bit].rank0(self.size) + self.v[bit].rank1(s)
315            else:
316                s = self.v[bit].rank0(s)
317        s += k  # s から k 進んだ位置が、元の列で何番目か調べる
318        for bit in range(self.log):
319            if x >> bit & 1:
320                s = self.v[bit].select1(s - self.v[bit].rank0(self.size))
321            else:
322                s = self.v[bit].select0(s)
323        return s
324
325    def kth_smallest(self, l: int, r: int, k: int) -> int:
326        """``a[l, r)`` の中で ``k`` 番目に **小さい** 値を返します。
327        :math:`O(\\log{\\sigma})` です。
328        """
329        assert (
330            0 <= l <= r <= self.size
331        ), f"IndexError: {self.__class__.__name__}.kth_smallest({l}, {r}, {k}), size={self.size}"
332        assert (
333            0 <= k < r - l
334        ), f"IndexError: {self.__class__.__name__}.kth_smallest({l}, {r}, {k}), wrong k"
335        s = 0
336        mid = self.mid
337        for bit in range(self.log - 1, -1, -1):
338            r0, l0 = self.v[bit].rank0(r), self.v[bit].rank0(l)
339            cnt = r0 - l0  # 区間内の 0 の個数
340            if cnt <= k:  # 0 が k 以下のとき、 k 番目は 1
341                s |= 1 << bit
342                k -= cnt
343                # この 1 が次の bit 列でどこに行くか
344                l = l - l0 + mid[bit]
345                r = r - r0 + mid[bit]
346            else:
347                # この 0 が次の bit 列でどこに行くか
348                l = l0
349                r = r0
350        return s
351
352    quantile = kth_smallest
353
354    def kth_largest(self, l: int, r: int, k: int) -> int:
355        """``a[l, r)`` の中で ``k`` 番目に **大きい値** を返します。
356        :math:`O(\\log{\\sigma})` です。
357        """
358        assert (
359            0 <= l <= r <= self.size
360        ), f"IndexError: {self.__class__.__name__}.kth_largest({l}, {r}, {k}), size={self.size}"
361        assert (
362            0 <= k < r - l
363        ), f"IndexError: {self.__class__.__name__}.kth_largest({l}, {r}, {k}), wrong k"
364        return self.kth_smallest(l, r, r - l - k - 1)
365
366    def topk(self, l: int, r: int, k: int) -> list[tuple[int, int]]:
367        """``a[l, r)`` の中で、要素を出現回数が多い順にその頻度とともに ``k`` 個返します。
368        :math:`O(\\min(r-l, \\sigam) \\log(\\sigam))` です。
369
370        Note:
371            :math:`\\sigma` が大きい場合、計算量に注意です。
372
373        Returns:
374            list[tuple[int, int]]: ``(要素, 頻度)`` を要素とする配列です。
375        """
376        assert (
377            0 <= l <= r <= self.size
378        ), f"IndexError: {self.__class__.__name__}.topk({l}, {r}, {k}), size={self.size}"
379        assert (
380            0 <= k < r - l
381        ), f"IndexError: {self.__class__.__name__}.topk({l}, {r}, {k}), wrong k"
382        # heap[-length, x, l, bit]
383        hq: list[tuple[int, int, int, int]] = [(-(r - l), 0, l, self.log - 1)]
384        ans = []
385        while hq:
386            length, x, l, bit = heappop(hq)
387            length = -length
388            if bit == -1:
389                ans.append((x, length))
390                k -= 1
391                if k == 0:
392                    break
393            else:
394                r = l + length
395                l0 = self.v[bit].rank0(l)
396                r0 = self.v[bit].rank0(r)
397                if l0 < r0:
398                    heappush(hq, (-(r0 - l0), x, l0, bit - 1))
399                l1 = self.v[bit].rank1(l) + self.mid[bit]
400                r1 = self.v[bit].rank1(r) + self.mid[bit]
401                if l1 < r1:
402                    heappush(hq, (-(r1 - l1), x | (1 << bit), l1, bit - 1))
403        return ans
404
405    def sum(self, l: int, r: int) -> int:
406        """``topk`` メソッドを用いて ``a[l, r)`` の総和を返します。
407        計算量に注意です。
408        """
409        assert False, "Yabai Keisanryo Error"
410        return sum(k * v for k, v in self.topk(l, r, r - l))
411
412    def _range_freq(self, l: int, r: int, x: int) -> int:
413        """a[l, r) で x 未満の要素の数を返す"""
414        ans = 0
415        for bit in range(self.log - 1, -1, -1):
416            l0, r0 = self.v[bit].rank0(l), self.v[bit].rank0(r)
417            if x >> bit & 1:
418                # bit が立ってたら、区間の 0 の個数を答えに加算し、新たな区間は 1 のみ
419                ans += r0 - l0
420                # 1 が次の bit 列でどこに行くか
421                l += self.mid[bit] - l0
422                r += self.mid[bit] - r0
423            else:
424                # 0 が次の bit 列でどこに行くか
425                l, r = l0, r0
426        return ans
427
428    def range_freq(self, l: int, r: int, x: int, y: int) -> int:
429        """``a[l, r)`` に含まれる、 ``x`` 以上 ``y`` 未満である要素の個数を返します。
430        :math:`O(\\log{\\sigma})` です。
431        """
432        assert (
433            0 <= l <= r <= self.size
434        ), f"IndexError: {self.__class__.__name__}.range_freq({l}, {r}, {x}, {y})"
435        assert 0 <= x <= y < self.sigma, f"ValueError"
436        return self._range_freq(l, r, y) - self._range_freq(l, r, x)
437
438    def prev_value(self, l: int, r: int, x: int) -> int:
439        """``a[l, r)`` で、``x`` 以上 ``y`` 未満であるような要素のうち最大の要素を返します。
440        :math:`O(\\log{\\sigma})` です。
441        """
442        assert (
443            0 <= l <= r <= self.size
444        ), f"IndexError: {self.__class__.__name__}.prev_value({l}, {r}, {x})"
445        return self.kth_smallest(l, r, self._range_freq(l, r, x) - 1)
446
447    def next_value(self, l: int, r: int, x: int) -> int:
448        """``a[l, r)`` で、``x`` 以上 ``y`` 未満であるような要素のうち最小の要素を返します。
449        :math:`O(\\log{\\sigma})` です。
450        """
451        assert (
452            0 <= l <= r <= self.size
453        ), f"IndexError: {self.__class__.__name__}.next_value({l}, {r}, {x})"
454        return self.kth_smallest(l, r, self._range_freq(l, r, x))
455
456    def range_count(self, l: int, r: int, x: int) -> int:
457        """``a[l, r)`` に含まれる ``x`` の個数を返します。
458        ``wm.rank(r, x) - wm.rank(l, x)`` と等価です。
459        :math:`O(\\log{\\sigma})` です。
460        """
461        assert (
462            0 <= l <= r <= self.size
463        ), f"IndexError: {self.__class__.__name__}.range_count({l}, {r}, {x})"
464        return self.rank(r, x) - self.rank(l, x)
465
466    def __len__(self) -> int:
467        return self.size
468
469    def __str__(self) -> str:
470        return (
471            f"{self.__class__.__name__}({[self.access(i) for i in range(self.size)]})"
472        )
473
474    __repr__ = __str__

仕様

class WaveletMatrix(sigma: int, a: Sequence[int] = [])[source]

Bases: object

WaveletMatrix です。 静的であることに注意してください。

以下の仕様の計算量には嘘があるかもしれません。import 元の BitVector の計算量も参考にしてください。

参考:

https://miti-7.hatenablog.com/entry/2018/04/28/152259 https://www.slideshare.net/pfi/ss-15916040 デwiki

access(k: int) int[source]

k 番目の値を返します。 \(O(\log{\sigma})\) です。

Parameters:

k (int) – インデックスです。

kth_largest(l: int, r: int, k: int) int[source]

a[l, r) の中で k 番目に 大きい値 を返します。 \(O(\log{\sigma})\) です。

kth_smallest(l: int, r: int, k: int) int[source]

a[l, r) の中で k 番目に 小さい 値を返します。 \(O(\log{\sigma})\) です。

next_value(l: int, r: int, x: int) int[source]

a[l, r) で、x 以上 y 未満であるような要素のうち最小の要素を返します。 \(O(\log{\sigma})\) です。

prev_value(l: int, r: int, x: int) int[source]

a[l, r) で、x 以上 y 未満であるような要素のうち最大の要素を返します。 \(O(\log{\sigma})\) です。

quantile(l: int, r: int, k: int) int

a[l, r) の中で k 番目に 小さい 値を返します。 \(O(\log{\sigma})\) です。

range_count(l: int, r: int, x: int) int[source]

a[l, r) に含まれる x の個数を返します。 wm.rank(r, x) - wm.rank(l, x) と等価です。 \(O(\log{\sigma})\) です。

range_freq(l: int, r: int, x: int, y: int) int[source]

a[l, r) に含まれる、 x 以上 y 未満である要素の個数を返します。 \(O(\log{\sigma})\) です。

rank(r: int, x: int) int[source]

a[0, r) に含まれる x の個数を返します。 \(O(\log{\sigma})\) です。

select(k: int, x: int) int[source]

k 番目の v のインデックスを返します。 \(O(\log{\sigma})\) です。

sum(l: int, r: int) int[source]

topk メソッドを用いて a[l, r) の総和を返します。 計算量に注意です。

topk(l: int, r: int, k: int) list[tuple[int, int]][source]

a[l, r) の中で、要素を出現回数が多い順にその頻度とともに k 個返します。 \(O(\min(r-l, \sigam) \log(\sigam))\) です。

Note

\(\sigma\) が大きい場合、計算量に注意です。

Returns:

(要素, 頻度) を要素とする配列です。

Return type:

list[tuple[int, int]]