splay_tree_bit_vector¶
ソースコード¶
from titan_pylib.data_structures.bit_vector.splay_tree_bit_vector import SplayTreeBitVector
展開済みコード¶
1# from titan_pylib.data_structures.bit_vector.splay_tree_bit_vector import SplayTreeBitVector
2# from titan_pylib.data_structures.bit_vector.bit_vector_interface import (
3# BitVectorInterface,
4# )
5from abc import ABC, abstractmethod
6
7
8class BitVectorInterface(ABC):
9
10 @abstractmethod
11 def access(self, k: int) -> int:
12 raise NotImplementedError
13
14 @abstractmethod
15 def __getitem__(self, k: int) -> int:
16 raise NotImplementedError
17
18 @abstractmethod
19 def rank0(self, r: int) -> int:
20 raise NotImplementedError
21
22 @abstractmethod
23 def rank1(self, r: int) -> int:
24 raise NotImplementedError
25
26 @abstractmethod
27 def rank(self, r: int, v: int) -> int:
28 raise NotImplementedError
29
30 @abstractmethod
31 def select0(self, k: int) -> int:
32 raise NotImplementedError
33
34 @abstractmethod
35 def select1(self, k: int) -> int:
36 raise NotImplementedError
37
38 @abstractmethod
39 def select(self, k: int, v: int) -> int:
40 raise NotImplementedError
41
42 @abstractmethod
43 def __len__(self) -> int:
44 raise NotImplementedError
45
46 @abstractmethod
47 def __str__(self) -> str:
48 raise NotImplementedError
49
50 @abstractmethod
51 def __repr__(self) -> str:
52 raise NotImplementedError
53from typing import Sequence
54from array import array
55
56
57class SplayTreeBitVector(BitVectorInterface):
58
59 @staticmethod
60 def _popcount(x: int) -> int:
61 x = x - ((x >> 1) & 0x55555555)
62 x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
63 x = x + (x >> 4) & 0x0F0F0F0F
64 x += x >> 8
65 x += x >> 16
66 return x & 0x0000007F
67
68 def __init__(self, a: Sequence[int] = []):
69 self.root = 0
70 self.bit_len = array("B", bytes(1))
71 self.key = array("I", bytes(4))
72 self.size = array("I", bytes(4))
73 self.total = array("I", bytes(4))
74 self.child = array("I", bytes(8))
75 self.end = 1
76 self.w = 32
77 if a:
78 self._build(a)
79
80 def reserve(self, n: int) -> None:
81 n = n // self.w + 1
82 a = array("I", bytes(4 * n))
83 self.bit_len += array("B", bytes(n))
84 self.key += a
85 self.size += a
86 self.total += a
87 self.child += array("I", bytes(8 * n))
88
89 def _build(self, a: Sequence[int]) -> None:
90 key, bit_len, child, size, total = (
91 self.key,
92 self.bit_len,
93 self.child,
94 self.size,
95 self.total,
96 )
97 _popcount = SplayTreeBitVector._popcount
98
99 def rec(l: int, r: int) -> int:
100 mid = (l + r) >> 1
101 if l != mid:
102 child[mid << 1] = rec(l, mid)
103 size[mid] += size[child[mid << 1]]
104 total[mid] += total[child[mid << 1]]
105 if mid + 1 != r:
106 child[mid << 1 | 1] = rec(mid + 1, r)
107 size[mid] += size[child[mid << 1 | 1]]
108 total[mid] += total[child[mid << 1 | 1]]
109 return mid
110
111 if not (hasattr(a, "__getitem__") and hasattr(a, "__len__")):
112 a = list(a)
113 n = len(a)
114 end = self.end
115 self.reserve(n)
116 i = 0
117 indx = end
118 for i in range(0, n, self.w):
119 j = 0
120 v = 0
121 while j < self.w and i + j < n:
122 v <<= 1
123 v |= a[i + j]
124 j += 1
125 key[indx] = v
126 bit_len[indx] = j
127 size[indx] = j
128 total[indx] = _popcount(v)
129 indx += 1
130 self.end = indx
131 self.root = rec(end, self.end)
132
133 def _make_node(self, key: int, bit_len: int) -> int:
134 end = self.end
135 if end >= len(self.key):
136 self.key.append(key)
137 self.bit_len.append(bit_len)
138 self.size.append(bit_len)
139 self.total.append(SplayTreeBitVector._popcount(key))
140 self.child.append(0)
141 self.child.append(0)
142 else:
143 self.key[end] = key
144 self.bit_len[end] = bit_len
145 self.size[end] = bit_len
146 self.total[end] = SplayTreeBitVector._popcount(key)
147 self.end += 1
148 return end
149
150 def _update_triple(self, x: int, y: int, z: int) -> None:
151 child, bit_len, size, total = self.child, self.bit_len, self.size, self.total
152 lx, rx = child[x << 1], child[x << 1 | 1]
153 ly, ry = child[y << 1], child[y << 1 | 1]
154 size[z] = size[x]
155 size[x] = bit_len[x] + size[lx] + size[rx]
156 size[y] = bit_len[y] + size[ly] + size[ry]
157 total[z] = total[x]
158 total[x] = total[lx] + SplayTreeBitVector._popcount(self.key[x]) + total[rx]
159 total[y] = total[ly] + SplayTreeBitVector._popcount(self.key[y]) + total[ry]
160
161 def _update_double(self, x: int, y: int) -> None:
162 child, bit_len, size, total = self.child, self.bit_len, self.size, self.total
163 lx, rx = child[x << 1], child[x << 1 | 1]
164 size[y] = size[x]
165 size[x] = bit_len[x] + size[lx] + size[rx]
166 total[y] = total[x]
167 total[x] = total[lx] + SplayTreeBitVector._popcount(self.key[x]) + total[rx]
168
169 def _update(self, node: int) -> None:
170 lnode, rnode = self.child[node << 1], self.child[node << 1 | 1]
171 self.size[node] = self.bit_len[node] + self.size[lnode] + self.size[rnode]
172 self.total[node] = (
173 SplayTreeBitVector._popcount(self.key[node])
174 + self.total[lnode]
175 + self.total[rnode]
176 )
177
178 def _splay(self, path: list[int], d: int) -> None:
179 child = self.child
180 g = d & 1
181 while len(path) > 1:
182 pnode = path.pop()
183 gnode = path.pop()
184 f = d >> 1 & 1
185 node = child[pnode << 1 | g ^ 1]
186 nnode = (pnode if g == f else node) << 1 | f
187 child[pnode << 1 | g ^ 1] = child[node << 1 | g]
188 child[node << 1 | g] = pnode
189 child[gnode << 1 | f ^ 1] = child[nnode]
190 child[nnode] = gnode
191 self._update_triple(gnode, pnode, node)
192 if not path:
193 return
194 d >>= 2
195 g = d & 1
196 child[path[-1] << 1 | g ^ 1] = node
197 pnode = path.pop()
198 node = child[pnode << 1 | g ^ 1]
199 child[pnode << 1 | g ^ 1] = child[node << 1 | g]
200 child[node << 1 | g] = pnode
201 self._update_double(pnode, node)
202
203 def _kth_elm_splay(self, node: int, k: int) -> int:
204 child, bit_len, size = self.child, self.bit_len, self.size
205 d = 0
206 path = []
207 while True:
208 t = size[child[node << 1]] + bit_len[node]
209 if t - bit_len[node] <= k < t:
210 if path:
211 self._splay(path, d)
212 return node
213 d = d << 1 | (t > k)
214 path.append(node)
215 node = child[node << 1 | (t <= k)]
216 if t <= k:
217 k -= t
218
219 def _left_splay(self, node: int) -> int:
220 if not node:
221 return 0
222 child = self.child
223 if not child[node << 1]:
224 return node
225 path = []
226 while child[node << 1]:
227 path.append(node)
228 node = child[node << 1]
229 self._splay(path, (1 << len(path)) - 1)
230 return node
231
232 def _right_splay(self, node: int) -> int:
233 if not node:
234 return 0
235 child = self.child
236 if not child[node << 1 | 1]:
237 return node
238 path = []
239 while child[node << 1 | 1]:
240 path.append(node)
241 node = child[node << 1 | 1]
242 self._splay(path, 0)
243 return node
244
245 def insert(self, k: int, key: int) -> None:
246 assert (
247 0 <= k <= len(self)
248 ), f"IndexError: SplayTreeBitVector.insert({k}, {key}), len={len(self)}"
249 if not self.root:
250 node = self._make_node(key, 1)
251 self.root = node
252 return
253 bit_len, child, size, keys, total = (
254 self.bit_len,
255 self.child,
256 self.size,
257 self.key,
258 self.total,
259 )
260 if k == size[self.root]:
261 node = self._right_splay(self.root)
262 if bit_len[node] == self.w:
263 v = keys[node] << 1 | key
264 new_node = self._make_node(v & 1, 1)
265 keys[node] = v >> 1
266 child[new_node << 1] = node
267 self._update(node)
268 size[new_node] += size[node]
269 total[new_node] += total[node]
270 self.root = new_node
271 else:
272 v = keys[node]
273 bl = k - bit_len[node] - size[child[node << 1]]
274 keys[node] = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
275 bit_len[node] += 1
276 size[node] += 1
277 total[node] += key
278 self.root = node
279 else:
280 node = self._kth_elm_splay(self.root, k)
281 if bit_len[node] == self.w:
282 k -= size[child[node << 1]]
283 v = keys[node]
284 bl = bit_len[node] - k
285 v = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
286 new_node = self._make_node(v >> self.w, 1)
287 keys[node] = v & ((1 << self.w) - 1)
288 self._update(node)
289 if child[node << 1]:
290 child[new_node << 1] = child[node << 1]
291 child[node << 1] = 0
292 self._update(node)
293 child[new_node << 1 | 1] = node
294 self._update(new_node)
295 self.root = new_node
296 else:
297 v = keys[node]
298 bl = bit_len[node] - k + size[child[node << 1]]
299 keys[node] = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
300 bit_len[node] += 1
301 size[node] += 1
302 total[node] += key
303 self.root = node
304
305 def pop(self, k: int = -1) -> int:
306 assert 0 <= k < len(self), f"IndexError: SplayTreeBitVector.pop({k})"
307 root = self._kth_elm_splay(self.root, k)
308 size, child, key, bit_len, total = (
309 self.size,
310 self.child,
311 self.key,
312 self.bit_len,
313 self.total,
314 )
315 k -= size[child[root << 1]]
316 v = key[root]
317 res = v >> (bit_len[root] - k - 1) & 1
318 if bit_len[root] == 1:
319 if not child[root << 1]:
320 self.root = child[root << 1 | 1]
321 elif not child[root << 1 | 1]:
322 self.root = child[root << 1]
323 else:
324 node = self._right_splay(child[root << 1])
325 child[node << 1 | 1] = child[root << 1 | 1]
326 self._update(node)
327 self.root = node
328 else:
329 key[root] = ((v >> (bit_len[root] - k)) << ((bit_len[root] - k - 1))) | (
330 v & ((1 << (bit_len[root] - k - 1)) - 1)
331 )
332 bit_len[root] -= 1
333 size[root] -= 1
334 total[root] -= res
335 self.root = root
336 return res
337
338 def _pref(self, r: int) -> int:
339 assert (
340 0 <= r <= len(self)
341 ), f"IndexError: SplayTreeBitVector._pref({r}), len={len(self)}"
342 if r == 0:
343 return 0
344 if r == len(self):
345 return self.total[self.root]
346 self.root = self._kth_elm_splay(self.root, r - 1)
347 r -= self.size[self.child[self.root << 1]]
348 return (
349 self.total[self.root]
350 - SplayTreeBitVector._popcount(
351 self.key[self.root] & ((1 << (self.bit_len[self.root] - r)) - 1)
352 )
353 - self.total[self.child[self.root << 1 | 1]]
354 )
355
356 def __getitem__(self, k: int) -> int:
357 assert 0 <= k < len(self), f"IndexError: SplayTreeBitVector.__getitem__({k})"
358 self.root = self._kth_elm_splay(self.root, k)
359 k -= self.size[self.child[self.root << 1]]
360 return (self.key[self.root] >> (self.bit_len[self.root] - k - 1)) & 1
361
362 def debug(self):
363 print("### debug")
364 print(f"{self.root=}")
365 print(f"{self.key=}")
366 print(f"{self.bit_len=}")
367 print(f"{self.size=}")
368 print(f"{self.total=}")
369 print(f"{self.child=}")
370
371 def __len__(self):
372 return self.size[self.root]
373
374 def tolist(self) -> list[int]:
375 child, key, bit_len = self.child, self.key, self.bit_len
376 a = []
377 if not self.root:
378 return a
379
380 def rec(node):
381 if child[node << 1]:
382 rec(child[node << 1])
383 for i in range(bit_len[node] - 1, -1, -1):
384 a.append(key[node] >> i & 1)
385 if child[node << 1 | 1]:
386 rec(child[node << 1 | 1])
387
388 rec(self.root)
389 return a
390
391 def __str__(self):
392 return str(self.tolist())
393
394 __repr__ = __str__
395
396 def debug_acc(self) -> None:
397 child = self.child
398 key = self.key
399
400 def rec(node):
401 acc = self._popcount(key[node])
402 if child[node << 1]:
403 acc += rec(child[node << 1])
404 if child[node << 1 | 1]:
405 acc += rec(child[node << 1 | 1])
406 if acc != self.total[node]:
407 # self.debug()
408 assert False, "acc Error"
409 return acc
410
411 rec(self.root)
412
413 def access(self, k: int) -> int:
414 return self.__getitem__(k)
415
416 def rank0(self, r: int) -> int:
417 # a[0, r) に含まれる 0 の個数
418 return r - self._pref(r)
419
420 def rank1(self, r: int) -> int:
421 # a[0, r) に含まれる 1 の個数
422 return self._pref(r)
423
424 def rank(self, r: int, v: int) -> int:
425 # a[0, r) に含まれる v の個数
426 return self.rank1(r) if v else self.rank0(r)
427
428 def select0(self, k: int) -> int:
429 # k 番目の 0 のindex
430 # O(log(N))
431 if k < 0 or self.rank0(len(self)) <= k:
432 return -1
433 l, r = 0, len(self)
434 while r - l > 1:
435 m = (l + r) >> 1
436 if m - self._pref(m) > k:
437 r = m
438 else:
439 l = m
440 return l
441
442 def select1(self, k: int) -> int:
443 # k 番目の 1 のindex
444 # O(log(N))
445 if k < 0 or self.rank1(len(self)) <= k:
446 return -1
447 l, r = 0, len(self)
448 while r - l > 1:
449 m = (l + r) >> 1
450 if self._pref(m) > k:
451 r = m
452 else:
453 l = m
454 return l
455
456 def select(self, k: int, v: int) -> int:
457 # k 番目の v のindex
458 # O(log(N))
459 return self.select1(k) if v else self.select0(k)
仕様¶
- class SplayTreeBitVector(a: Sequence[int] = [])[source]¶
Bases:
BitVectorInterface