fenwick_tree_set

ソースコード

from titan_pylib.data_structures.set.fenwick_tree_set import FenwickTreeSet

view on github

展開済みコード

  1# from titan_pylib.data_structures.set.fenwick_tree_set import FenwickTreeSet
  2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  3from typing import Protocol
  4
  5
  6class SupportsLessThan(Protocol):
  7
  8    def __lt__(self, other) -> bool: ...
  9# from titan_pylib.data_structures.fenwick_tree.fenwick_tree import FenwickTree
 10from typing import Union, Iterable, Optional
 11
 12
 13class FenwickTree:
 14    """FenwickTreeです。"""
 15
 16    def __init__(self, n_or_a: Union[Iterable[int], int]):
 17        """構築します。
 18        :math:`O(n)` です。
 19
 20        Args:
 21          n_or_a (Union[Iterable[int], int]): `n_or_a` が `int` のとき、初期値 `0` 、長さ `n` で構築します。
 22                                              `n_or_a` が `Iterable` のとき、初期値 `a` で構築します。
 23        """
 24        if isinstance(n_or_a, int):
 25            self._size = n_or_a
 26            self._tree = [0] * (self._size + 1)
 27        else:
 28            a = n_or_a if isinstance(n_or_a, list) else list(n_or_a)
 29            _size = len(a)
 30            _tree = [0] + a
 31            for i in range(1, _size):
 32                if i + (i & -i) <= _size:
 33                    _tree[i + (i & -i)] += _tree[i]
 34            self._size = _size
 35            self._tree = _tree
 36        self._s = 1 << (self._size - 1).bit_length()
 37
 38    def pref(self, r: int) -> int:
 39        """区間 ``[0, r)`` の総和を返します。
 40        :math:`O(\\log{n})` です。
 41        """
 42        assert (
 43            0 <= r <= self._size
 44        ), f"IndexError: {self.__class__.__name__}.pref({r}), n={self._size}"
 45        ret, _tree = 0, self._tree
 46        while r > 0:
 47            ret += _tree[r]
 48            r &= r - 1
 49        return ret
 50
 51    def suff(self, l: int) -> int:
 52        """区間 ``[l, n)`` の総和を返します。
 53        :math:`O(\\log{n})` です。
 54        """
 55        assert (
 56            0 <= l < self._size
 57        ), f"IndexError: {self.__class__.__name__}.suff({l}), n={self._size}"
 58        return self.pref(self._size) - self.pref(l)
 59
 60    def sum(self, l: int, r: int) -> int:
 61        """区間 ``[l, r)`` の総和を返します。
 62        :math:`O(\\log{n})` です。
 63        """
 64        assert (
 65            0 <= l <= r <= self._size
 66        ), f"IndexError: {self.__class__.__name__}.sum({l}, {r}), n={self._size}"
 67        _tree = self._tree
 68        res = 0
 69        while r > l:
 70            res += _tree[r]
 71            r &= r - 1
 72        while l > r:
 73            res -= _tree[l]
 74            l &= l - 1
 75        return res
 76
 77    prod = sum
 78
 79    def __getitem__(self, k: int) -> int:
 80        """位置 ``k`` の要素を返します。
 81        :math:`O(\\log{n})` です。
 82        """
 83        assert (
 84            -self._size <= k < self._size
 85        ), f"IndexError: {self.__class__.__name__}[{k}], n={self._size}"
 86        if k < 0:
 87            k += self._size
 88        return self.sum(k, k + 1)
 89
 90    def add(self, k: int, x: int) -> None:
 91        """``k`` 番目の値に ``x`` を加えます。
 92        :math:`O(\\log{n})` です。
 93        """
 94        assert (
 95            0 <= k < self._size
 96        ), f"IndexError: {self.__class__.__name__}.add({k}, {x}), n={self._size}"
 97        k += 1
 98        _tree = self._tree
 99        while k <= self._size:
