fenwick_tree

ソースコード

from titan_pylib.data_structures.fenwick_tree.fenwick_tree import FenwickTree

view on github

展開済みコード

  1# from titan_pylib.data_structures.fenwick_tree.fenwick_tree import FenwickTree
  2from typing import Union, Iterable, Optional
  3
  4
  5class FenwickTree:
  6    """FenwickTreeです。"""
  7
  8    def __init__(self, n_or_a: Union[Iterable[int], int]):
  9        """構築します。
 10        :math:`O(n)` です。
 11
 12        Args:
 13          n_or_a (Union[Iterable[int], int]): `n_or_a` が `int` のとき、初期値 `0` 、長さ `n` で構築します。
 14                                              `n_or_a` が `Iterable` のとき、初期値 `a` で構築します。
 15        """
 16        if isinstance(n_or_a, int):
 17            self._size = n_or_a
 18            self._tree = [0] * (self._size + 1)
 19        else:
 20            a = n_or_a if isinstance(n_or_a, list) else list(n_or_a)
 21            _size = len(a)
 22            _tree = [0] + a
 23            for i in range(1, _size):
 24                if i + (i & -i) <= _size:
 25                    _tree[i + (i & -i)] += _tree[i]
 26            self._size = _size
 27            self._tree = _tree
 28        self._s = 1 << (self._size - 1).bit_length()
 29
 30    def pref(self, r: int) -> int:
 31        """区間 ``[0, r)`` の総和を返します。
 32        :math:`O(\\log{n})` です。
 33        """
 34        assert (
 35            0 <= r <= self._size
 36        ), f"IndexError: {self.__class__.__name__}.pref({r}), n={self._size}"
 37        ret, _tree = 0, self._tree
 38        while r > 0:
 39            ret += _tree[r]
 40            r &= r - 1
 41        return ret
 42
 43    def suff(self, l: int) -> int:
 44        """区間 ``[l, n)`` の総和を返します。
 45        :math:`O(\\log{n})` です。
 46        """
 47        assert (
 48            0 <= l < self._size
 49        ), f"IndexError: {self.__class__.__name__}.suff({l}), n={self._size}"
 50        return self.pref(self._size) - self.pref(l)
 51
 52    def sum(self, l: int, r: int) -> int:
 53        """区間 ``[l, r)`` の総和を返します。
 54        :math:`O(\\log{n})` です。
 55        """
 56        assert (
 57            0 <= l <= r <= self._size
 58        ), f"IndexError: {self.__class__.__name__}.sum({l}, {r}), n={self._size}"
 59        _tree = self._tree
 60        res = 0
 61        while r > l:
 62            res += _tree[r]
 63            r &= r - 1
 64        while l > r:
 65            res -= _tree[l]
 66            l &= l - 1
 67        return res
 68
 69    prod = sum
 70
 71    def __getitem__(self, k: int) -> int:
 72        """位置 ``k`` の要素を返します。
 73        :math:`O(\\log{n})` です。
 74        """
 75        assert (
 76            -self._size <= k < self._size
 77        ), f"IndexError: {self.__class__.__name__}[{k}], n={self._size}"
 78        if k < 0:
 79            k += self._size
 80        return self.sum(k, k + 1)
 81
 82    def add(self, k: int, x: int) -> None:
 83        """``k`` 番目の値に ``x`` を加えます。
 84        :math:`O(\\log{n})` です。
 85        """
 86        assert (
 87            0 <= k < self._size
 88        ), f"IndexError: {self.__class__.__name__}.add({k}, {x}), n={self._size}"
 89        k += 1
 90        _tree = self._tree
 91        while k <= self._size:
 92            _tree[k] += x
 93            k += k & -k
 94
 95    def __setitem__(self, k: int, x: int):
 96        """``k`` 番目の値を ``x`` に更新します。
 97        :math:`O(\\log{n})` です。
 98        """
 99        assert (
100            -self._size <= k < self._size
101        ), f"IndexError: {self.__class__.__name__}[{k}] = {x}, n={self._size}"
102        if k < 0:
103            k += self._size
104        pre = self[k]
105        self.add(k, x - pre)
106
107    def bisect_left(self, w: int) -> Optional[int]:
108        i, s, _size, _tree = 0, self._s, self._size, self._tree
109        while s:
110            if i + s <= _size and _tree[i + s] < w:
111                w -= _tree[i + s]
112                i += s
113            s >>= 1
114        return i if w else None
115
116    def bisect_right(self, w: int) -> int:
117        i, s, _size, _tree = 0, self._s, self._size, self._tree
118        while s:
119            if i + s <= _size and _tree[i + s] <= w:
120                w -= _tree[i + s]
121                i += s
122            s >>= 1
123        return i
124
125    def _pop(self, k: int) -> int:
126        assert k >= 0
127        i, acc, s, _size, _tree = 0, 0, self._s, self._size, self._tree
128        while s:
129            if i + s <= _size:
130                if acc + _tree[i + s] <= k:
131                    acc += _tree[i + s]
132                    i += s
133                else:
134                    _tree[i + s] -= 1
135            s >>= 1
136        return i
137
138    def tolist(self) -> list[int]:
139        """リストにして返します。
140        :math:`O(n)` です。
141        """
142        sub = [self.pref(i) for i in range(self._size + 1)]
143        return [sub[i + 1] - sub[i] for i in range(self._size)]
144
145    @staticmethod
146    def get_inversion_num(a: list[int], compress: bool = False) -> int:
147        inv = 0
148        if compress:
149            a_ = sorted(set(a))
150            z = {e: i for i, e in enumerate(a_)}
151            fw = FenwickTree(len(a_) + 1)
152            for i, e in enumerate(a):
153                inv += i - fw.pref(z[e] + 1)
154                fw.add(z[e], 1)
155        else:
156            fw = FenwickTree(len(a) + 1)
157            for i, e in enumerate(a):
158                inv += i - fw.pref(e + 1)
159                fw.add(e, 1)
160        return inv
161
162    def __str__(self):
163        return str(self.tolist())
164
165    def __repr__(self):
166        return f"{self.__class__.__name__}({self})"

仕様

class FenwickTree(n_or_a: Iterable[int] | int)[source]

Bases: object

FenwickTreeです。

__getitem__(k: int) int[source]

位置 k の要素を返します。 \(O(\log{n})\) です。

__setitem__(k: int, x: int)[source]

k 番目の値を x に更新します。 \(O(\log{n})\) です。

add(k: int, x: int) None[source]

k 番目の値に x を加えます。 \(O(\log{n})\) です。

bisect_left(w: int) int | None[source]
bisect_right(w: int) int[source]
static get_inversion_num(a: list[int], compress: bool = False) int[source]
pref(r: int) int[source]

区間 [0, r) の総和を返します。 \(O(\log{n})\) です。

prod(l: int, r: int) int

区間 [l, r) の総和を返します。 \(O(\log{n})\) です。

suff(l: int) int[source]

区間 [l, n) の総和を返します。 \(O(\log{n})\) です。

sum(l: int, r: int) int[source]

区間 [l, r) の総和を返します。 \(O(\log{n})\) です。

tolist() list[int][source]

リストにして返します。 \(O(n)\) です。