reversible_lazy_splay_tree_array

ソースコード

from titan_pylib.data_structures.splay_tree.reversible_lazy_splay_tree_array import ReversibleLazySplayTreeArrayData
from titan_pylib.data_structures.splay_tree.reversible_lazy_splay_tree_array import ReversibleLazySplayTreeArray

view on github

展開済みコード

  1# from titan_pylib.data_structures.splay_tree.reversible_lazy_splay_tree_array import ReversibleLazySplayTreeArray
  2from array import array
  3from typing import (
  4    Generic,
  5    TypeVar,
  6    Callable,
  7    Iterable,
  8    Optional,
  9    Union,
 10    Sequence,
 11)
 12
 13T = TypeVar("T")
 14F = TypeVar("F")
 15
 16
 17class ReversibleLazySplayTreeArrayData(Generic[T, F]):
 18
 19    def __init__(
 20        self,
 21        op: Optional[Callable[[T, T], T]] = None,
 22        mapping: Optional[Callable[[F, T], T]] = None,
 23        composition: Optional[Callable[[F, F], F]] = None,
 24        e: T = None,
 25        id: F = None,
 26    ) -> None:
 27        self.op: Callable[[T, T], T] = (lambda s, t: e) if op is None else op
 28        self.mapping: Callable[[F, T], T] = (lambda f, s: e) if op is None else mapping
 29        self.composition: Callable[[F, F], F] = (
 30            (lambda f, g: id) if op is None else composition
 31        )
 32        self.e: T = e
 33        self.id: F = id
 34        self.keydata: list[T] = [e, e, e]
 35        self.lazy: list[F] = [id]
 36        self.arr: array[int] = array("I", bytes(16))
 37        # left:  arr[node<<2]
 38        # right: arr[node<<2|1]
 39        # size:  arr[node<<2|2]
 40        # rev:   arr[node<<2|3]
 41        self.end: int = 1
 42
 43    def reserve(self, n: int) -> None:
 44        if n <= 0:
 45            return
 46        self.keydata += [self.e] * (3 * n)
 47        self.lazy += [self.id] * n
 48        self.arr += array("I", bytes(16 * n))
 49
 50
 51class ReversibleLazySplayTreeArray(Generic[T, F]):
 52
 53    def __init__(
 54        self,
 55        data: "ReversibleLazySplayTreeArrayData",
 56        n_or_a: Union[int, Iterable[T]] = 0,
 57        _root: int = 0,
 58    ):
 59        self.data = data
 60        self.root = _root
 61        if not n_or_a:
 62            return
 63        if isinstance(n_or_a, int):
 64            a = [data.e for _ in range(n_or_a)]
 65        elif not isinstance(n_or_a, Sequence):
 66            a = list(n_or_a)
 67        else:
 68            a = n_or_a
 69        if a:
 70            self._build(a)
 71
 72    def _build(self, a: Sequence[T]) -> None:
 73        def rec(l: int, r: int) -> int:
 74            mid = (l + r) >> 1
 75            if l != mid:
 76                arr[mid << 2] = rec(l, mid)
 77            if mid + 1 != r:
 78                arr[mid << 2 | 1] = rec(mid + 1, r)
 79            self._update(mid)
 80            return mid
 81
 82        n = len(a)
 83        keydata, arr = self.data.keydata, self.data.arr
 84        end = self.data.end
 85        self.data.reserve(n + end - len(keydata) // 3 + 1)
 86        self.data.end += n
 87        for i, e in enumerate(a):
 88            keydata[(end + i) * 3 + 0] = e
 89            keydata[(end + i) * 3 + 1] = e
 90            keydata[(end + i) * 3 + 2] = e
 91        self.root = rec(end, n + end)
 92
 93    def _make_node(self, key: T) -> int:
 94        data = self.data
 95        if data.end >= len(data.arr) // 4:
 96            data.keydata.append(key)
 97            data.keydata.append(key)
 98            data.keydata.append(key)
 99            data.lazy.append(data.id)
100            data.arr.append(0)
101            data.arr.append(0)
102            data.arr.append(1)
103            data.arr.append(0)
104        else:
105            data.keydata[data.end * 3 + 0] = key
106            data.keydata[data.end * 3 + 1] = key
107            data.keydata[data.end * 3 + 2] = key
108        data.end += 1
109        return data.end - 1
110
111    def _propagate(self, node: int) -> None:
112        data = self.data
113        arr = data.arr
114        if arr[node << 2 | 3]:
115            keydata = data.keydata
116            keydata[node * 3 + 1], keydata[node * 3 + 2] = (
117                keydata[node * 3 + 2],
118                keydata[node * 3 + 1],
119            )
120            arr[node << 2], arr[node << 2 | 1] = arr[node << 2 | 1], arr[node << 2]
121            arr[node << 2 | 3] = 0
122            arr[arr[node << 2] << 2 | 3] ^= 1
123            arr[arr[node << 2 | 1] << 2 | 3] ^= 1
124        nlazy = data.lazy[node]
125        if nlazy == data.id:
126            return
127        lnode, rnode = arr[node << 2], arr[node << 2 | 1]
128        keydata, lazy = data.keydata, data.lazy
129        lazy[node] = data.id
130        if lnode:
131            lazy[lnode] = data.composition(nlazy, lazy[lnode])
132            keydata[lnode * 3 + 0] = data.mapping(nlazy, keydata[lnode * 3 + 0])
133            keydata[lnode * 3 + 1] = data.mapping(nlazy, keydata[lnode * 3 + 1])
134            keydata[lnode * 3 + 2] = data.mapping(nlazy, keydata[lnode * 3 + 2])
135        if rnode:
136            lazy[rnode] = data.composition(nlazy, lazy[rnode])
137            keydata[rnode * 3 + 0] = data.mapping(nlazy, keydata[rnode * 3 + 0])
138            keydata[rnode * 3 + 1] = data.mapping(nlazy, keydata[rnode * 3 + 1])
139            keydata[rnode * 3 + 2] = data.mapping(nlazy, keydata[rnode * 3 + 2])
140
141    def _update_triple(self, x: int, y: int, z: int) -> None:
142        # data = self.data
143        # keydata, arr = data.keydata, data.arr
144        # lx, rx = arr[x<<2], arr[x<<2|1]
145        # ly, ry = arr[y<<2], arr[y<<2|1]
146        # self._propagate(lx)
147        # self._propagate(rx)
148        # self._propagate(ly)
149        # self._propagate(ry)
150        # arr[z<<2|2] = arr[x<<2|2]
151        # arr[x<<2|2] = 1 + arr[lx<<2|2] + arr[rx<<2|2]
152        # arr[y<<2|2] = 1 + arr[ly<<2|2] + arr[ry<<2|2]
153        # keydata[z*3+1] = keydata[x*3+1]
154        # keydata[z*3+2] = keydata[x*3+2]
155        # keydata[x*3+1] = data.op(data.op(keydata[lx*3+1], keydata[x*3]), keydata[rx*3+1])
156        # keydata[x*3+2] = data.op(data.op(keydata[rx*3+2], keydata[x*3]), keydata[lx*3+2])
157        # keydata[y*3+1] = data.op(data.op(keydata[ly*3+1], keydata[y*3]), keydata[ry*3+1])
158        # keydata[y*3+2] = data.op(data.op(keydata[ry*3+2], keydata[y*3]), keydata[ly*3+2])
159        self._update(x)
160        self._update(y)
161        self._update(z)
162
163    def _update_double(self, x: int, y: int) -> None:
164        # data = self.data
165        # keydata, arr = data.keydata, data.arr
166        # lx, rx = arr[x<<2], arr[x<<2|1]
167        # self._propagate(lx)
168        # self._propagate(rx)
169        # arr[y<<2|2] = arr[x<<2|2]
170        # arr[x<<2|2] = 1 + arr[lx<<2|2] + arr[rx<<2|2]
171        # keydata[y*3+1] = keydata[x*3+1]
172        # keydata[y*3+2] = keydata[x*3+2]
173        # keydata[x*3+1] = data.op(data.op(keydata[lx*3+1], keydata[x*3]), keydata[rx*3+1])
174        # keydata[x*3+2] = data.op(data.op(keydata[rx*3+2], keydata[x*3]), keydata[lx*3+2])
175        self._update(x)
176        self._update(y)
177
178    def _update(self, node: int) -> None:
179        data = self.data
180        keydata, arr = data.keydata, data.arr
181        lnode, rnode = arr[node << 2], arr[node << 2 | 1]
182        self._propagate(lnode)
183        self._propagate(rnode)
184        arr[node << 2 | 2] = 1 + arr[lnode << 2 | 2] + arr[rnode << 2 | 2]
185        keydata[node * 3 + 1] = data.op(
186            data.op(keydata[lnode * 3 + 1], keydata[node * 3 + 0]),
187            keydata[rnode * 3 + 1],
188        )
189        keydata[node * 3 + 2] = data.op(
190            data.op(keydata[rnode * 3 + 2], keydata[node * 3 + 0]),
191            keydata[lnode * 3 + 2],
192        )
193
194    def _splay(self, path: list[int], d: int) -> None:
195        arr = self.data.arr
196        g = d & 1
197        while len(path) > 1:
198            pnode = path.pop()
199            gnode = path.pop()
200            f = d >> 1 & 1
201            node = arr[pnode << 2 | g ^ 1]
202            nnode = (pnode if g == f else node) << 2 | f
203            arr[pnode << 2 | g ^ 1] = arr[node << 2 | g]
204            arr[node << 2 | g] = pnode
205            arr[gnode << 2 | f ^ 1] = arr[nnode]
206            arr[nnode] = gnode
207            self._update_triple(gnode, pnode, node)
208            if not path:
209                return
210            d >>= 2
211            g = d & 1
212            arr[path[-1] << 2 | g ^ 1] = node
213        pnode = path.pop()
214        node = arr[pnode << 2 | g ^ 1]
215        arr[pnode << 2 | g ^ 1] = arr[node << 2 | g]
216        arr[node << 2 | g] = pnode
217        self._update_double(pnode, node)
218
219    def _kth_elm_splay(self, node: int, k: int) -> int:
220        arr = self.data.arr
221        if k < 0:
222            k += arr[node << 2 | 2]
223        d = 0
224        path = []
225        while True:
226            self._propagate(node)
227            t = arr[arr[node << 2] << 2 | 2]
228            if t == k:
229                if path:
230                    self._splay(path, d)
231                return node
232            d = d << 1 | (t > k)
233            path.append(node)
234            node = arr[node << 2 | (t < k)]
235            if t < k:
236                k -= t + 1
237
238    def _left_splay(self, node: int) -> int:
239        if not node:
240            return 0
241        self._propagate(node)
242        arr = self.data.arr
243        if not arr[node << 2]:
244            return node
245        path = []
246        while arr[node << 2]:
247            path.append(node)
248            node = arr[node << 2]
249            self._propagate(node)
250        self._splay(path, (1 << len(path)) - 1)
251        return node
252
253    def _right_splay(self, node: int) -> int:
254        if not node:
255            return 0
256        self._propagate(node)
257        arr = self.data.arr
258        if not arr[node << 2 | 1]:
259            return node
260        path = []
261        while arr[node << 2 | 1]:
262            path.append(node)
263            node = arr[node << 2 | 1]
264            self._propagate(node)
265        self._splay(path, 0)
266        return node
267
268    def reserve(self, n: int) -> None:
269        self.data.reserve(n)
270
271    def merge(self, other: "ReversibleLazySplayTreeArray") -> None:
272        assert self.data is other.data
273        if not other.root:
274            return
275        if not self.root:
276            self.root = other.root
277            return
278        self.root = self._right_splay(self.root)
279        self.data.arr[self.root << 2 | 1] = other.root
280        self._update(self.root)
281
282    def split(
283        self, k: int
284    ) -> tuple["ReversibleLazySplayTreeArray", "ReversibleLazySplayTreeArray"]:
285        assert (
286            -len(self) < k <= len(self)
287        ), f"IndexError: ReversibleLazySplayTreeArray.split({k}), len={len(self)}"
288        if k < 0:
289            k += len(self)
290        if k >= self.data.arr[self.root << 2 | 2]:
291            return self, ReversibleLazySplayTreeArray(self.data, _root=0)
292        self.root = self._kth_elm_splay(self.root, k)
293        left = ReversibleLazySplayTreeArray(
294            self.data, _root=self.data.arr[self.root << 2]
295        )
296        self.data.arr[self.root << 2] = 0
297        self._update(self.root)
298        return left, self
299
300    def _internal_split(self, k: int) -> tuple[int, int]:
301        if k >= self.data.arr[self.root << 2 | 2]:
302            return self.root, 0
303        self.root = self._kth_elm_splay(self.root, k)
304        left = self.data.arr[self.root << 2]
305        self._propagate(left)
306        self.data.arr[self.root << 2] = 0
307        self._update(self.root)
308        return left, self.root
309
310    def reverse(self, l: int, r: int) -> None:
311        assert (
312            0 <= l <= r <= len(self)
313        ), f"IndexError: ReversibleLazySplayTreeArray.reverse({l}, {r}), len={len(self)}"
314        if l == r:
315            return
316        data = self.data
317        left, right = self._internal_split(r)
318        if l:
319            left = self._kth_elm_splay(left, l - 1)
320        data.arr[(data.arr[left << 2 | 1] if l else left) << 2 | 3] ^= 1
321        if right:
322            data.arr[right << 2] = left
323            self._update(right)
324        self.root = right if right else left
325
326    def all_reverse(self) -> None:
327        self.data.arr[self.root << 2 | 3] ^= 1
328        self._propagate(self.root)
329
330    def apply(self, l: int, r: int, f: F) -> None:
331        assert (
332            0 <= l <= r <= len(self)
333        ), f"IndexError: ReversibleLazySplayTreeArray.apply({l}, {r}), len={len(self)}"
334        data = self.data
335        left, right = self._internal_split(r)
336        keydata, lazy = data.keydata, data.lazy
337        if l:
338            left = self._kth_elm_splay(left, l - 1)
339        node = data.arr[left << 2 | 1] if l else left
340        keydata[node * 3 + 0] = data.mapping(f, keydata[node * 3 + 0])
341        keydata[node * 3 + 1] = data.mapping(f, keydata[node * 3 + 1])
342        keydata[node * 3 + 2] = data.mapping(f, keydata[node * 3 + 2])
343        lazy[node] = data.composition(f, lazy[node])
344        if l:
345            self._update(left)
346        if right:
347            data.arr[right << 2] = left
348            self._update(right)
349        self.root = right if right else left
350
351    def all_apply(self, f: F) -> None:
352        if not self.root:
353            return
354        data, node = self.data, self.root
355        data.keydata[node * 3 + 0] = data.mapping(f, data.keydata[node * 3 + 0])
356        data.keydata[node * 3 + 1] = data.mapping(f, data.keydata[node * 3 + 1])
357        data.keydata[node * 3 + 2] = data.mapping(f, data.keydata[node * 3 + 2])
358        data.lazy[node] = data.composition(f, data.lazy[node])
359
360    def prod(self, l: int, r: int) -> T:
361        assert (
362            0 <= l <= r <= len(self)
363        ), f"IndexError: LazySplayTree.prod({l}, {r}), len={len(self)}"
364        data = self.data
365        left, right = self._internal_split(r)
366        if l:
367            left = self._kth_elm_splay(left, l - 1)
368        node = data.arr[left << 2 | 1] if l else left
369        self._propagate(node)
370        res = data.keydata[node * 3 + 1]
371        if right:
372            data.arr[right << 2] = left
373            self._update(right)
374        self.root = right if right else left
375        return res
376
377    def all_prod(self) -> T:
378        return self.data.keydata[self.root * 3 + 1]
379
380    def insert(self, k: int, key: T) -> None:
381        assert (
382            -len(self) <= k <= len(self)
383        ), f"IndexError: ReversibleLazySplayTreeArray.insert({k}, {key}), len={len(self)}"
384        if k < 0:
385            k += len(self)
386        data = self.data
387        node = self._make_node(key)
388        if not self.root:
389            self._update(node)
390            self.root = node
391            return
392        arr = data.arr
393        if k == data.arr[self.root << 2 | 2]:
394            arr[node << 2] = self._right_splay(self.root)
395        else:
396            node_ = self._kth_elm_splay(self.root, k)
397            if arr[node_ << 2]:
398                arr[node << 2] = arr[node_ << 2]
399                arr[node_ << 2] = 0
400                self._update(node_)
401            arr[node << 2 | 1] = node_
402        self._update(node)
403        self.root = node
404
405    def append(self, key: T) -> None:
406        data = self.data
407        node = self._right_splay(self.root)
408        self.root = self._make_node(key)
409        data.arr[self.root << 2] = node
410        self._update(self.root)
411
412    def appendleft(self, key: T) -> None:
413        node = self._left_splay(self.root)
414        self.root = self._make_node(key)
415        self.data.arr[self.root << 2 | 1] = node
416        self._update(self.root)
417
418    def pop(self, k: int = -1) -> T:
419        assert (
420            -len(self) <= k < len(self)
421        ), f"IndexError: ReversibleLazySplayTreeArray.pop({k})"
422        data = self.data
423        if k == -1:
424            node = self._right_splay(self.root)
425            self._propagate(node)
426            self.root = data.arr[node << 2]
427            return data.keydata[node * 3 + 0]
428        self.root = self._kth_elm_splay(self.root, k)
429        res = data.keydata[self.root * 3 + 0]
430        if not data.arr[self.root << 2]:
431            self.root = data.arr[self.root << 2 | 1]
432        elif not data.arr[self.root << 2 | 1]:
433            self.root = data.arr[self.root << 2]
434        else:
435            node = self._right_splay(data.arr[self.root << 2])
436            data.arr[node << 2 | 1] = data.arr[self.root << 2 | 1]
437            self.root = node
438            self._update(self.root)
439        return res
440
441    def popleft(self) -> T:
442        assert self, "IndexError: ReversibleLazySplayTreeArray.popleft()"
443        node = self._left_splay(self.root)
444        self.root = self.data.arr[node << 2 | 1]
445        return self.data.keydata[node * 3 + 0]
446
447    def rotate(self, x: int) -> None:
448        # 「末尾をを削除し先頭に挿入」をx回
449        n = self.data.arr[self.root << 2 | 2]
450        l, self = self.split(n - (x % n))
451        self.merge(l)
452
453    def tolist(self) -> list[T]:
454        node = self.root
455        arr, keydata = self.data.arr, self.data.keydata
456        stack = []
457        res = []
458        while stack or node:
459            if node:
460                self._propagate(node)
461                stack.append(node)
462                node = arr[node << 2]
463            else:
464                node = stack.pop()
465                res.append(keydata[node * 3 + 0])
466                node = arr[node << 2 | 1]
467        return res
468
469    def clear(self) -> None:
470        self.root = 0
471
472    def __setitem__(self, k: int, key: T):
473        assert (
474            -len(self) <= k < len(self)
475        ), f"IndexError: ReversibleLazySplayTreeArray.__setitem__({k})"
476        self.root = self._kth_elm_splay(self.root, k)
477        self.data.keydata[self.root * 3 + 0] = key
478        self._update(self.root)
479
480    def __getitem__(self, k: int) -> T:
481        assert (
482            -len(self) <= k < len(self)
483        ), f"IndexError: ReversibleLazySplayTreeArray.__getitem__({k})"
484        self.root = self._kth_elm_splay(self.root, k)
485        return self.data.keydata[self.root * 3 + 0]
486
487    def __iter__(self):
488        self.__iter = 0
489        return self
490
491    def __next__(self):
492        if self.__iter == self.data.arr[self.root << 2 | 2]:
493            raise StopIteration
494        res = self.__getitem__(self.__iter)
495        self.__iter += 1
496        return res
497
498    def __reversed__(self):
499        for i in range(len(self)):
500            yield self.__getitem__(-i - 1)
501
502    def __len__(self):
503        return self.data.arr[self.root << 2 | 2]
504
505    def __str__(self):
506        return str(self.tolist())
507
508    def __bool__(self):
509        return self.root != 0
510
511    def __repr__(self):
512        return f"ReversibleLazySplayTreeArray({self})"

仕様

class ReversibleLazySplayTreeArray(data: ReversibleLazySplayTreeArrayData, n_or_a: int | Iterable[T] = 0, _root: int = 0)[source]

Bases: Generic[T, F]

all_apply(f: F) None[source]
all_prod() T[source]
all_reverse() None[source]
append(key: T) None[source]
appendleft(key: T) None[source]
apply(l: int, r: int, f: F) None[source]
clear() None[source]
insert(k: int, key: T) None[source]
merge(other: ReversibleLazySplayTreeArray) None[source]
pop(k: int = -1) T[source]
popleft() T[source]
prod(l: int, r: int) T[source]
reserve(n: int) None[source]
reverse(l: int, r: int) None[source]
rotate(x: int) None[source]
split(k: int) tuple[ReversibleLazySplayTreeArray, ReversibleLazySplayTreeArray][source]
tolist() list[T][source]
class ReversibleLazySplayTreeArrayData(op: Callable[[T, T], T] | None = None, mapping: Callable[[F, T], T] | None = None, composition: Callable[[F, F], F] | None = None, e: T = None, id: F = None)[source]

Bases: Generic[T, F]

reserve(n: int) None[source]