splay_tree_bit_vector

ソースコード

from titan_pylib.data_structures.bit_vector.splay_tree_bit_vector import SplayTreeBitVector

view on github

展開済みコード

  1# from titan_pylib.data_structures.bit_vector.splay_tree_bit_vector import SplayTreeBitVector
  2# from titan_pylib.data_structures.bit_vector.bit_vector_interface import (
  3#     BitVectorInterface,
  4# )
  5from abc import ABC, abstractmethod
  6
  7
  8class BitVectorInterface(ABC):
  9
 10    @abstractmethod
 11    def access(self, k: int) -> int:
 12        raise NotImplementedError
 13
 14    @abstractmethod
 15    def __getitem__(self, k: int) -> int:
 16        raise NotImplementedError
 17
 18    @abstractmethod
 19    def rank0(self, r: int) -> int:
 20        raise NotImplementedError
 21
 22    @abstractmethod
 23    def rank1(self, r: int) -> int:
 24        raise NotImplementedError
 25
 26    @abstractmethod
 27    def rank(self, r: int, v: int) -> int:
 28        raise NotImplementedError
 29
 30    @abstractmethod
 31    def select0(self, k: int) -> int:
 32        raise NotImplementedError
 33
 34    @abstractmethod
 35    def select1(self, k: int) -> int:
 36        raise NotImplementedError
 37
 38    @abstractmethod
 39    def select(self, k: int, v: int) -> int:
 40        raise NotImplementedError
 41
 42    @abstractmethod
 43    def __len__(self) -> int:
 44        raise NotImplementedError
 45
 46    @abstractmethod
 47    def __str__(self) -> str:
 48        raise NotImplementedError
 49
 50    @abstractmethod
 51    def __repr__(self) -> str:
 52        raise NotImplementedError
 53from typing import Sequence
 54from array import array
 55
 56
 57class SplayTreeBitVector(BitVectorInterface):
 58
 59    @staticmethod
 60    def _popcount(x: int) -> int:
 61        x = x - ((x >> 1) & 0x55555555)
 62        x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
 63        x = x + (x >> 4) & 0x0F0F0F0F
 64        x += x >> 8
 65        x += x >> 16
 66        return x & 0x0000007F
 67
 68    def __init__(self, a: Sequence[int] = []):
 69        self.root = 0
 70        self.bit_len = array("B", bytes(1))
 71        self.key = array("I", bytes(4))
 72        self.size = array("I", bytes(4))
 73        self.total = array("I", bytes(4))
 74        self.child = array("I", bytes(8))
 75        self.end = 1
 76        self.w = 32
 77        if a:
 78            self._build(a)
 79
 80    def reserve(self, n: int) -> None:
 81        n = n // self.w + 1
 82        a = array("I", bytes(4 * n))
 83        self.bit_len += array("B", bytes(n))
 84        self.key += a
 85        self.size += a
 86        self.total += a
 87        self.child += array("I", bytes(8 * n))
 88
 89    def _build(self, a: Sequence[int]) -> None:
 90        key, bit_len, child, size, total = (
 91            self.key,
 92            self.bit_len,
 93            self.child,
 94            self.size,
 95            self.total,
 96        )
 97        _popcount = SplayTreeBitVector._popcount
 98
 99        def rec(l: int, r: int) -> int:
