fenwick_tree_wavelet_matrix

ソースコード

from titan_pylib.data_structures.wavelet_matrix.fenwick_tree_wavelet_matrix import FenwickTreeWaveletMatrix

view on github

展開済みコード

  1# from titan_pylib.data_structures.wavelet_matrix.fenwick_tree_wavelet_matrix import FenwickTreeWaveletMatrix
  2# from titan_pylib.data_structures.bit_vector.bit_vector import BitVector
  3# from titan_pylib.data_structures.bit_vector.bit_vector_interface import (
  4#     BitVectorInterface,
  5# )
  6from abc import ABC, abstractmethod
  7
  8
  9class BitVectorInterface(ABC):
 10
 11    @abstractmethod
 12    def access(self, k: int) -> int:
 13        raise NotImplementedError
 14
 15    @abstractmethod
 16    def __getitem__(self, k: int) -> int:
 17        raise NotImplementedError
 18
 19    @abstractmethod
 20    def rank0(self, r: int) -> int:
 21        raise NotImplementedError
 22
 23    @abstractmethod
 24    def rank1(self, r: int) -> int:
 25        raise NotImplementedError
 26
 27    @abstractmethod
 28    def rank(self, r: int, v: int) -> int:
 29        raise NotImplementedError
 30
 31    @abstractmethod
 32    def select0(self, k: int) -> int:
 33        raise NotImplementedError
 34
 35    @abstractmethod
 36    def select1(self, k: int) -> int:
 37        raise NotImplementedError
 38
 39    @abstractmethod
 40    def select(self, k: int, v: int) -> int:
 41        raise NotImplementedError
 42
 43    @abstractmethod
 44    def __len__(self) -> int:
 45        raise NotImplementedError
 46
 47    @abstractmethod
 48    def __str__(self) -> str:
 49        raise NotImplementedError
 50
 51    @abstractmethod
 52    def __repr__(self) -> str:
 53        raise NotImplementedError
 54from array import array
 55
 56
 57class BitVector(BitVectorInterface):
 58    """コンパクトな bit vector です。"""
 59
 60    def __init__(self, n: int):
 61        """長さ ``n`` の ``BitVector`` です。
 62
 63        bit を保持するのに ``array[I]`` を使用します。
 64        ``block_size= n / 32`` として、使用bitは ``32*block_size=2n bit`` です。
 65
 66        累積和を保持するのに同様の ``array[I]`` を使用します。
 67        32bitごとの和を保存しています。同様に使用bitは ``2n bit`` です。
 68        """
 69        assert 0 <= n < 4294967295
 70        self.N = n
 71        self.block_size = (n + 31) >> 5
 72        b = bytes(4 * (self.block_size + 1))
 73        self.bit = array("I", b)
 74        self.acc = array("I", b)
 75
 76    @staticmethod
 77    def _popcount(x: int) -> int:
 78        x = x - ((x >> 1) & 0x55555555)
 79        x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
 80        x = x + (x >> 4) & 0x0F0F0F0F
 81        x += x >> 8
 82        x += x >> 16
 83        return x & 0x0000007F
 84
 85    def set(self, k: int) -> None:
 86        """``k`` 番目の bit を ``1`` にします。
 87        :math:`O(1)` です。
 88
 89        Args:
 90          k (int): インデックスです。
 91        """
 92        self.bit[k >> 5] |= 1 << (k & 31)
 93
 94    def build(self) -> None:
 95        """構築します。
 96        **これ以降 ``set`` メソッドを使用してはいけません。**
 97        :math:`O(n)` です。
 98        """
 99        acc, bit = self.acc, self.bit
