cumulative_sum_wavelet_matrix

ソースコード

from titan_pylib.data_structures.wavelet_matrix.cumulative_sum_wavelet_matrix import CumulativeSumWaveletMatrix

view on github

展開済みコード

  1# from titan_pylib.data_structures.wavelet_matrix.cumulative_sum_wavelet_matrix import CumulativeSumWaveletMatrix
  2# from titan_pylib.data_structures.bit_vector.bit_vector import BitVector
  3# from titan_pylib.data_structures.bit_vector.bit_vector_interface import (
  4#     BitVectorInterface,
  5# )
  6from abc import ABC, abstractmethod
  7
  8
  9class BitVectorInterface(ABC):
 10
 11    @abstractmethod
 12    def access(self, k: int) -> int:
 13        raise NotImplementedError
 14
 15    @abstractmethod
 16    def __getitem__(self, k: int) -> int:
 17        raise NotImplementedError
 18
 19    @abstractmethod
 20    def rank0(self, r: int) -> int:
 21        raise NotImplementedError
 22
 23    @abstractmethod
 24    def rank1(self, r: int) -> int:
 25        raise NotImplementedError
 26
 27    @abstractmethod
 28    def rank(self, r: int, v: int) -> int:
 29        raise NotImplementedError
 30
 31    @abstractmethod
 32    def select0(self, k: int) -> int:
 33        raise NotImplementedError
 34
 35    @abstractmethod
 36    def select1(self, k: int) -> int:
 37        raise NotImplementedError
 38
 39    @abstractmethod
 40    def select(self, k: int, v: int) -> int:
 41        raise NotImplementedError
 42
 43    @abstractmethod
 44    def __len__(self) -> int:
 45        raise NotImplementedError
 46
 47    @abstractmethod
 48    def __str__(self) -> str:
 49        raise NotImplementedError
 50
 51    @abstractmethod
 52    def __repr__(self) -> str:
 53        raise NotImplementedError
 54from array import array
 55
 56
 57class BitVector(BitVectorInterface):
 58    """コンパクトな bit vector です。"""
 59
 60    def __init__(self, n: int):
 61        """長さ ``n`` の ``BitVector`` です。
 62
 63        bit を保持するのに ``array[I]`` を使用します。
 64        ``block_size= n / 32`` として、使用bitは ``32*block_size=2n bit`` です。
 65
 66        累積和を保持するのに同様の ``array[I]`` を使用します。
 67        32bitごとの和を保存しています。同様に使用bitは ``2n bit`` です。
 68        """
 69        assert 0 <= n < 4294967295
 70        self.N = n
 71        self.block_size = (n + 31) >> 5
 72        b = bytes(4 * (self.block_size + 1))
 73        self.bit = array("I", b)
 74        self.acc = array("I", b)
 75
 76    @staticmethod
 77    def _popcount(x: int) -> int:
 78        x = x - ((x >> 1) & 0x55555555)
 79        x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
 80        x = x + (x >> 4) & 0x0F0F0F0F
 81        x += x >> 8
 82        x += x >> 16
 83        return x & 0x0000007F
 84
 85    def set(self, k: int) -> None:
 86        """``k`` 番目の bit を ``1`` にします。
 87        :math:`O(1)` です。
 88
 89        Args:
 90          k (int): インデックスです。
 91        """
 92        self.bit[k >> 5] |= 1 << (k & 31)
 93
 94    def build(self) -> None:
 95        """構築します。
 96        **これ以降 ``set`` メソッドを使用してはいけません。**
 97        :math:`O(n)` です。
 98        """
 99        acc, bit = self.acc, self.bit
