dynamic_wavelet_matrix

ソースコード

from titan_pylib.data_structures.wavelet_matrix.dynamic_wavelet_matrix import DynamicWaveletMatrix

view on github

展開済みコード

   1# from titan_pylib.data_structures.wavelet_matrix.dynamic_wavelet_matrix import DynamicWaveletMatrix
   2# from titan_pylib.data_structures.bit_vector.avl_tree_bit_vector import AVLTreeBitVector
   3# from titan_pylib.data_structures.bit_vector.bit_vector_interface import (
   4#     BitVectorInterface,
   5# )
   6from abc import ABC, abstractmethod
   7
   8
   9class BitVectorInterface(ABC):
  10
  11    @abstractmethod
  12    def access(self, k: int) -> int:
  13        raise NotImplementedError
  14
  15    @abstractmethod
  16    def __getitem__(self, k: int) -> int:
  17        raise NotImplementedError
  18
  19    @abstractmethod
  20    def rank0(self, r: int) -> int:
  21        raise NotImplementedError
  22
  23    @abstractmethod
  24    def rank1(self, r: int) -> int:
  25        raise NotImplementedError
  26
  27    @abstractmethod
  28    def rank(self, r: int, v: int) -> int:
  29        raise NotImplementedError
  30
  31    @abstractmethod
  32    def select0(self, k: int) -> int:
  33        raise NotImplementedError
  34
  35    @abstractmethod
  36    def select1(self, k: int) -> int:
  37        raise NotImplementedError
  38
  39    @abstractmethod
  40    def select(self, k: int, v: int) -> int:
  41        raise NotImplementedError
  42
  43    @abstractmethod
  44    def __len__(self) -> int:
  45        raise NotImplementedError
  46
  47    @abstractmethod
  48    def __str__(self) -> str:
  49        raise NotImplementedError
  50
  51    @abstractmethod
  52    def __repr__(self) -> str:
  53        raise NotImplementedError
  54from array import array
  55from typing import Iterable, Final, Sequence
  56
  57titan_pylib_AVLTreeBitVector_W: Final[int] = 31
  58
  59
  60class AVLTreeBitVector(BitVectorInterface):
  61    """AVL木で書かれたビットベクトルです。簡潔でもなんでもありません。
  62
  63    bit列を管理するわけですが、各節点は 1~32 bit を持つようにしています。
  64    これにより、最大 32 倍高速化が行えます。(16~32bitとするといいんだろうけど)
  65    """
  66
  67    @staticmethod
  68    def _popcount(x: int) -> int:
  69        x = x - ((x >> 1) & 0x55555555)
  70        x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
  71        x = x + (x >> 4) & 0x0F0F0F0F
  72        x += x >> 8
  73        x += x >> 16
  74        return x & 0x0000007F
  75
  76    def __init__(self, a: Iterable[int] = []):
  77        """
  78        :math:`O(n)` です。
  79
  80        Args:
  81          a (Iterable[int], optional): 構築元の配列です。
  82        """
  83        self.root = 0
  84        self.bit_len = array("B", bytes(1))
  85        self.key = array("I", bytes(4))
  86        self.size = array("I", bytes(4))
  87        self.total = array("I", bytes(4))
  88        self.left = array("I", bytes(4))
  89        self.right = array("I", bytes(4))
  90        self.balance = array("b", bytes(1))
  91        self.end = 1
  92        if a:
  93            self._build(a)
  94
  95    def reserve(self, n: int) -> None:
  96        """``n`` 要素分のメモリを確保します。
  97        :math:`O(n)` です。
  98        """
  99        n = n // titan_pylib_AVLTreeBitVector_W + 1
 100        a = array("I", bytes(4 * n))
 101        self.bit_len += array("B", bytes(n))
 102        self.key += a
 103        self.size += a
 104        self.total += a
 105        self.left += a
 106        self.right += a
 107        self.balance += array("b", bytes(n))
 108
 109    def _build(self, a: Iterable[int]) -> None:
 110        key, bit_len, left, right, size, balance, total = (
 111            self.key,
 112            self.bit_len,
 113            self.left,
 114            self.right,
 115            self.size,
 116            self.balance,
 117            self.total,
 118        )
 119        _popcount = AVLTreeBitVector._popcount
 120
 121        def rec(lr: int) -> int:
 122            l, r = lr >> bit, lr & msk
 123            mid = (l + r) >> 1
 124            hl, hr = 0, 0
 125            if l != mid:
 126                le = rec(l << bit | mid)
 127                left[mid], hl = le >> bit, le & msk
 128                size[mid] += size[left[mid]]
 129                total[mid] += total[left[mid]]
 130            if mid + 1 != r:
 131                ri = rec((mid + 1) << bit | r)
 132                right[mid], hr = ri >> bit, ri & msk
 133                size[mid] += size[right[mid]]
 134                total[mid] += total[right[mid]]
 135            balance[mid] = hl - hr
 136            return mid << bit | (max(hl, hr) + 1)
 137
 138        if not isinstance(a, Sequence):
 139            a = list(a)
 140        n = len(a)
 141        bit = n.bit_length() + 2
 142        msk = (1 << bit) - 1
 143        end = self.end
 144        self.reserve(n)
 145        i = 0
 146        indx = end
 147        for i in range(0, n, titan_pylib_AVLTreeBitVector_W):
 148            j = 0
 149            v = 0
 150            while j < titan_pylib_AVLTreeBitVector_W and i + j < n:
 151                v <<= 1
 152                v |= a[i + j]
 153                j += 1
 154            key[indx] = v
 155            bit_len[indx] = j
 156            size[indx] = j
 157            total[indx] = _popcount(v)
 158            indx += 1
 159        self.end = indx
 160        self.root = rec(end << bit | self.end) >> bit
 161
 162    def _rotate_L(self, node: int) -> int:
 163        left, right, size, balance, total = (
 164            self.left,
 165            self.right,
 166            self.size,
 167            self.balance,
 168            self.total,
 169        )
 170        u = left[node]
 171        size[u] = size[node]
 172        total[u] = total[node]
 173        size[node] -= size[left[u]] + self.bit_len[u]
 174        total[node] -= total[left[u]] + AVLTreeBitVector._popcount(self.key[u])
 175        left[node] = right[u]
 176        right[u] = node
 177        if balance[u] == 1:
 178            balance[u] = 0
 179            balance[node] = 0
 180        else:
 181            balance[u] = -1
 182            balance[node] = 1
 183        return u
 184
 185    def _rotate_R(self, node: int) -> int:
 186        left, right, size, balance, total = (
 187            self.left,
 188            self.right,
 189            self.size,
 190            self.balance,
 191            self.total,
 192        )
 193        u = right[node]
 194        size[u] = size[node]
 195        total[u] = total[node]
 196        size[node] -= size[right[u]] + self.bit_len[u]
 197        total[node] -= total[right[u]] + AVLTreeBitVector._popcount(self.key[u])
 198        right[node] = left[u]
 199        left[u] = node
 200        if balance[u] == -1:
 201            balance[u] = 0
 202            balance[node] = 0
 203        else:
 204            balance[u] = 1
 205            balance[node] = -1
 206        return u
 207
 208    def _update_balance(self, node: int) -> None:
 209        balance = self.balance
 210        if balance[node] == 1:
 211            balance[self.right[node]] = -1
 212            balance[self.left[node]] = 0
 213        elif balance[node] == -1:
 214            balance[self.right[node]] = 0
 215            balance[self.left[node]] = 1
 216        else:
 217            balance[self.right[node]] = 0
 218            balance[self.left[node]] = 0
 219        balance[node] = 0
 220
 221    def _rotate_LR(self, node: int) -> int:
 222        left, right, size, total = self.left, self.right, self.size, self.total
 223        B = left[node]
 224        E = right[B]
 225        size[E] = size[node]
 226        size[node] -= size[B] - size[right[E]]
 227        size[B] -= size[right[E]] + self.bit_len[E]
 228        total[E] = total[node]
 229        total[node] -= total[B] - total[right[E]]
 230        total[B] -= total[right[E]] + AVLTreeBitVector._popcount(self.key[E])
 231        right[B] = left[E]
 232        left[E] = B
 233        left[node] = right[E]
 234        right[E] = node
 235        self._update_balance(E)
 236        return E
 237
 238    def _rotate_RL(self, node: int) -> int:
 239        left, right, size, total = self.left, self.right, self.size, self.total
 240        C = right[node]
 241        D = left[C]
 242        size[D] = size[node]
 243        size[node] -= size[C] - size[left[D]]
 244        size[C] -= size[left[D]] + self.bit_len[D]
 245        total[D] = total[node]
 246        total[node] -= total[C] - total[left[D]]
 247        total[C] -= total[left[D]] + AVLTreeBitVector._popcount(self.key[D])
 248        left[C] = right[D]
 249        right[D] = C
 250        right[node] = left[D]
 251        left[D] = node
 252        self._update_balance(D)
 253        return D
 254
 255    def _pref(self, r: int) -> int:
 256        left, right, bit_len, size, key, total = (
 257            self.left,
 258            self.right,
 259            self.bit_len,
 260            self.size,
 261            self.key,
 262            self.total,
 263        )
 264        node = self.root
 265        s = 0
 266        while r > 0:
 267            t = size[left[node]] + bit_len[node]
 268            if t - bit_len[node] < r <= t:
 269                r -= size[left[node]]
 270                s += total[left[node]] + AVLTreeBitVector._popcount(
 271                    key[node] >> (bit_len[node] - r)
 272                )
 273                break
 274            if t > r:
 275                node = left[node]
 276            else:
 277                s += total[left[node]] + AVLTreeBitVector._popcount(key[node])
 278                node = right[node]
 279                r -= t
 280        return s
 281
 282    def _make_node(self, key: int, bit_len: int) -> int:
 283        end = self.end
 284        if end >= len(self.key):
 285            self.key.append(key)
 286            self.bit_len.append(bit_len)
 287            self.size.append(bit_len)
 288            self.total.append(AVLTreeBitVector._popcount(key))
 289            self.left.append(0)
 290            self.right.append(0)
 291            self.balance.append(0)
 292        else:
 293            self.key[end] = key
 294            self.bit_len[end] = bit_len
 295            self.size[end] = bit_len
 296            self.total[end] = AVLTreeBitVector._popcount(key)
 297        self.end += 1
 298        return end
 299
 300    def insert(self, k: int, key: int) -> None:
 301        """``k`` 番目に ``v`` を挿入します。
 302        :math:`O(\\log{n})` です。
 303
 304        Args:
 305          k (int): 挿入位置のインデックスです。
 306          key (int): 挿入する値です。 ``0`` または ``1`` である必要があります。
 307        """
 308        if self.root == 0:
 309            self.root = self._make_node(key, 1)
 310            return
 311        left, right, size, bit_len, balance, keys, total = (
 312            self.left,
 313            self.right,
 314            self.size,
 315            self.bit_len,
 316            self.balance,
 317            self.key,
 318            self.total,
 319        )
 320        node = self.root
 321        path = []
 322        d = 0
 323        while node:
 324            t = size[left[node]] + bit_len[node]
 325            if t - bit_len[node] <= k <= t:
 326                break
 327            d <<= 1
 328            size[node] += 1
 329            total[node] += key
 330            path.append(node)
 331            node = left[node] if t > k else right[node]
 332            if t > k:
 333                d |= 1
 334            else:
 335                k -= t
 336        k -= size[left[node]]
 337        if bit_len[node] < titan_pylib_AVLTreeBitVector_W:
 338            v = keys[node]
 339            bl = bit_len[node] - k
 340            keys[node] = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
 341            bit_len[node] += 1
 342            size[node] += 1
 343            total[node] += key
 344            return
 345        path.append(node)
 346        size[node] += 1
 347        total[node] += key
 348        v = keys[node]
 349        bl = titan_pylib_AVLTreeBitVector_W - k
 350        v = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
 351        left_key = v >> titan_pylib_AVLTreeBitVector_W
 352        left_key_popcount = left_key & 1
 353        keys[node] = v & ((1 << titan_pylib_AVLTreeBitVector_W) - 1)
 354        node = left[node]
 355        d <<= 1
 356        d |= 1
 357        if not node:
 358            if bit_len[path[-1]] < titan_pylib_AVLTreeBitVector_W:
 359                bit_len[path[-1]] += 1
 360                keys[path[-1]] = (keys[path[-1]] << 1) | left_key
 361                return
 362            else:
 363                left[path[-1]] = self._make_node(left_key, 1)
 364        else:
 365            path.append(node)
 366            size[node] += 1
 367            total[node] += left_key_popcount
 368            d <<= 1
 369            while right[node]:
 370                node = right[node]
 371                path.append(node)
 372                size[node] += 1
 373                total[node] += left_key_popcount
 374                d <<= 1
 375            if bit_len[node] < titan_pylib_AVLTreeBitVector_W:
 376                bit_len[node] += 1
 377                keys[node] = (keys[node] << 1) | left_key
 378                return
 379            else:
 380                right[node] = self._make_node(left_key, 1)
 381        new_node = 0
 382        while path:
 383            node = path.pop()
 384            balance[node] += 1 if d & 1 else -1
 385            d >>= 1
 386            if balance[node] == 0:
 387                break
 388            if balance[node] == 2:
 389                new_node = (
 390                    self._rotate_LR(node)
 391                    if balance[left[node]] == -1
 392                    else self._rotate_L(node)
 393                )
 394                break
 395            elif balance[node] == -2:
 396                new_node = (
 397                    self._rotate_RL(node)
 398                    if balance[right[node]] == 1
 399                    else self._rotate_R(node)
 400                )
 401                break
 402        if new_node:
 403            if path:
 404                if d & 1:
 405                    left[path[-1]] = new_node
 406                else:
 407                    right[path[-1]] = new_node
 408            else:
 409                self.root = new_node
 410
 411    def _pop_under(self, path: list[int], d: int, node: int, res: int) -> None:
 412        left, right, size, bit_len, balance, keys, total = (
 413            self.left,
 414            self.right,
 415            self.size,
 416            self.bit_len,
 417            self.balance,
 418            self.key,
 419            self.total,
 420        )
 421        fd, lmax_total, lmax_bit_len = 0, 0, 0
 422        if left[node] and right[node]:
 423            path.append(node)
 424            d <<= 1
 425            d |= 1
 426            lmax = left[node]
 427            while right[lmax]:
 428                path.append(lmax)
 429                d <<= 1
 430                fd <<= 1
 431                fd |= 1
 432                lmax = right[lmax]
 433            lmax_total = AVLTreeBitVector._popcount(keys[lmax])
 434            lmax_bit_len = bit_len[lmax]
 435            keys[node] = keys[lmax]
 436            bit_len[node] = lmax_bit_len
 437            node = lmax
 438        cnode = right[node] if left[node] == 0 else left[node]
 439        if path:
 440            if d & 1:
 441                left[path[-1]] = cnode
 442            else:
 443                right[path[-1]] = cnode
 444        else:
 445            self.root = cnode
 446            return
 447        while path:
 448            new_node = 0
 449            node = path.pop()
 450            balance[node] -= 1 if d & 1 else -1
 451            size[node] -= lmax_bit_len if fd & 1 else 1
 452            total[node] -= lmax_total if fd & 1 else res
 453            d >>= 1
 454            fd >>= 1
 455            if balance[node] == 2:
 456                new_node = (
 457                    self._rotate_LR(node)
 458                    if balance[left[node]] < 0
 459                    else self._rotate_L(node)
 460                )
 461            elif balance[node] == -2:
 462                new_node = (
 463                    self._rotate_RL(node)
 464                    if balance[right[node]] > 0
 465                    else self._rotate_R(node)
 466                )
 467            elif balance[node] != 0:
 468                break
 469            if new_node:
 470                if not path:
 471                    self.root = new_node
 472                    return
 473                if d & 1:
 474                    left[path[-1]] = new_node
 475                else:
 476                    right[path[-1]] = new_node
 477                if balance[new_node] != 0:
 478                    break
 479        while path:
 480            node = path.pop()
 481            size[node] -= lmax_bit_len if fd & 1 else 1
 482            total[node] -= lmax_total if fd & 1 else res
 483            fd >>= 1
 484
 485    def pop(self, k: int) -> int:
 486        """``k`` 番目の要素を削除し、その値を返します。
 487        :math:`O(\\log{n})` です。
 488
 489        Args:
 490          k (int): 削除位置のインデックスです。
 491        """
 492        assert 0 <= k < len(self)
 493        left, right, size = self.left, self.right, self.size
 494        bit_len, keys, total = self.bit_len, self.key, self.total
 495        node = self.root
 496        d = 0
 497        path = []
 498        while node:
 499            t = size[left[node]] + bit_len[node]
 500            if t - bit_len[node] <= k < t:
 501                break
 502            path.append(node)
 503            node = left[node] if t > k else right[node]
 504            d <<= 1
 505            if t > k:
 506                d |= 1
 507            else:
 508                k -= t
 509        k -= size[left[node]]
 510        v = keys[node]
 511        res = v >> (bit_len[node] - k - 1) & 1
 512        if bit_len[node] == 1:
 513            self._pop_under(path, d, node, res)
 514            return res
 515        keys[node] = ((v >> (bit_len[node] - k)) << ((bit_len[node] - k - 1))) | (
 516            v & ((1 << (bit_len[node] - k - 1)) - 1)
 517        )
 518        bit_len[node] -= 1
 519        size[node] -= 1
 520        total[node] -= res
 521        for p in path:
 522            size[p] -= 1
 523            total[p] -= res
 524        return res
 525
 526    def set(self, k: int, v: int) -> None:
 527        """``k`` 番目の値を ``v`` に更新します。
 528        :math:`O(\\log{n})` です。
 529
 530        Args:
 531          k (int): 更新位置のインデックスです。
 532          key (int): 更新する値です。 ``0`` または ``1`` である必要があります。
 533        """
 534        self.__setitem__(k, v)
 535
 536    def tolist(self) -> list[int]:
 537        """リストにして返します。
 538        :math:`O(n)` です。
 539        """
 540        left, right, key, bit_len = self.left, self.right, self.key, self.bit_len
 541        a = []
 542        if not self.root:
 543            return a
 544
 545        def rec(node):
 546            if left[node]:
 547                rec(left[node])
 548            for i in range(bit_len[node] - 1, -1, -1):
 549                a.append(key[node] >> i & 1)
 550            if right[node]:
 551                rec(right[node])
 552
 553        rec(self.root)
 554        return a
 555
 556    def _debug_acc(self) -> None:
 557        """デバッグ用のメソッドです。
 558        key,totalをチェックします。
 559        """
 560        left, right = self.left, self.right
 561        key = self.key
 562
 563        def rec(node):
 564            acc = self._popcount(key[node])
 565            if left[node]:
 566                acc += rec(left[node])
 567            if right[node]:
 568                acc += rec(right[node])
 569            if acc != self.total[node]:
 570                # self.debug()
 571                assert False, "acc Error"
 572            return acc
 573
 574        rec(self.root)
 575        print("debug_acc ok.")
 576
 577    def access(self, k: int) -> int:
 578        """``k`` 番目の値を返します。
 579        :math:`O(\\log{n})` です。
 580
 581        Args:
 582          k (int): 取得位置のインデックスです。
 583        """
 584        return self.__getitem__(k)
 585
 586    def rank0(self, r: int) -> int:
 587        """``a[0, r)`` に含まれる ``0`` の個数を返します。
 588        :math:`O(\\log{n})` です。
 589        """
 590        return r - self._pref(r)
 591
 592    def rank1(self, r: int) -> int:
 593        """``a[0, r)`` に含まれる ``1`` の個数を返します。
 594        :math:`O(\\log{n})` です。
 595        """
 596        return self._pref(r)
 597
 598    def rank(self, r: int, v: int) -> int:
 599        """``a[0, r)`` に含まれる ``v`` の個数を返します。
 600        :math:`O(\\log{n})` です。
 601        """
 602        return self.rank1(r) if v else self.rank0(r)
 603
 604    def select0(self, k: int) -> int:
 605        """``k`` 番目の ``0`` のインデックスを返します。
 606        :math:`O(\\log{n}^2)` です。
 607        """
 608        if k < 0 or self.rank0(len(self)) <= k:
 609            return -1
 610        l, r = 0, len(self)
 611        while r - l > 1:
 612            m = (l + r) >> 1
 613            if m - self._pref(m) > k:
 614                r = m
 615            else:
 616                l = m
 617        return l
 618
 619    def select1(self, k: int) -> int:
 620        """``k`` 番目の ``1`` のインデックスを返します。
 621        :math:`O(\\log{n}^2)` です。
 622        """
 623        if k < 0 or self.rank1(len(self)) <= k:
 624            return -1
 625        l, r = 0, len(self)
 626        while r - l > 1:
 627            m = (l + r) >> 1
 628            if self._pref(m) > k:
 629                r = m
 630            else:
 631                l = m
 632        return l
 633
 634    def select(self, k: int, v: int) -> int:
 635        """``k`` 番目の ``v`` のインデックスを返します。
 636        :math:`O(\\log{n}^2)` です。
 637        """
 638        return self.select1(k) if v else self.select0(k)
 639
 640    def _insert_and_rank1(self, k: int, key: int) -> int:
 641        if self.root == 0:
 642            self.root = self._make_node(key, 1)
 643            return 0
 644        left, right, size, bit_len, balance, keys, total = (
 645            self.left,
 646            self.right,
 647            self.size,
 648            self.bit_len,
 649            self.balance,
 650            self.key,
 651            self.total,
 652        )
 653        node = self.root
 654        s = 0
 655        path = []
 656        d = 0
 657        while node:
 658            t = size[left[node]] + bit_len[node]
 659            if t - bit_len[node] <= k <= t:
 660                break
 661            if t <= k:
 662                s += total[left[node]] + AVLTreeBitVector._popcount(keys[node])
 663            d <<= 1
 664            size[node] += 1
 665            total[node] += key
 666            path.append(node)
 667            node = left[node] if t > k else right[node]
 668            if t > k:
 669                d |= 1
 670            else:
 671                k -= t
 672        k -= size[left[node]]
 673        s += total[left[node]] + AVLTreeBitVector._popcount(
 674            keys[node] >> (bit_len[node] - k)
 675        )
 676        if bit_len[node] < titan_pylib_AVLTreeBitVector_W:
 677            v = keys[node]
 678            bl = bit_len[node] - k
 679            keys[node] = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
 680            bit_len[node] += 1
 681            size[node] += 1
 682            total[node] += key
 683            return s
 684        path.append(node)
 685        size[node] += 1
 686        total[node] += key
 687        v = keys[node]
 688        bl = titan_pylib_AVLTreeBitVector_W - k
 689        v = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
 690        left_key = v >> titan_pylib_AVLTreeBitVector_W
 691        left_key_popcount = left_key & 1
 692        keys[node] = v & ((1 << titan_pylib_AVLTreeBitVector_W) - 1)
 693        node = left[node]
 694        d <<= 1
 695        d |= 1
 696        if not node:
 697            if bit_len[path[-1]] < titan_pylib_AVLTreeBitVector_W:
 698                bit_len[path[-1]] += 1
 699                keys[path[-1]] = (keys[path[-1]] << 1) | left_key
 700                return s
 701            else:
 702                left[path[-1]] = self._make_node(left_key, 1)
 703        else:
 704            path.append(node)
 705            size[node] += 1
 706            total[node] += left_key_popcount
 707            d <<= 1
 708            while right[node]:
 709                node = right[node]
 710                path.append(node)
 711                size[node] += 1
 712                total[node] += left_key_popcount
 713                d <<= 1
 714            if bit_len[node] < titan_pylib_AVLTreeBitVector_W:
 715                bit_len[node] += 1
 716                keys[node] = (keys[node] << 1) | left_key
 717                return s
 718            else:
 719                right[node] = self._make_node(left_key, 1)
 720        new_node = 0
 721        while path:
 722            node = path.pop()
 723            balance[node] += 1 if d & 1 else -1
 724            d >>= 1
 725            if balance[node] == 0:
 726                break
 727            if balance[node] == 2:
 728                new_node = (
 729                    self._rotate_LR(node)
 730                    if balance[left[node]] == -1
 731                    else self._rotate_L(node)
 732                )
 733                break
 734            elif balance[node] == -2:
 735                new_node = (
 736                    self._rotate_RL(node)
 737                    if balance[right[node]] == 1
 738                    else self._rotate_R(node)
 739                )
 740                break
 741        if new_node:
 742            if path:
 743                if d & 1:
 744                    left[path[-1]] = new_node
 745                else:
 746                    right[path[-1]] = new_node
 747            else:
 748                self.root = new_node
 749        return s
 750
 751    def _access_pop_and_rank1(self, k: int) -> int:
 752        assert 0 <= k < len(self)
 753        left, right, size = self.left, self.right, self.size
 754        bit_len, keys, total = self.bit_len, self.key, self.total
 755        s = 0
 756        node = self.root
 757        d = 0
 758        path = []
 759        while node:
 760            t = size[left[node]] + bit_len[node]
 761            if t - bit_len[node] <= k < t:
 762                break
 763            if t <= k:
 764                s += total[left[node]] + AVLTreeBitVector._popcount(keys[node])
 765            path.append(node)
 766            node = left[node] if t > k else right[node]
 767            d <<= 1
 768            if t > k:
 769                d |= 1
 770            else:
 771                k -= t
 772        k -= size[left[node]]
 773        s += total[left[node]] + AVLTreeBitVector._popcount(
 774            keys[node] >> (bit_len[node] - k)
 775        )
 776        v = keys[node]
 777        res = v >> (bit_len[node] - k - 1) & 1
 778        if bit_len[node] == 1:
 779            self._pop_under(path, d, node, res)
 780            return s << 1 | res
 781        keys[node] = ((v >> (bit_len[node] - k)) << ((bit_len[node] - k - 1))) | (
 782            v & ((1 << (bit_len[node] - k - 1)) - 1)
 783        )
 784        bit_len[node] -= 1
 785        size[node] -= 1
 786        total[node] -= res
 787        for p in path:
 788            size[p] -= 1
 789            total[p] -= res
 790        return s << 1 | res
 791
 792    def __getitem__(self, k: int) -> int:
 793        """``k`` 番目の要素を返します。
 794        :math:`O(\\log{n})` です。
 795        """
 796        assert 0 <= k < len(self)
 797        left, right, bit_len, size, key = (
 798            self.left,
 799            self.right,
 800            self.bit_len,
 801            self.size,
 802            self.key,
 803        )
 804        node = self.root
 805        while True:
 806            t = size[left[node]] + bit_len[node]
 807            if t - bit_len[node] <= k < t:
 808                k -= size[left[node]]
 809                return key[node] >> (bit_len[node] - k - 1) & 1
 810            if t > k:
 811                node = left[node]
 812            else:
 813                node = right[node]
 814                k -= t
 815
 816    def __setitem__(self, k: int, v: int) -> None:
 817        """``k`` 番目の要素を ``v`` に更新します。
 818        :math:`O(\\log{n})` です。
 819        """
 820        left, right, bit_len, size, key, total = (
 821            self.left,
 822            self.right,
 823            self.bit_len,
 824            self.size,
 825            self.key,
 826            self.total,
 827        )
 828        assert v == 0 or v == 1, "ValueError"
 829        node = self.root
 830        path = []
 831        while True:
 832            t = size[left[node]] + bit_len[node]
 833            path.append(node)
 834            if t - bit_len[node] <= k < t:
 835                k -= size[left[node]]
 836                if v:
 837                    key[node] |= 1 << k
 838                else:
 839                    key[node] &= ~(1 << k)
 840                break
 841            elif t > k:
 842                node = left[node]
 843            else:
 844                node = right[node]
 845                k -= t
 846        while path:
 847            node = path.pop()
 848            total[node] = (
 849                AVLTreeBitVector._popcount(key[node])
 850                + total[left[node]]
 851                + total[right[node]]
 852            )
 853
 854    def __str__(self):
 855        return str(self.tolist())
 856
 857    def __len__(self):
 858        return self.size[self.root]
 859
 860    def __repr__(self):
 861        return f"{self.__class__.__name__}({self})"
 862# from titan_pylib.data_structures.wavelet_matrix.wavelet_matrix import WaveletMatrix
 863# from titan_pylib.data_structures.bit_vector.bit_vector import BitVector
 864# from titan_pylib.data_structures.bit_vector.bit_vector_interface import (
 865#     BitVectorInterface,
 866# )
 867from array import array
 868
 869
 870class BitVector(BitVectorInterface):
 871    """コンパクトな bit vector です。"""
 872
 873    def __init__(self, n: int):
 874        """長さ ``n`` の ``BitVector`` です。
 875
 876        bit を保持するのに ``array[I]`` を使用します。
 877        ``block_size= n / 32`` として、使用bitは ``32*block_size=2n bit`` です。
 878
 879        累積和を保持するのに同様の ``array[I]`` を使用します。
 880        32bitごとの和を保存しています。同様に使用bitは ``2n bit`` です。
 881        """
 882        assert 0 <= n < 4294967295
 883        self.N = n
 884        self.block_size = (n + 31) >> 5
 885        b = bytes(4 * (self.block_size + 1))
 886        self.bit = array("I", b)
 887        self.acc = array("I", b)
 888
 889    @staticmethod
 890    def _popcount(x: int) -> int:
 891        x = x - ((x >> 1) & 0x55555555)
 892        x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
 893        x = x + (x >> 4) & 0x0F0F0F0F
 894        x += x >> 8
 895        x += x >> 16
 896        return x & 0x0000007F
 897
 898    def set(self, k: int) -> None:
 899        """``k`` 番目の bit を ``1`` にします。
 900        :math:`O(1)` です。
 901
 902        Args:
 903          k (int): インデックスです。
 904        """
 905        self.bit[k >> 5] |= 1 << (k & 31)
 906
 907    def build(self) -> None:
 908        """構築します。
 909        **これ以降 ``set`` メソッドを使用してはいけません。**
 910        :math:`O(n)` です。
 911        """
 912        acc, bit = self.acc, self.bit
 913        for i in range(self.block_size):
 914            acc[i + 1] = acc[i] + BitVector._popcount(bit[i])
 915
 916    def access(self, k: int) -> int:
 917        """``k`` 番目の bit を返します。
 918        :math:`O(1)` です。
 919        """
 920        return (self.bit[k >> 5] >> (k & 31)) & 1
 921
 922    def __getitem__(self, k: int) -> int:
 923        return (self.bit[k >> 5] >> (k & 31)) & 1
 924
 925    def rank0(self, r: int) -> int:
 926        """``a[0, r)`` に含まれる ``0`` の個数を返します。
 927        :math:`O(1)` です。
 928        """
 929        return r - (
 930            self.acc[r >> 5]
 931            + BitVector._popcount(self.bit[r >> 5] & ((1 << (r & 31)) - 1))
 932        )
 933
 934    def rank1(self, r: int) -> int:
 935        """``a[0, r)`` に含まれる ``1`` の個数を返します。
 936        :math:`O(1)` です。
 937        """
 938        return self.acc[r >> 5] + BitVector._popcount(
 939            self.bit[r >> 5] & ((1 << (r & 31)) - 1)
 940        )
 941
 942    def rank(self, r: int, v: int) -> int:
 943        """``a[0, r)`` に含まれる ``v`` の個数を返します。
 944        :math:`O(1)` です。
 945        """
 946        return self.rank1(r) if v else self.rank0(r)
 947
 948    def select0(self, k: int) -> int:
 949        """``k`` 番目の ``0`` のインデックスを返します。
 950        :math:`O(\\log{n})` です。
 951        """
 952        if k < 0 or self.rank0(self.N) <= k:
 953            return -1
 954        l, r = 0, self.block_size + 1
 955        while r - l > 1:
 956            m = (l + r) >> 1
 957            if m * 32 - self.acc[m] > k:
 958                r = m
 959            else:
 960                l = m
 961        indx = 32 * l
 962        k = k - (l * 32 - self.acc[l]) + self.rank0(indx)
 963        l, r = indx, indx + 32
 964        while r - l > 1:
 965            m = (l + r) >> 1
 966            if self.rank0(m) > k:
 967                r = m
 968            else:
 969                l = m
 970        return l
 971
 972    def select1(self, k: int) -> int:
 973        """``k`` 番目の ``1`` のインデックスを返します。
 974        :math:`O(\\log{n})` です。
 975        """
 976        if k < 0 or self.rank1(self.N) <= k:
 977            return -1
 978        l, r = 0, self.block_size + 1
 979        while r - l > 1:
 980            m = (l + r) >> 1
 981            if self.acc[m] > k:
 982                r = m
 983            else:
 984                l = m
 985        indx = 32 * l
 986        k = k - self.acc[l] + self.rank1(indx)
 987        l, r = indx, indx + 32
 988        while r - l > 1:
 989            m = (l + r) >> 1
 990            if self.rank1(m) > k:
 991                r = m
 992            else:
 993                l = m
 994        return l
 995
 996    def select(self, k: int, v: int) -> int:
 997        """``k`` 番目の ``v`` のインデックスを返します。
 998        :math:`O(\\log{n})` です。
 999        """