100        for i in range(self.block_size):
101            acc[i + 1] = acc[i] + BitVector._popcount(bit[i])
102
103    def access(self, k: int) -> int:
104        """``k`` 番目の bit を返します。
105        :math:`O(1)` です。
106        """
107        return (self.bit[k >> 5] >> (k & 31)) & 1
108
109    def __getitem__(self, k: int) -> int:
110        return (self.bit[k >> 5] >> (k & 31)) & 1
111
112    def rank0(self, r: int) -> int:
113        """``a[0, r)`` に含まれる ``0`` の個数を返します。
114        :math:`O(1)` です。
115        """
116        return r - (
117            self.acc[r >> 5]
118            + BitVector._popcount(self.bit[r >> 5] & ((1 << (r & 31)) - 1))
119        )
120
121    def rank1(self, r: int) -> int:
122        """``a[0, r)`` に含まれる ``1`` の個数を返します。
123        :math:`O(1)` です。
124        """
125        return self.acc[r >> 5] + BitVector._popcount(
126            self.bit[r >> 5] & ((1 << (r & 31)) - 1)
127        )
128
129    def rank(self, r: int, v: int) -> int:
130        """``a[0, r)`` に含まれる ``v`` の個数を返します。
131        :math:`O(1)` です。
132        """
133        return self.rank1(r) if v else self.rank0(r)
134
135    def select0(self, k: int) -> int:
136        """``k`` 番目の ``0`` のインデックスを返します。
137        :math:`O(\\log{n})` です。
138        """
139        if k < 0 or self.rank0(self.N) <= k:
140            return -1
141        l, r = 0, self.block_size + 1
142        while r - l > 1:
143            m = (l + r) >> 1
144            if m * 32 - self.acc[m] > k:
145                r = m
146            else:
147                l = m
148        indx = 32 * l
149        k = k - (l * 32 - self.acc[l]) + self.rank0(indx)
150        l, r = indx, indx + 32
151        while r - l > 1:
152            m = (l + r) >> 1
153            if self.rank0(m) > k:
154                r = m
155            else:
156                l = m
157        return l
158
159    def select1(self, k: int) -> int:
160        """``k`` 番目の ``1`` のインデックスを返します。
161        :math:`O(\\log{n})` です。
162        """
163        if k < 0 or self.rank1(self.N) <= k:
164            return -1
165        l, r = 0, self.block_size + 1
166        while r - l > 1:
167            m = (l + r) >> 1
168            if self.acc[m] > k:
169                r = m
170            else:
171                l = m
172        indx = 32 * l
173        k = k - self.acc[l] + self.rank1(indx)
174        l, r = indx, indx + 32
175        while r - l > 1:
176            m = (l + r) >> 1
177            if self.rank1(m) > k:
178                r = m
179            else:
180                l = m
181        return l
182
183    def select(self, k: int, v: int) -> int:
184        """``k`` 番目の ``v`` のインデックスを返します。
185        :math:`O(\\log{n})` です。
186        """
187        return self.select1(k) if v else self.select0(k)
188
189    def __len__(self):
190        return self.N
191
192    def __str__(self):
193        return str([self.access(i) for i in range(self.N)])
194
195    def __repr__(self):
196        return f"{self.__class__.__name__}({self})"
197# from titan_pylib.data_structures.fenwick_tree.fenwick_tree import FenwickTree
198from typing import Union, Iterable, Optional
199
200
201class FenwickTree:
202    """FenwickTreeです。"""
203
204    def __init__(self, n_or_a: Union[Iterable[int], int]):
205        """構築します。
206        :math:`O(n)` です。
207
208        Args:
209          n_or_a (Union[Iterable[int], int]): `n_or_a` が `int` のとき、初期値 `0` 、長さ `n` で構築します。
210                                              `n_or_a` が `Iterable` のとき、初期値 `a` で構築します。
211        """
212        if isinstance(n_or_a, int):
213            self._size = n_or_a
214            self._tree = [0] * (self._size + 1)
215        else:
216            a = n_or_a if isinstance(n_or_a, list) else list(n_or_a)
217            _size = len(a)
218            _tree = [0] + a
219            for i in range(1, _size):
220                if i + (i & -i) <= _size:
221                    _tree[i + (i & -i)] += _tree[i]
222            self._size = _size
223            self._tree = _tree
224        self._s = 1 << (self._size - 1).bit_length()
225
226    def pref(self, r: int) -> int:
227        """区間 ``[0, r)`` の総和を返します。
228        :math:`O(\\log{n})` です。
229        """
230        assert (
231            0 <= r <= self._size
232        ), f"IndexError: {self.__class__.__name__}.pref({r}), n={self._size}"
233        ret, _tree = 0, self._tree
234        while r > 0:
235            ret += _tree[r]
236            r &= r - 1
237        return ret
238
239    def suff(self, l: int) -> int:
240        """区間 ``[l, n)`` の総和を返します。
241        :math:`O(\\log{n})` です。
242        """
243        assert (
244            0 <= l < self._size
245        ), f"IndexError: {self.__class__.__name__}.suff({l}), n={self._size}"
246        return self.pref(self._size) - self.pref(l)
247
248    def sum(self, l: int, r: int) -> int:
249        """区間 ``[l, r)`` の総和を返します。
250        :math:`O(\\log{n})` です。
251        """
252        assert (
253            0 <= l <= r <= self._size
254        ), f"IndexError: {self.__class__.__name__}.sum({l}, {r}), n={self._size}"
255        _tree = self._tree
256        res = 0
257        while r > l:
258            res += _tree[r]
259            r &= r - 1
260        while l > r:
261            res -= _tree[l]
262            l &= l - 1
263        return res
264
265    prod = sum
266
267    def __getitem__(self, k: int) -> int:
268        """位置 ``k`` の要素を返します。
269        :math:`O(\\log{n})` です。
270        """
271        assert (
272            -self._size <= k < self._size
273        ), f"IndexError: {self.__class__.__name__}[{k}], n={self._size}"
274        if k < 0:
275            k += self._size
276        return self.sum(k, k + 1)
277
278    def add(self, k: int, x: int) -> None:
279        """``k`` 番目の値に ``x`` を加えます。
280        :math:`O(\\log{n})` です。
281        """
282        assert (
283            0 <= k < self._size
284        ), f"IndexError: {self.__class__.__name__}.add({k}, {x}), n={self._size}"
285        k += 1
286        _tree = self._tree
287        while k <= self._size:
288            _tree[k] += x
289            k += k & -k
290
291    def __setitem__(self, k: int, x: int):
292        """``k`` 番目の値を ``x`` に更新します。
293        :math:`O(\\log{n})` です。
294        """
295        assert (
296            -self._size <= k < self._size
297        ), f"IndexError: {self.__class__.__name__}[{k}] = {x}, n={self._size}"
298        if k < 0:
299            k += self._size
300        pre = self[k]
301        self.add(k, x - pre)
302
303    def bisect_left(self, w: int) -> Optional[int]:
304        i, s, _size, _tree = 0, self._s, self._size, self._tree
305        while s:
306            if i + s <= _size and _tree[i + s] < w:
307                w -= _tree[i + s]
308                i += s
309            s >>= 1
310        return i if w else None
311
312    def bisect_right(self, w: int) -> int:
313        i, s, _size, _tree = 0, self._s, self._size, self._tree
314        while s:
315            if i + s <= _size and _tree[i + s] <= w:
316                w -= _tree[i + s]
317                i += s
318            s >>= 1
319        return i
320
321    def _pop(self, k: int) -> int:
322        assert k >= 0
323        i, acc, s, _size, _tree = 0, 0, self._s, self._size, self._tree
324        while s:
325            if i + s <= _size:
326                if acc + _tree[i + s] <= k:
327                    acc += _tree[i + s]
328                    i += s
329                else:
330                    _tree[i + s] -= 1
331            s >>= 1
332        return i
333
334    def tolist(self) -> list[int]:
335        """リストにして返します。
336        :math:`O(n)` です。
337        """
338        sub = [self.pref(i) for i in range(self._size + 1)]
339        return [sub[i + 1] - sub[i] for i in range(self._size)]
340
341    @staticmethod
342    def get_inversion_num(a: list[int], compress: bool = False) -> int:
343        inv = 0
344        if compress:
345            a_ = sorted(set(a))
346            z = {e: i for i, e in enumerate(a_)}
347            fw = FenwickTree(len(a_) + 1)
348            for i, e in enumerate(a):
349                inv += i - fw.pref(z[e] + 1)
350                fw.add(z[e], 1)
351        else:
352            fw = FenwickTree(len(a) + 1)
353            for i, e in enumerate(a):
354                inv += i - fw.pref(e + 1)
355                fw.add(e, 1)
356        return inv
357
358    def __str__(self):
359        return str(self.tolist())
360
361    def __repr__(self):
362        return f"{self.__class__.__name__}({self})"
363from array import array
364from typing import Sequence
365from bisect import bisect_left
366
367
368class FenwickTreeWaveletMatrix:
369
370    def __init__(self, sigma: int, pos: list[tuple[int, int, int]] = []):
371        self.sigma: int = sigma
372        self.log: int = (sigma - 1).bit_length()
373        self.mid: array[int] = array("I", bytes(4 * self.log))
374        self.xy: list[tuple[int, int]] = self._sort_unique([(x, y) for x, y, _ in pos])
375        self.y: list[int] = self._sort_unique([y for _, y, _ in pos])
376        self.size: int = len(self.xy)
377        self.v: list[BitVector] = [BitVector(self.size) for _ in range(self.log)]
378        self._build([bisect_left(self.y, y) for _, y in self.xy])
379        ws = [[0] * self.size for _ in range(self.log)]
380        for x, y, w in pos:
381            k = bisect_left(self.xy, (x, y))
382            i_y = bisect_left(self.y, y)
383            for bit in range(self.log - 1, -1, -1):
384                if i_y >> bit & 1:
385                    k = self.v[bit].rank1(k) + self.mid[bit]
386                else:
387                    k = self.v[bit].rank0(k)
388                ws[bit][k] += w
389        self.bit: list[FenwickTree] = [FenwickTree(a) for a in ws]
390
391    def _build(self, a: Sequence[int]) -> None:
392        # 列 a から wm を構築する
393        for bit in range(self.log - 1, -1, -1):
394            # bit目の0/1に応じてvを構築 + aを安定ソート
395            v = self.v[bit]
396            zero, one = [], []
397            for i, e in enumerate(a):
398                if e >> bit & 1:
399                    v.set(i)
400                    one.append(e)
401                else:
402                    zero.append(e)
403            v.build()
404            self.mid[bit] = len(zero)  # 境界をmid[bit]に保持
405            a = zero + one
406
407    def _sort_unique(self, a: list) -> list:
408        if not a:
409            return a
410        a.sort()
411        b = [a[0]]
412        for e in a:
413            if b[-1] == e:
414                continue
415            b.append(e)
416        return b
417
418    def add_point(self, x: int, y: int, w: int) -> None:
419        k = bisect_left(self.xy, (x, y))
420        i_y = bisect_left(self.y, y)
421        for bit in range(self.log - 1, -1, -1):
422            if i_y >> bit & 1:
423                k = self.v[bit].rank1(k) + self.mid[bit]
424            else:
425                k = self.v[bit].rank0(k)
426            self.bit[bit].add(k, w)
427
428    def _sum(self, l: int, r: int, x: int) -> int:
429        ans = 0
430        for bit in range(self.log - 1, -1, -1):
431            l0, r0 = self.v[bit].rank0(l), self.v[bit].rank0(r)
432            if x >> bit & 1:
433                l += self.mid[bit] - l0
434                r += self.mid[bit] - r0
435                ans += self.bit[bit].sum(l0, r0)
436            else:
437                l, r = l0, r0
438        return ans
439
440    def sum(self, w1: int, w2: int, h1: int, h2: int) -> int:
441        # sum([w1, w2) x [h1, h2))
442        l = bisect_left(self.xy, (w1, 0))
443        r = bisect_left(self.xy, (w2, 0))
444        return self._sum(l, r, bisect_left(self.y, h2)) - self._sum(
445            l, r, bisect_left(self.y, h1)
446        )

仕様

class FenwickTreeWaveletMatrix(sigma: int, pos: list[tuple[int, int, int]] = [])[source]

Bases: object

add_point(x: int, y: int, w: int) None[source]
sum(w1: int, w2: int, h1: int, h2: int) int[source]