100        for i in range(self.block_size):
101            acc[i + 1] = acc[i] + BitVector._popcount(bit[i])
102
103    def access(self, k: int) -> int:
104        """``k`` 番目の bit を返します。
105        :math:`O(1)` です。
106        """
107        return (self.bit[k >> 5] >> (k & 31)) & 1
108
109    def __getitem__(self, k: int) -> int:
110        return (self.bit[k >> 5] >> (k & 31)) & 1
111
112    def rank0(self, r: int) -> int:
113        """``a[0, r)`` に含まれる ``0`` の個数を返します。
114        :math:`O(1)` です。
115        """
116        return r - (
117            self.acc[r >> 5]
118            + BitVector._popcount(self.bit[r >> 5] & ((1 << (r & 31)) - 1))
119        )
120
121    def rank1(self, r: int) -> int:
122        """``a[0, r)`` に含まれる ``1`` の個数を返します。
123        :math:`O(1)` です。
124        """
125        return self.acc[r >> 5] + BitVector._popcount(
126            self.bit[r >> 5] & ((1 << (r & 31)) - 1)
127        )
128
129    def rank(self, r: int, v: int) -> int:
130        """``a[0, r)`` に含まれる ``v`` の個数を返します。
131        :math:`O(1)` です。
132        """
133        return self.rank1(r) if v else self.rank0(r)
134
135    def select0(self, k: int) -> int:
136        """``k`` 番目の ``0`` のインデックスを返します。
137        :math:`O(\\log{n})` です。
138        """
139        if k < 0 or self.rank0(self.N) <= k:
140            return -1
141        l, r = 0, self.block_size + 1
142        while r - l > 1:
143            m = (l + r) >> 1
144            if m * 32 - self.acc[m] > k:
145                r = m
146            else:
147                l = m
148        indx = 32 * l
149        k = k - (l * 32 - self.acc[l]) + self.rank0(indx)
150        l, r = indx, indx + 32
151        while r - l > 1:
152            m = (l + r) >> 1
153            if self.rank0(m) > k:
154                r = m
155            else:
156                l = m
157        return l
158
159    def select1(self, k: int) -> int:
160        """``k`` 番目の ``1`` のインデックスを返します。
161        :math:`O(\\log{n})` です。
162        """
163        if k < 0 or self.rank1(self.N) <= k:
164            return -1
165        l, r = 0, self.block_size + 1
166        while r - l > 1:
167            m = (l + r) >> 1
168            if self.acc[m] > k:
169                r = m
170            else:
171                l = m
172        indx = 32 * l
173        k = k - self.acc[l] + self.rank1(indx)
174        l, r = indx, indx + 32
175        while r - l > 1:
176            m = (l + r) >> 1
177            if self.rank1(m) > k:
178                r = m
179            else:
180                l = m
181        return l
182
183    def select(self, k: int, v: int) -> int:
184        """``k`` 番目の ``v`` のインデックスを返します。
185        :math:`O(\\log{n})` です。
186        """
187        return self.select1(k) if v else self.select0(k)
188
189    def __len__(self):
190        return self.N
191
192    def __str__(self):
193        return str([self.access(i) for i in range(self.N)])
194
195    def __repr__(self):
196        return f"{self.__class__.__name__}({self})"
197# from titan_pylib.data_structures.cumulative_sum.cumulative_sum import CumulativeSum
198from typing import Iterable
199
200
201class CumulativeSum:
202    """1次元累積和です。"""
203
204    def __init__(self, a: Iterable[int], e: int = 0):
205        """
206        :math:`O(n)` です。
207
208        Args:
209          a (Iterable[int]): ``CumulativeSum`` を構築する配列です。
210          e (int): 単位元です。デフォルトは ``0`` です。
211        """
212        a = list(a)
213        n = len(a)
214        acc = [e] * (n + 1)
215        for i in range(n):
216            acc[i + 1] = acc[i] + a[i]
217        self.n = n
218        self.acc = acc
219        self.a = a
220
221    def pref(self, r: int) -> int:
222        """区間 ``[0, r)`` の演算結果を返します。
223        :math:`O(1)` です。
224
225        Args:
226          r (int): インデックスです。
227        """
228        assert (
229            0 <= r <= self.n
230        ), f"IndexError: {self.__class__.__name__}.pref({r}), n={self.n}"
231        return self.acc[r]
232
233    def all_sum(self) -> int:
234        """区間 `[0, n)` の演算結果を返します。
235        :math:`O(1)` です。
236
237        Args:
238          l (int): インデックスです。
239          r (int): インデックスです。
240        """
241        return self.acc[-1]
242
243    def sum(self, l: int, r: int) -> int:
244        """区間 `[l, r)` の演算結果を返します。
245        :math:`O(1)` です。
246
247        Args:
248          l (int): インデックスです。
249          r (int): インデックスです。
250        """
251        assert (
252            0 <= l <= r <= self.n
253        ), f"IndexError: {self.__class__.__name__}.sum({l}, {r}), n={self.n}"
254        return self.acc[r] - self.acc[l]
255
256    prod = sum
257    all_prod = all_sum
258
259    def __getitem__(self, k: int) -> int:
260        assert (
261            -self.n <= k < self.n
262        ), f"IndexError: {self.__class__.__name__}[{k}], n={self.n}"
263        return self.a[k]
264
265    def __len__(self) -> int:
266        return len(self.a)
267
268    def __str__(self) -> str:
269        return str(self.acc)
270
271    __repr__ = __str__
272from array import array
273from typing import Sequence, Iterable
274from bisect import bisect_left
275
276
277class CumulativeSumWaveletMatrix:
278
279    def __init__(self, sigma: int, pos: Iterable[tuple[int, int, int]] = []) -> None:
280        """
281        Args:
282          sigma (int): yの最大値
283          pos (list[tuple[int, int, int]], optional): list[(x, y, w)]
284        """
285        self.sigma: int = sigma
286        self.log: int = (sigma - 1).bit_length()
287        self.mid: array[int] = array("I", bytes(4 * self.log))
288        self.xy: list[tuple[int, int]] = self._sort_unique([(x, y) for x, y, _ in pos])
289        self.y: list[int] = self._sort_unique([y for _, y in self.xy])
290        self.size: int = len(self.xy)
291        self.v: list[BitVector] = [BitVector(self.size) for _ in range(self.log)]
292        self._build([bisect_left(self.y, y) for _, y in self.xy])
293        ws = [[0] * self.size for _ in range(self.log)]
294        for x, y, w in pos:
295            k = bisect_left(self.xy, (x, y))
296            i_y = bisect_left(self.y, y)
297            for bit in range(self.log - 1, -1, -1):
298                if i_y >> bit & 1:
299                    k = self.v[bit].rank1(k) + self.mid[bit]
300                else:
301                    k = self.v[bit].rank0(k)
302                ws[bit][k] += w
303        self.bit: list[CumulativeSum] = [CumulativeSum(a) for a in ws]
304
305    def _build(self, a: Sequence[int]) -> None:
306        # 列 a から wm を構築する
307        for bit in range(self.log - 1, -1, -1):
308            # bit目の0/1に応じてvを構築 + aを安定ソート
309            v = self.v[bit]
310            zero, one = [], []
311            for i, e in enumerate(a):
312                if e >> bit & 1:
313                    v.set(i)
314                    one.append(e)
315                else:
316                    zero.append(e)
317            v.build()
318            self.mid[bit] = len(zero)  # 境界をmid[bit]に保持
319            a = zero + one
320
321    def _sort_unique(self, a: list) -> list:
322        if not a:
323            return a
324        a.sort()
325        b = [a[0]]
326        for e in a:
327            if b[-1] == e:
328                continue
329            b.append(e)
330        return b
331
332    def _sum(self, l: int, r: int, x: int) -> int:
333        ans = 0
334        for bit in range(self.log - 1, -1, -1):
335            l0, r0 = self.v[bit].rank0(l), self.v[bit].rank0(r)
336            if x >> bit & 1:
337                l += self.mid[bit] - l0
338                r += self.mid[bit] - r0
339                ans += self.bit[bit].sum(l0, r0)
340            else:
341                l, r = l0, r0
342        return ans
343
344    def sum(self, w1: int, w2: int, h1: int, h2: int) -> int:
345        """sum([w1, w2) x [h1, h2))"""
346        assert 0 <= w1 <= w2
347        assert 0 <= h1 <= h2
348        l = bisect_left(self.xy, (w1, 0))
349        r = bisect_left(self.xy, (w2, 0))
350        return self._sum(l, r, bisect_left(self.y, h2)) - self._sum(
351            l, r, bisect_left(self.y, h1)
352        )

仕様

class CumulativeSumWaveletMatrix(sigma: int, pos: Iterable[tuple[int, int, int]] = [])[source]

Bases: object

sum([w1, w2) x [h1, h2))[source]