1from titan_pylib.data_structures.bit_vector.avl_tree_bit_vector import AVLTreeBitVector
2from titan_pylib.data_structures.wavelet_matrix.wavelet_matrix import WaveletMatrix
3from typing import Sequence
4from array import array
5
6
[docs]
7class DynamicWaveletMatrix(WaveletMatrix):
8 """動的ウェーブレット行列です。
9
10 (静的)ウェーブレット行列でできる操作に加えて ``insert / pop / set`` 等ができます。
11 - ``BitVector`` を平衡二分木にしています(``AVLTreeBitVector``)。あらゆる操作に平衡二分木の log がつきます。これヤバくね
12
13 :math:`O(n\\log{(\\sigma)})` です。
14 """
15
16 def __init__(self, sigma: int, a: Sequence[int] = []) -> None:
17 self.sigma: int = sigma
18 self.log: int = (sigma - 1).bit_length()
19 self.v: list[AVLTreeBitVector] = [AVLTreeBitVector()] * self.log
20 self.mid: array[int] = array("I", bytes(4 * self.log))
21 self.size: int = len(a)
22 self._build(a)
23
24 def _build(self, a: Sequence[int]) -> None:
25 v = array("B", bytes(self.size))
26 for bit in range(self.log - 1, -1, -1):
27 # bit目の0/1に応じてvを構築 + aを安定ソート
28 zero, one = [], []
29 for i, e in enumerate(a):
30 if e >> bit & 1:
31 v[i] = 1
32 one.append(e)
33 else:
34 v[i] = 0
35 zero.append(e)
36 self.mid[bit] = len(zero) # 境界をmid[bit]に保持
37 self.v[bit] = AVLTreeBitVector(v)
38 a = zero + one
39
[docs]
40 def reserve(self, n: int) -> None:
41 """``n`` 要素分のメモリを確保します。
42 :math:`O(n)` です。
43 """
44 assert n >= 0, f"ValueError: {self.__class__.__name__}.reserve({n})"
45 for v in self.v:
46 v.reserve(n)
47
[docs]
48 def insert(self, k: int, x: int) -> None:
49 """位置 ``k`` に ``x`` を挿入します。
50 :math:`O(\\log{(n)}\\log{(\\sigma)})` です。
51 """
52 assert (
53 0 <= k <= self.size
54 ), f"IndexError: {self.__class__.__name__}.insert({k}, {x}), n={self.size}"
55 assert (
56 0 <= x < 1 << self.log
57 ), f"ValueError: {self.__class__.__name__}.insert({k}, {x}), LIM={1<<self.log}"
58 mid = self.mid
59 for bit in range(self.log - 1, -1, -1):
60 v = self.v[bit]
61 # if x >> bit & 1:
62 # v.insert(k, 1)
63 # k = v.rank1(k) + mid[bit]
64 # else:
65 # v.insert(k, 0)
66 # mid[bit] += 1
67 # k = v.rank0(k)
68 if x >> bit & 1:
69 s = v._insert_and_rank1(k, 1)
70 k = s + mid[bit]
71 else:
72 s = v._insert_and_rank1(k, 0)
73 k -= s
74 mid[bit] += 1
75 self.size += 1
76
[docs]
77 def pop(self, k: int) -> int:
78 """位置 ``k`` の要素を削除し、その値を返します。
79 :math:`O(\\log{(n)}\\log{(\\sigma)})` です。
80 """
81 assert (
82 0 <= k < self.size
83 ), f"IndexError: {self.__class__.__name__}.pop({k}), n={self.size}"
84 mid = self.mid
85 ans = 0
86 for bit in range(self.log - 1, -1, -1):
87 v = self.v[bit]
88 # K = k
89 # if v.access(k):
90 # ans |= 1 << bit
91 # k = v.rank1(k) + mid[bit]
92 # else:
93 # mid[bit] -= 1
94 # k = v.rank0(k)
95 # v.pop(K)
96 sb = v._access_pop_and_rank1(k)
97 s = sb >> 1
98 if sb & 1:
99 ans |= 1 << bit
100 k = s + mid[bit]
101 else:
102 mid[bit] -= 1
103 k -= s
104 self.size -= 1
105 return ans
106
[docs]
107 def set(self, k: int, x: int) -> None:
108 """位置 ``k`` の要素を ``x`` に更新します。
109 :math:`O(\\log{(n)}\\log{(\\sigma)})` です。
110 """
111 assert (
112 0 <= k < self.size
113 ), f"IndexError: {self.__class__.__name__}.set({k}, {x}), n={self.size}"
114 assert (
115 0 <= x < 1 << self.log
116 ), f"ValueError: {self.__class__.__name__}.set({k}, {x}), LIM={1<<self.log}"
117 self.pop(k)
118 self.insert(k, x)
119
120 def __setitem__(self, k: int, x: int):
121 assert (
122 0 <= k < self.size
123 ), f"IndexError: {self.__class__.__name__}[{k}] = {x}, n={self.size}"
124 assert (
125 0 <= x < 1 << self.log
126 ), f"ValueError: {self.__class__.__name__}[{k}] = {x}, LIM={1<<self.log}"
127 self.set(k, x)
128
129 def __str__(self):
130 return f"{self.__class__.__name__}({[self[i] for i in range(self.size)]})"