100            mid = (l + r) >> 1
101            if l != mid:
102                child[mid << 1] = rec(l, mid)
103                size[mid] += size[child[mid << 1]]
104                total[mid] += total[child[mid << 1]]
105            if mid + 1 != r:
106                child[mid << 1 | 1] = rec(mid + 1, r)
107                size[mid] += size[child[mid << 1 | 1]]
108                total[mid] += total[child[mid << 1 | 1]]
109            return mid
110
111        if not (hasattr(a, "__getitem__") and hasattr(a, "__len__")):
112            a = list(a)
113        n = len(a)
114        end = self.end
115        self.reserve(n)
116        i = 0
117        indx = end
118        for i in range(0, n, self.w):
119            j = 0
120            v = 0
121            while j < self.w and i + j < n:
122                v <<= 1
123                v |= a[i + j]
124                j += 1
125            key[indx] = v
126            bit_len[indx] = j
127            size[indx] = j
128            total[indx] = _popcount(v)
129            indx += 1
130        self.end = indx
131        self.root = rec(end, self.end)
132
133    def _make_node(self, key: int, bit_len: int) -> int:
134        end = self.end
135        if end >= len(self.key):
136            self.key.append(key)
137            self.bit_len.append(bit_len)
138            self.size.append(bit_len)
139            self.total.append(SplayTreeBitVector._popcount(key))
140            self.child.append(0)
141            self.child.append(0)
142        else:
143            self.key[end] = key
144            self.bit_len[end] = bit_len
145            self.size[end] = bit_len
146            self.total[end] = SplayTreeBitVector._popcount(key)
147        self.end += 1
148        return end
149
150    def _update_triple(self, x: int, y: int, z: int) -> None:
151        child, bit_len, size, total = self.child, self.bit_len, self.size, self.total
152        lx, rx = child[x << 1], child[x << 1 | 1]
153        ly, ry = child[y << 1], child[y << 1 | 1]
154        size[z] = size[x]
155        size[x] = bit_len[x] + size[lx] + size[rx]
156        size[y] = bit_len[y] + size[ly] + size[ry]
157        total[z] = total[x]
158        total[x] = total[lx] + SplayTreeBitVector._popcount(self.key[x]) + total[rx]
159        total[y] = total[ly] + SplayTreeBitVector._popcount(self.key[y]) + total[ry]
160
161    def _update_double(self, x: int, y: int) -> None:
162        child, bit_len, size, total = self.child, self.bit_len, self.size, self.total
163        lx, rx = child[x << 1], child[x << 1 | 1]
164        size[y] = size[x]
165        size[x] = bit_len[x] + size[lx] + size[rx]
166        total[y] = total[x]
167        total[x] = total[lx] + SplayTreeBitVector._popcount(self.key[x]) + total[rx]
168
169    def _update(self, node: int) -> None:
170        lnode, rnode = self.child[node << 1], self.child[node << 1 | 1]
171        self.size[node] = self.bit_len[node] + self.size[lnode] + self.size[rnode]
172        self.total[node] = (
173            SplayTreeBitVector._popcount(self.key[node])
174            + self.total[lnode]
175            + self.total[rnode]
176        )
177
178    def _splay(self, path: list[int], d: int) -> None:
179        child = self.child
180        g = d & 1
181        while len(path) > 1:
182            pnode = path.pop()
183            gnode = path.pop()
184            f = d >> 1 & 1
185            node = child[pnode << 1 | g ^ 1]
186            nnode = (pnode if g == f else node) << 1 | f
187            child[pnode << 1 | g ^ 1] = child[node << 1 | g]
188            child[node << 1 | g] = pnode
189            child[gnode << 1 | f ^ 1] = child[nnode]
190            child[nnode] = gnode
191            self._update_triple(gnode, pnode, node)
192            if not path:
193                return
194            d >>= 2
195            g = d & 1
196            child[path[-1] << 1 | g ^ 1] = node
197        pnode = path.pop()
198        node = child[pnode << 1 | g ^ 1]
199        child[pnode << 1 | g ^ 1] = child[node << 1 | g]
200        child[node << 1 | g] = pnode
201        self._update_double(pnode, node)
202
203    def _kth_elm_splay(self, node: int, k: int) -> int:
204        child, bit_len, size = self.child, self.bit_len, self.size
205        d = 0
206        path = []
207        while True:
208            t = size[child[node << 1]] + bit_len[node]
209            if t - bit_len[node] <= k < t:
210                if path:
211                    self._splay(path, d)
212                return node
213            d = d << 1 | (t > k)
214            path.append(node)
215            node = child[node << 1 | (t <= k)]
216            if t <= k:
217                k -= t
218
219    def _left_splay(self, node: int) -> int:
220        if not node:
221            return 0
222        child = self.child
223        if not child[node << 1]:
224            return node
225        path = []
226        while child[node << 1]:
227            path.append(node)
228            node = child[node << 1]
229        self._splay(path, (1 << len(path)) - 1)
230        return node
231
232    def _right_splay(self, node: int) -> int:
233        if not node:
234            return 0
235        child = self.child
236        if not child[node << 1 | 1]:
237            return node
238        path = []
239        while child[node << 1 | 1]:
240            path.append(node)
241            node = child[node << 1 | 1]
242        self._splay(path, 0)
243        return node
244
245    def insert(self, k: int, key: int) -> None:
246        assert (
247            0 <= k <= len(self)
248        ), f"IndexError: SplayTreeBitVector.insert({k}, {key}), len={len(self)}"
249        if not self.root:
250            node = self._make_node(key, 1)
251            self.root = node
252            return
253        bit_len, child, size, keys, total = (
254            self.bit_len,
255            self.child,
256            self.size,
257            self.key,
258            self.total,
259        )
260        if k == size[self.root]:
261            node = self._right_splay(self.root)
262            if bit_len[node] == self.w:
263                v = keys[node] << 1 | key
264                new_node = self._make_node(v & 1, 1)
265                keys[node] = v >> 1
266                child[new_node << 1] = node
267                self._update(node)
268                size[new_node] += size[node]
269                total[new_node] += total[node]
270                self.root = new_node
271            else:
272                v = keys[node]
273                bl = k - bit_len[node] - size[child[node << 1]]
274                keys[node] = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
275                bit_len[node] += 1
276                size[node] += 1
277                total[node] += key
278                self.root = node
279        else:
280            node = self._kth_elm_splay(self.root, k)
281            if bit_len[node] == self.w:
282                k -= size[child[node << 1]]
283                v = keys[node]
284                bl = bit_len[node] - k
285                v = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
286                new_node = self._make_node(v >> self.w, 1)
287                keys[node] = v & ((1 << self.w) - 1)
288                self._update(node)
289                if child[node << 1]:
290                    child[new_node << 1] = child[node << 1]
291                    child[node << 1] = 0
292                    self._update(node)
293                child[new_node << 1 | 1] = node
294                self._update(new_node)
295                self.root = new_node
296            else:
297                v = keys[node]
298                bl = bit_len[node] - k + size[child[node << 1]]
299                keys[node] = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
300                bit_len[node] += 1
301                size[node] += 1
302                total[node] += key
303                self.root = node
304
305    def pop(self, k: int = -1) -> int:
306        assert 0 <= k < len(self), f"IndexError: SplayTreeBitVector.pop({k})"
307        root = self._kth_elm_splay(self.root, k)
308        size, child, key, bit_len, total = (
309            self.size,
310            self.child,
311            self.key,
312            self.bit_len,
313            self.total,
314        )
315        k -= size[child[root << 1]]
316        v = key[root]
317        res = v >> (bit_len[root] - k - 1) & 1
318        if bit_len[root] == 1:
319            if not child[root << 1]:
320                self.root = child[root << 1 | 1]
321            elif not child[root << 1 | 1]:
322                self.root = child[root << 1]
323            else:
324                node = self._right_splay(child[root << 1])
325                child[node << 1 | 1] = child[root << 1 | 1]
326                self._update(node)
327                self.root = node
328        else:
329            key[root] = ((v >> (bit_len[root] - k)) << ((bit_len[root] - k - 1))) | (
330                v & ((1 << (bit_len[root] - k - 1)) - 1)
331            )
332            bit_len[root] -= 1
333            size[root] -= 1
334            total[root] -= res
335            self.root = root
336        return res
337
338    def _pref(self, r: int) -> int:
339        assert (
340            0 <= r <= len(self)
341        ), f"IndexError: SplayTreeBitVector._pref({r}), len={len(self)}"
342        if r == 0:
343            return 0
344        if r == len(self):
345            return self.total[self.root]
346        self.root = self._kth_elm_splay(self.root, r - 1)
347        r -= self.size[self.child[self.root << 1]]
348        return (
349            self.total[self.root]
350            - SplayTreeBitVector._popcount(
351                self.key[self.root] & ((1 << (self.bit_len[self.root] - r)) - 1)
352            )
353            - self.total[self.child[self.root << 1 | 1]]
354        )
355
356    def __getitem__(self, k: int) -> int:
357        assert 0 <= k < len(self), f"IndexError: SplayTreeBitVector.__getitem__({k})"
358        self.root = self._kth_elm_splay(self.root, k)
359        k -= self.size[self.child[self.root << 1]]
360        return (self.key[self.root] >> (self.bit_len[self.root] - k - 1)) & 1
361
362    def debug(self):
363        print("### debug")
364        print(f"{self.root=}")
365        print(f"{self.key=}")
366        print(f"{self.bit_len=}")
367        print(f"{self.size=}")
368        print(f"{self.total=}")
369        print(f"{self.child=}")
370
371    def __len__(self):
372        return self.size[self.root]
373
374    def tolist(self) -> list[int]:
375        child, key, bit_len = self.child, self.key, self.bit_len
376        a = []
377        if not self.root:
378            return a
379
380        def rec(node):
381            if child[node << 1]:
382                rec(child[node << 1])
383            for i in range(bit_len[node] - 1, -1, -1):
384                a.append(key[node] >> i & 1)
385            if child[node << 1 | 1]:
386                rec(child[node << 1 | 1])
387
388        rec(self.root)
389        return a
390
391    def __str__(self):
392        return str(self.tolist())
393
394    __repr__ = __str__
395
396    def debug_acc(self) -> None:
397        child = self.child
398        key = self.key
399
400        def rec(node):
401            acc = self._popcount(key[node])
402            if child[node << 1]:
403                acc += rec(child[node << 1])
404            if child[node << 1 | 1]:
405                acc += rec(child[node << 1 | 1])
406            if acc != self.total[node]:
407                # self.debug()
408                assert False, "acc Error"
409            return acc
410
411        rec(self.root)
412
413    def access(self, k: int) -> int:
414        return self.__getitem__(k)
415
416    def rank0(self, r: int) -> int:
417        # a[0, r) に含まれる 0 の個数
418        return r - self._pref(r)
419
420    def rank1(self, r: int) -> int:
421        # a[0, r) に含まれる 1 の個数
422        return self._pref(r)
423
424    def rank(self, r: int, v: int) -> int:
425        # a[0, r) に含まれる v の個数
426        return self.rank1(r) if v else self.rank0(r)
427
428    def select0(self, k: int) -> int:
429        # k 番目の 0 のindex
430        # O(log(N))
431        if k < 0 or self.rank0(len(self)) <= k:
432            return -1
433        l, r = 0, len(self)
434        while r - l > 1:
435            m = (l + r) >> 1
436            if m - self._pref(m) > k:
437                r = m
438            else:
439                l = m
440        return l
441
442    def select1(self, k: int) -> int:
443        # k 番目の 1 のindex
444        # O(log(N))
445        if k < 0 or self.rank1(len(self)) <= k:
446            return -1
447        l, r = 0, len(self)
448        while r - l > 1:
449            m = (l + r) >> 1
450            if self._pref(m) > k:
451                r = m
452            else:
453                l = m
454        return l
455
456    def select(self, k: int, v: int) -> int:
457        # k 番目の v のindex
458        # O(log(N))
459        return self.select1(k) if v else self.select0(k)

仕様

class SplayTreeBitVector(a: Sequence[int] = [])[source]

Bases: BitVectorInterface

access(k: int) int[source]
debug()[source]
debug_acc() None[source]
insert(k: int, key: int) None[source]
pop(k: int = -1) int[source]
rank(r: int, v: int) int[source]
rank0(r: int) int[source]
rank1(r: int) int[source]
reserve(n: int) None[source]
select(k: int, v: int) int[source]
select0(k: int) int[source]
select1(k: int) int[source]
tolist() list[int][source]