1000        return self.select1(k) if v else self.select0(k)
1001
1002    def __len__(self):
1003        return self.N
1004
1005    def __str__(self):
1006        return str([self.access(i) for i in range(self.N)])
1007
1008    def __repr__(self):
1009        return f"{self.__class__.__name__}({self})"
1010from typing import Sequence
1011from heapq import heappush, heappop
1012from array import array
1013
1014
1015class WaveletMatrix:
1016    """``WaveletMatrix`` です。
1017    静的であることに注意してください。
1018
1019    以下の仕様の計算量には嘘があるかもしれません。import 元の ``BitVector`` の計算量も参考にしてください。
1020
1021    参考:
1022      `https://miti-7.hatenablog.com/entry/2018/04/28/152259 <https://miti-7.hatenablog.com/entry/2018/04/28/152259>`_
1023      `https://www.slideshare.net/pfi/ss-15916040 <https://www.slideshare.net/pfi/ss-15916040>`_
1024      `デwiki <https://scrapbox.io/data-structures/Wavelet_Matrix>`_
1025    """
1026
1027    def __init__(self, sigma: int, a: Sequence[int] = []) -> None:
1028        """``[0, sigma)`` の整数列を管理する ``WaveletMatrix`` を構築します。
1029        :math:`O(n\\log{\\sigma})` です。
1030
1031        Args:
1032            sigma (int): 扱う整数の上限です。
1033            a (Sequence[int], optional): 構築する配列です。
1034        """
1035        self.sigma: int = sigma
1036        self.log: int = (sigma - 1).bit_length()
1037        self.mid: array[int] = array("I", bytes(4 * self.log))
1038        self.size: int = len(a)
1039        self.v: list[BitVector] = [BitVector(self.size) for _ in range(self.log)]
1040        self._build(a)
1041
1042    def _build(self, a: Sequence[int]) -> None:
1043        # 列 a から wm を構築する
1044        for bit in range(self.log - 1, -1, -1):
1045            # bit目の0/1に応じてvを構築 + aを安定ソート
1046            v = self.v[bit]
1047            zero, one = [], []
1048            for i, e in enumerate(a):
1049                if e >> bit & 1:
1050                    v.set(i)
1051                    one.append(e)
1052                else:
1053                    zero.append(e)
1054            v.build()
1055            self.mid[bit] = len(zero)  # 境界をmid[bit]に保持
1056            a = zero + one
1057
1058    def access(self, k: int) -> int:
1059        """``k`` 番目の値を返します。
1060        :math:`O(\\log{\\sigma})` です。
1061
1062        Args:
1063            k (int): インデックスです。
1064        """
1065        assert (
1066            -self.size <= k < self.size
1067        ), f"IndexError: {self.__class__.__name__}.access({k}), size={self.size}"
1068        if k < 0:
1069            k += self.size
1070        s = 0  # 答え
1071        for bit in range(self.log - 1, -1, -1):
1072            if self.v[bit].access(k):
1073                # k番目が立ってたら、
1074                # kまでの1とすべての0が次のk
1075                s |= 1 << bit
1076                k = self.v[bit].rank1(k) + self.mid[bit]
1077            else:
1078                # kまでの0が次のk
1079                k = self.v[bit].rank0(k)
1080        return s
1081
1082    def __getitem__(self, k: int) -> int:
1083        assert (
1084            -self.size <= k < self.size
1085        ), f"IndexError: {self.__class__.__name__}[{k}], size={self.size}"
1086        return self.access(k)
1087
1088    def rank(self, r: int, x: int) -> int:
1089        """``a[0, r)`` に含まれる ``x`` の個数を返します。
1090        :math:`O(\\log{\\sigma})` です。
1091        """
1092        assert (
1093            0 <= r <= self.size
1094        ), f"IndexError: {self.__class__.__name__}.rank(), r={r}, size={self.size}"
1095        assert (
1096            0 <= x < 1 << self.log
1097        ), f"ValueError: {self.__class__.__name__}.rank(), x={x}, LIM={1<<self.log}"
1098        l = 0
1099        mid = self.mid
1100        for bit in range(self.log - 1, -1, -1):
1101            # 位置 r より左に x が何個あるか
1102            # x の bit 目で場合分け
1103            if x >> bit & 1:
1104                # 立ってたら、次のl, rは以下
1105                l = self.v[bit].rank1(l) + mid[bit]
1106                r = self.v[bit].rank1(r) + mid[bit]
1107            else:
1108                # そうでなければ次のl, rは以下
1109                l = self.v[bit].rank0(l)
1110                r = self.v[bit].rank0(r)
1111        return r - l
1112
1113    def select(self, k: int, x: int) -> int:
1114        """``k`` 番目の ``v`` のインデックスを返します。
1115        :math:`O(\\log{\\sigma})` です。
1116        """
1117        assert (
1118            0 <= k < self.size
1119        ), f"IndexError: {self.__class__.__name__}.select({k}, {x}), k={k}, size={self.size}"
1120        assert (
1121            0 <= x < 1 << self.log
1122        ), f"ValueError: {self.__class__.__name__}.select({k}, {x}), x={x}, LIM={1<<self.log}"
1123        # x の開始位置 s を探す
1124        s = 0
1125        for bit in range(self.log - 1, -1, -1):
1126            if x >> bit & 1:
1127                s = self.v[bit].rank0(self.size) + self.v[bit].rank1(s)
1128            else:
1129                s = self.v[bit].rank0(s)
1130        s += k  # s から k 進んだ位置が、元の列で何番目か調べる
1131        for bit in range(self.log):
1132            if x >> bit & 1:
1133                s = self.v[bit].select1(s - self.v[bit].rank0(self.size))
1134            else:
1135                s = self.v[bit].select0(s)
1136        return s
1137
1138    def kth_smallest(self, l: int, r: int, k: int) -> int:
1139        """``a[l, r)`` の中で ``k`` 番目に **小さい** 値を返します。
1140        :math:`O(\\log{\\sigma})` です。
1141        """
1142        assert (
1143            0 <= l <= r <= self.size
1144        ), f"IndexError: {self.__class__.__name__}.kth_smallest({l}, {r}, {k}), size={self.size}"
1145        assert (
1146            0 <= k < r - l
1147        ), f"IndexError: {self.__class__.__name__}.kth_smallest({l}, {r}, {k}), wrong k"
1148        s = 0
1149        mid = self.mid
1150        for bit in range(self.log - 1, -1, -1):
1151            r0, l0 = self.v[bit].rank0(r), self.v[bit].rank0(l)
1152            cnt = r0 - l0  # 区間内の 0 の個数
1153            if cnt <= k:  # 0 が k 以下のとき、 k 番目は 1
1154                s |= 1 << bit
1155                k -= cnt
1156                # この 1 が次の bit 列でどこに行くか
1157                l = l - l0 + mid[bit]
1158                r = r - r0 + mid[bit]
1159            else:
1160                # この 0 が次の bit 列でどこに行くか
1161                l = l0
1162                r = r0
1163        return s
1164
1165    quantile = kth_smallest
1166
1167    def kth_largest(self, l: int, r: int, k: int) -> int:
1168        """``a[l, r)`` の中で ``k`` 番目に **大きい値** を返します。
1169        :math:`O(\\log{\\sigma})` です。
1170        """
1171        assert (
1172            0 <= l <= r <= self.size
1173        ), f"IndexError: {self.__class__.__name__}.kth_largest({l}, {r}, {k}), size={self.size}"
1174        assert (
1175            0 <= k < r - l
1176        ), f"IndexError: {self.__class__.__name__}.kth_largest({l}, {r}, {k}), wrong k"
1177        return self.kth_smallest(l, r, r - l - k - 1)
1178
1179    def topk(self, l: int, r: int, k: int) -> list[tuple[int, int]]:
1180        """``a[l, r)`` の中で、要素を出現回数が多い順にその頻度とともに ``k`` 個返します。
1181        :math:`O(\\min(r-l, \\sigam) \\log(\\sigam))` です。
1182
1183        Note:
1184            :math:`\\sigma` が大きい場合、計算量に注意です。
1185
1186        Returns:
1187            list[tuple[int, int]]: ``(要素, 頻度)`` を要素とする配列です。
1188        """
1189        assert (
1190            0 <= l <= r <= self.size
1191        ), f"IndexError: {self.__class__.__name__}.topk({l}, {r}, {k}), size={self.size}"
1192        assert (
1193            0 <= k < r - l
1194        ), f"IndexError: {self.__class__.__name__}.topk({l}, {r}, {k}), wrong k"
1195        # heap[-length, x, l, bit]
1196        hq: list[tuple[int, int, int, int]] = [(-(r - l), 0, l, self.log - 1)]
1197        ans = []
1198        while hq:
1199            length, x, l, bit = heappop(hq)
1200            length = -length
1201            if bit == -1:
1202                ans.append((x, length))
1203                k -= 1
1204                if k == 0:
1205                    break
1206            else:
1207                r = l + length
1208                l0 = self.v[bit].rank0(l)
1209                r0 = self.v[bit].rank0(r)
1210                if l0 < r0:
1211                    heappush(hq, (-(r0 - l0), x, l0, bit - 1))
1212                l1 = self.v[bit].rank1(l) + self.mid[bit]
1213                r1 = self.v[bit].rank1(r) + self.mid[bit]
1214                if l1 < r1:
1215                    heappush(hq, (-(r1 - l1), x | (1 << bit), l1, bit - 1))
1216        return ans
1217
1218    def sum(self, l: int, r: int) -> int:
1219        """``topk`` メソッドを用いて ``a[l, r)`` の総和を返します。
1220        計算量に注意です。
1221        """
1222        assert False, "Yabai Keisanryo Error"
1223        return sum(k * v for k, v in self.topk(l, r, r - l))
1224
1225    def _range_freq(self, l: int, r: int, x: int) -> int:
1226        """a[l, r) で x 未満の要素の数を返す"""
1227        ans = 0
1228        for bit in range(self.log - 1, -1, -1):
1229            l0, r0 = self.v[bit].rank0(l), self.v[bit].rank0(r)
1230            if x >> bit & 1:
1231                # bit が立ってたら、区間の 0 の個数を答えに加算し、新たな区間は 1 のみ
1232                ans += r0 - l0
1233                # 1 が次の bit 列でどこに行くか
1234                l += self.mid[bit] - l0
1235                r += self.mid[bit] - r0
1236            else:
1237                # 0 が次の bit 列でどこに行くか
1238                l, r = l0, r0
1239        return ans
1240
1241    def range_freq(self, l: int, r: int, x: int, y: int) -> int:
1242        """``a[l, r)`` に含まれる、 ``x`` 以上 ``y`` 未満である要素の個数を返します。
1243        :math:`O(\\log{\\sigma})` です。
1244        """
1245        assert (
1246            0 <= l <= r <= self.size
1247        ), f"IndexError: {self.__class__.__name__}.range_freq({l}, {r}, {x}, {y})"
1248        assert 0 <= x <= y < self.sigma, f"ValueError"
1249        return self._range_freq(l, r, y) - self._range_freq(l, r, x)
1250
1251    def prev_value(self, l: int, r: int, x: int) -> int:
1252        """``a[l, r)`` で、``x`` 以上 ``y`` 未満であるような要素のうち最大の要素を返します。
1253        :math:`O(\\log{\\sigma})` です。
1254        """
1255        assert (
1256            0 <= l <= r <= self.size
1257        ), f"IndexError: {self.__class__.__name__}.prev_value({l}, {r}, {x})"
1258        return self.kth_smallest(l, r, self._range_freq(l, r, x) - 1)
1259
1260    def next_value(self, l: int, r: int, x: int) -> int:
1261        """``a[l, r)`` で、``x`` 以上 ``y`` 未満であるような要素のうち最小の要素を返します。
1262        :math:`O(\\log{\\sigma})` です。
1263        """
1264        assert (
1265            0 <= l <= r <= self.size
1266        ), f"IndexError: {self.__class__.__name__}.next_value({l}, {r}, {x})"
1267        return self.kth_smallest(l, r, self._range_freq(l, r, x))
1268
1269    def range_count(self, l: int, r: int, x: int) -> int:
1270        """``a[l, r)`` に含まれる ``x`` の個数を返します。
1271        ``wm.rank(r, x) - wm.rank(l, x)`` と等価です。
1272        :math:`O(\\log{\\sigma})` です。
1273        """
1274        assert (
1275            0 <= l <= r <= self.size
1276        ), f"IndexError: {self.__class__.__name__}.range_count({l}, {r}, {x})"
1277        return self.rank(r, x) - self.rank(l, x)
1278
1279    def __len__(self) -> int:
1280        return self.size
1281
1282    def __str__(self) -> str:
1283        return (
1284            f"{self.__class__.__name__}({[self.access(i) for i in range(self.size)]})"
1285        )
1286
1287    __repr__ = __str__
1288from typing import Sequence
1289from array import array
1290
1291
1292class DynamicWaveletMatrix(WaveletMatrix):
1293    """動的ウェーブレット行列です。
1294
1295    (静的)ウェーブレット行列でできる操作に加えて ``insert / pop / set`` 等ができます。
1296      - ``BitVector`` を平衡二分木にしています(``AVLTreeBitVector``)。あらゆる操作に平衡二分木の log がつきます。これヤバくね
1297
1298    :math:`O(n\\log{(\\sigma)})` です。
1299    """
1300
1301    def __init__(self, sigma: int, a: Sequence[int] = []) -> None:
1302        self.sigma: int = sigma
1303        self.log: int = (sigma - 1).bit_length()
1304        self.v: list[AVLTreeBitVector] = [AVLTreeBitVector()] * self.log
1305        self.mid: array[int] = array("I", bytes(4 * self.log))
1306        self.size: int = len(a)
1307        self._build(a)
1308
1309    def _build(self, a: Sequence[int]) -> None:
1310        v = array("B", bytes(self.size))
1311        for bit in range(self.log - 1, -1, -1):
1312            # bit目の0/1に応じてvを構築 + aを安定ソート
1313            zero, one = [], []
1314            for i, e in enumerate(a):
1315                if e >> bit & 1:
1316                    v[i] = 1
1317                    one.append(e)
1318                else:
1319                    v[i] = 0
1320                    zero.append(e)
1321            self.mid[bit] = len(zero)  # 境界をmid[bit]に保持
1322            self.v[bit] = AVLTreeBitVector(v)
1323            a = zero + one
1324
1325    def reserve(self, n: int) -> None:
1326        """``n`` 要素分のメモリを確保します。
1327        :math:`O(n)` です。
1328        """
1329        assert n >= 0, f"ValueError: {self.__class__.__name__}.reserve({n})"
1330        for v in self.v:
1331            v.reserve(n)
1332
1333    def insert(self, k: int, x: int) -> None:
1334        """位置 ``k`` に ``x`` を挿入します。
1335        :math:`O(\\log{(n)}\\log{(\\sigma)})` です。
1336        """
1337        assert (
1338            0 <= k <= self.size
1339        ), f"IndexError: {self.__class__.__name__}.insert({k}, {x}), n={self.size}"
1340        assert (
1341            0 <= x < 1 << self.log
1342        ), f"ValueError: {self.__class__.__name__}.insert({k}, {x}), LIM={1<<self.log}"
1343        mid = self.mid
1344        for bit in range(self.log - 1, -1, -1):
1345            v = self.v[bit]
1346            # if x >> bit & 1:
1347            #   v.insert(k, 1)
1348            #   k = v.rank1(k) + mid[bit]
1349            # else:
1350            #   v.insert(k, 0)
1351            #   mid[bit] += 1
1352            #   k = v.rank0(k)
1353            if x >> bit & 1:
1354                s = v._insert_and_rank1(k, 1)
1355                k = s + mid[bit]
1356            else:
1357                s = v._insert_and_rank1(k, 0)
1358                k -= s
1359                mid[bit] += 1
1360        self.size += 1
1361
1362    def pop(self, k: int) -> int:
1363        """位置 ``k`` の要素を削除し、その値を返します。
1364        :math:`O(\\log{(n)}\\log{(\\sigma)})` です。
1365        """
1366        assert (
1367            0 <= k < self.size
1368        ), f"IndexError: {self.__class__.__name__}.pop({k}), n={self.size}"
1369        mid = self.mid
1370        ans = 0
1371        for bit in range(self.log - 1, -1, -1):
1372            v = self.v[bit]
1373            # K = k
1374            # if v.access(k):
1375            #   ans |= 1 << bit
1376            #   k = v.rank1(k) + mid[bit]
1377            # else:
1378            #   mid[bit] -= 1
1379            #   k = v.rank0(k)
1380            # v.pop(K)
1381            sb = v._access_pop_and_rank1(k)
1382            s = sb >> 1
1383            if sb & 1:
1384                ans |= 1 << bit
1385                k = s + mid[bit]
1386            else:
1387                mid[bit] -= 1
1388                k -= s
1389        self.size -= 1
1390        return ans
1391
1392    def set(self, k: int, x: int) -> None:
1393        """位置 ``k`` の要素を ``x`` に更新します。
1394        :math:`O(\\log{(n)}\\log{(\\sigma)})` です。
1395        """
1396        assert (
1397            0 <= k < self.size
1398        ), f"IndexError: {self.__class__.__name__}.set({k}, {x}), n={self.size}"
1399        assert (
1400            0 <= x < 1 << self.log
1401        ), f"ValueError: {self.__class__.__name__}.set({k}, {x}), LIM={1<<self.log}"
1402        self.pop(k)
1403        self.insert(k, x)
1404
1405    def __setitem__(self, k: int, x: int):
1406        assert (
1407            0 <= k < self.size
1408        ), f"IndexError: {self.__class__.__name__}[{k}] = {x}, n={self.size}"
1409        assert (
1410            0 <= x < 1 << self.log
1411        ), f"ValueError: {self.__class__.__name__}[{k}] = {x}, LIM={1<<self.log}"
1412        self.set(k, x)
1413
1414    def __str__(self):
1415        return f"{self.__class__.__name__}({[self[i] for i in range(self.size)]})"

仕様

class DynamicWaveletMatrix(sigma: int, a: Sequence[int] = [])[source]

Bases: WaveletMatrix

動的ウェーブレット行列です。

(静的)ウェーブレット行列でできる操作に加えて insert / pop / set 等ができます。
  • BitVector を平衡二分木にしています(AVLTreeBitVector)。あらゆる操作に平衡二分木の log がつきます。これヤバくね

\(O(n\log{(\sigma)})\) です。

insert(k: int, x: int) None[source]

位置 kx を挿入します。 \(O(\log{(n)}\log{(\sigma)})\) です。

pop(k: int) int[source]

位置 k の要素を削除し、その値を返します。 \(O(\log{(n)}\log{(\sigma)})\) です。

reserve(n: int) None[source]

n 要素分のメモリを確保します。 \(O(n)\) です。

set(k: int, x: int) None[source]

位置 k の要素を x に更新します。 \(O(\log{(n)}\log{(\sigma)})\) です。