100            _tree[k] += x
101            k += k & -k
102
103    def __setitem__(self, k: int, x: int):
104        """``k`` 番目の値を ``x`` に更新します。
105        :math:`O(\\log{n})` です。
106        """
107        assert (
108            -self._size <= k < self._size
109        ), f"IndexError: {self.__class__.__name__}[{k}] = {x}, n={self._size}"
110        if k < 0:
111            k += self._size
112        pre = self[k]
113        self.add(k, x - pre)
114
115    def bisect_left(self, w: int) -> Optional[int]:
116        i, s, _size, _tree = 0, self._s, self._size, self._tree
117        while s:
118            if i + s <= _size and _tree[i + s] < w:
119                w -= _tree[i + s]
120                i += s
121            s >>= 1
122        return i if w else None
123
124    def bisect_right(self, w: int) -> int:
125        i, s, _size, _tree = 0, self._s, self._size, self._tree
126        while s:
127            if i + s <= _size and _tree[i + s] <= w:
128                w -= _tree[i + s]
129                i += s
130            s >>= 1
131        return i
132
133    def _pop(self, k: int) -> int:
134        assert k >= 0
135        i, acc, s, _size, _tree = 0, 0, self._s, self._size, self._tree
136        while s:
137            if i + s <= _size:
138                if acc + _tree[i + s] <= k:
139                    acc += _tree[i + s]
140                    i += s
141                else:
142                    _tree[i + s] -= 1
143            s >>= 1
144        return i
145
146    def tolist(self) -> list[int]:
147        """リストにして返します。
148        :math:`O(n)` です。
149        """
150        sub = [self.pref(i) for i in range(self._size + 1)]
151        return [sub[i + 1] - sub[i] for i in range(self._size)]
152
153    @staticmethod
154    def get_inversion_num(a: list[int], compress: bool = False) -> int:
155        inv = 0
156        if compress:
157            a_ = sorted(set(a))
158            z = {e: i for i, e in enumerate(a_)}
159            fw = FenwickTree(len(a_) + 1)
160            for i, e in enumerate(a):
161                inv += i - fw.pref(z[e] + 1)
162                fw.add(z[e], 1)
163        else:
164            fw = FenwickTree(len(a) + 1)
165            for i, e in enumerate(a):
166                inv += i - fw.pref(e + 1)
167                fw.add(e, 1)
168        return inv
169
170    def __str__(self):
171        return str(self.tolist())
172
173    def __repr__(self):
174        return f"{self.__class__.__name__}({self})"
175from typing import Iterable, TypeVar, Generic, Union, Optional
176
177T = TypeVar("T", bound=SupportsLessThan)
178
179
180class FenwickTreeSet(Generic[T]):
181
182    def __init__(
183        self,
184        _used: Union[int, Iterable[T]],
185        _a: Iterable[T] = [],
186        compress=True,
187        _multi=False,
188    ) -> None:
189        self._len = 0
190        if isinstance(_used, int):
191            self._to_origin = list(range(_used))
192        elif isinstance(_used, set):
193            self._to_origin = sorted(_used)
194        else:
195            self._to_origin = sorted(set(_used))
196        self._to_zaatsu: dict[T, int] = (
197            {key: i for i, key in enumerate(self._to_origin)}
198            if compress
199            else self._to_origin
200        )
201        self._size = len(self._to_origin)
202        self._cnt = [0] * self._size
203        _a = list(_a)
204        if _a:
205            a_ = [0] * self._size
206            if _multi:
207                self._len = len(_a)
208                for v in _a:
209                    i = self._to_zaatsu[v]
210                    a_[i] += 1
211                    self._cnt[i] += 1
212            else:
213                for v in _a:
214                    i = self._to_zaatsu[v]
215                    if self._cnt[i] == 0:
216                        self._len += 1
217                        a_[i] = 1
218                        self._cnt[i] = 1
219            self._fw = FenwickTree(a_)
220        else:
221            self._fw = FenwickTree(self._size)
222
223    def add(self, key: T) -> bool:
224        i = self._to_zaatsu[key]
225        if self._cnt[i]:
226            return False
227        self._len += 1
228        self._cnt[i] = 1
229        self._fw.add(i, 1)
230        return True
231
232    def remove(self, key: T) -> None:
233        if not self.discard(key):
234            raise KeyError(key)
235
236    def discard(self, key: T) -> bool:
237        i = self._to_zaatsu[key]
238        if self._cnt[i]:
239            self._len -= 1
240            self._cnt[i] = 0
241            self._fw.add(i, -1)
242            return True
243        return False
244
245    def le(self, key: T) -> Optional[T]:
246        i = self._to_zaatsu[key]
247        if self._cnt[i]:
248            return key
249        pref = self._fw.pref(i) - 1
250        return None if pref < 0 else self._to_origin[self._fw.bisect_right(pref)]
251
252    def lt(self, key: T) -> Optional[T]:
253        pref = self._fw.pref(self._to_zaatsu[key]) - 1
254        return None if pref < 0 else self._to_origin[self._fw.bisect_right(pref)]
255
256    def ge(self, key: T) -> Optional[T]:
257        i = self._to_zaatsu[key]
258        if self._cnt[i]:
259            return key
260        pref = self._fw.pref(i + 1)
261        return (
262            None if pref >= self._len else self._to_origin[self._fw.bisect_right(pref)]
263        )
264
265    def gt(self, key: T) -> Optional[T]:
266        pref = self._fw.pref(self._to_zaatsu[key] + 1)
267        return (
268            None if pref >= self._len else self._to_origin[self._fw.bisect_right(pref)]
269        )
270
271    def index(self, key: T) -> int:
272        return self._fw.pref(self._to_zaatsu[key])
273
274    def index_right(self, key: T) -> int:
275        return self._fw.pref(self._to_zaatsu[key] + 1)
276
277    def pop(self, k: int = -1) -> T:
278        assert (
279            -self._len <= k < self._len
280        ), f"IndexError: FenwickTreeSet.pop({k}), Index out of range."
281        if k < 0:
282            k += self._len
283        self._len -= 1
284        x = self._fw._pop(k)
285        self._cnt[x] = 0
286        return self._to_origin[x]
287
288    def pop_min(self) -> T:
289        assert (
290            self._len > 0
291        ), f"IndexError: pop_min() from empty {self.__class__.__name__}."
292        return self.pop(0)
293
294    def pop_max(self) -> T:
295        assert (
296            self._len > 0
297        ), f"IndexError: pop_max() from empty {self.__class__.__name__}."
298        return self.pop(-1)
299
300    def get_min(self) -> Optional[T]:
301        if not self:
302            return
303        return self[0]
304
305    def get_max(self) -> Optional[T]:
306        if not self:
307            return
308        return self[-1]
309
310    def __getitem__(self, k):
311        assert (
312            -self._len <= k < self._len
313        ), f"IndexError: FenwickTreeSet[{k}], Index out of range."
314        if k < 0:
315            k += self._len
316        return self._to_origin[self._fw.bisect_right(k)]
317
318    def __iter__(self):
319        self._iter = 0
320        return self
321
322    def __next__(self):
323        if self._iter == self._len:
324            raise StopIteration
325        res = self._to_origin[self._fw.bisect_right(self._iter)]
326        self._iter += 1
327        return res
328
329    def __reversed__(self):
330        _to_origin = self._to_origin
331        for i in range(self._len):
332            yield _to_origin[self._fw.bisect_right(self._len - i - 1)]
333
334    def __len__(self):
335        return self._len
336
337    def __contains__(self, key: T):
338        return self._cnt[self._to_zaatsu[key]] > 0
339
340    def __bool__(self):
341        return self._len > 0
342
343    def __str__(self):
344        return "{" + ", ".join(map(str, self)) + "}"
345
346    def __repr__(self):
347        return f"{self.__class__.__name__}({self})"

仕様

class FenwickTreeSet(_used: int | Iterable[T], _a: Iterable[T] = [], compress=True, _multi=False)[source]

Bases: Generic[T]

add(key: T) bool[source]
discard(key: T) bool[source]
ge(key: T) T | None[source]
get_max() T | None[source]
get_min() T | None[source]
gt(key: T) T | None[source]
index(key: T) int[source]
index_right(key: T) int[source]
le(key: T) T | None[source]
lt(key: T) T | None[source]
pop(k: int = -1) T[source]
pop_max() T[source]
pop_min() T[source]
remove(key: T) None[source]