lazy_splay_tree

ソースコード

from titan_pylib.data_structures.splay_tree.lazy_splay_tree import LazySplayTree

view on github

展開済みコード

  1# from titan_pylib.data_structures.splay_tree.lazy_splay_tree import LazySplayTree
  2from typing import Generic, Union, TypeVar, Callable, Iterable, Optional
  3
  4T = TypeVar("T")
  5F = TypeVar("F")
  6
  7
  8class LazySplayTree(Generic[T, F]):
  9
 10    class _Node:
 11
 12        def __init__(self, key: T, lazy: F) -> None:
 13            self.key: T = key
 14            self.data: T = key
 15            self.rdata: T = key
 16            self.lazy: F = lazy
 17            self.left: Optional["LazySplayTree._Node"] = None
 18            self.right: Optional["LazySplayTree._Node"] = None
 19            self.par: Optional["LazySplayTree._Node"] = None
 20            self.size: int = 1
 21            self.rev: int = 0
 22
 23    def __init__(
 24        self,
 25        n_or_a: Union[int, Iterable[T]],
 26        op: Callable[[T, T], T],
 27        mapping: Callable[[F, T], T],
 28        composition: Callable[[F, F], F],
 29        e: T,
 30        id: F,
 31        _root: Optional[_Node] = None,
 32    ) -> None:
 33        """構築します。
 34        :math:`O(n)` です。
 35
 36        Args:
 37          n_or_a (Union[int, Iterable[T]]): ``n`` のとき、 ``e`` から長さ ``n`` で構築します。
 38                                            ``a`` のとき、 ``a`` から構築します。
 39          op (Callable[[T, T], T]): 遅延セグ木のあれです。
 40          mapping (Callable[[F, T], T]): 遅延セグ木のあれです。
 41          composition (Callable[[F, F], F]): 遅延セグ木のあれです。
 42          e (T): 遅延セグ木のあれです。
 43          id (F): 遅延セグ木のあれです。
 44        """
 45        self.op = op
 46        self.mapping = mapping
 47        self.composition = composition
 48        self.e = e
 49        self.id = id
 50        self.root = _root
 51        if _root:
 52            return
 53        a = n_or_a
 54        if isinstance(a, int):
 55            a = [e for _ in range(a)]
 56        elif not isinstance(a, list):
 57            a = list(a)
 58        if a:
 59            self._build(a)
 60
 61    def _build(self, a: list[T]) -> None:
 62        _Node = LazySplayTree._Node
 63        id = self.id
 64
 65        def build(l: int, r: int) -> LazySplayTree._Node:
 66            mid = (l + r) >> 1
 67            node = _Node(a[mid], id)
 68            if l != mid:
 69                node.left = build(l, mid)
 70                node.left.par = node
 71            if mid + 1 != r:
 72                node.right = build(mid + 1, r)
 73                node.right.par = node
 74            self._update(node)
 75            return node
 76
 77        self.root = build(0, len(a))
 78
 79    def _rotate(self, node: _Node) -> None:
 80        pnode = node.par
 81        gnode = pnode.par
 82        if gnode:
 83            if gnode.left is pnode:
 84                gnode.left = node
 85            else:
 86                gnode.right = node
 87        node.par = gnode
 88        if pnode.left is node:
 89            pnode.left = node.right
 90            if node.right:
 91                node.right.par = pnode
 92            node.right = pnode
 93        else:
 94            pnode.right = node.left
 95            if node.left:
 96                node.left.par = pnode
 97            node.left = pnode
 98        pnode.par = node
 99        self._update_double(pnode, node)
100
101    def _propagate_rev(self, node: Optional[_Node]) -> None:
102        if not node:
103            return
104        node.rev ^= 1
105
106    def _propagate_lazy(self, node: Optional[_Node], f: F) -> None:
107        if not node:
108            return
109        node.key = self.mapping(f, node.key)
110        node.data = self.mapping(f, node.data)
111        node.rdata = self.mapping(f, node.rdata)
112        node.lazy = f if node.lazy == self.id else self.composition(f, node.lazy)
113
114    def _propagate(self, node: Optional[_Node]) -> None:
115        if not node:
116            return
117        if node.rev:
118            node.data, node.rdata = node.rdata, node.data
119            node.left, node.right = node.right, node.left
120            self._propagate_rev(node.left)
121            self._propagate_rev(node.right)
122            node.rev = 0
123        if node.lazy != self.id:
124            self._propagate_lazy(node.left, node.lazy)
125            self._propagate_lazy(node.right, node.lazy)
126            node.lazy = self.id
127
128    def _update_double(self, pnode: _Node, node: _Node) -> None:
129        node.data = pnode.data
130        node.rdata = pnode.rdata
131        node.size = pnode.size
132        self._update(pnode)
133
134    def _update(self, node: _Node) -> None:
135        node.data = node.key
136        node.rdata = node.key
137        node.size = 1
138        if node.left:
139            node.data = self.op(node.left.data, node.data)
140            node.rdata = self.op(node.rdata, node.left.rdata)
141            node.size += node.left.size
142        if node.right:
143            node.data = self.op(node.data, node.right.data)
144            node.rdata = self.op(node.right.rdata, node.rdata)
145            node.size += node.right.size
146
147    def _splay(self, node: _Node) -> None:
148        # while node.par and node.par.par:
149        #   pnode = node.par
150        #   self._rotate(pnode if (pnode.par.left is pnode) == (pnode.left is node) else node)
151        #   self._rotate(node)
152        # if node.par:
153        #   self._rotate(node)
154        while node.par:
155            pnode = node.par
156            if pnode:
157                self._rotate(
158                    pnode if (pnode.par.left is pnode) == (pnode.left is node) else node
159                )
160            self._rotate(node)
161
162    def kth_splay(self, node: Optional[_Node], k: int) -> None:
163        if k < 0:
164            k += len(self)
165        while True:
166            self._propagate(node)
167            t = node.left.size if node.left else 0
168            if t == k:
169                break
170            if t > k:
171                node = node.left
172            else:
173                node = node.right
174                k -= t + 1
175        self._splay(node)
176        return node
177
178    def _left_splay(self, node: Optional[_Node]) -> Optional[_Node]:
179        self._propagate(node)
180        if not node or not node.left:
181            return node
182        while node.left:
183            node = node.left
184            self._propagate(node)
185        self._splay(node)
186        return node
187
188    def _right_splay(self, node: Optional[_Node]) -> Optional[_Node]:
189        self._propagate(node)
190        if not node or not node.right:
191            return node
192        while node.right:
193            node = node.right
194            self._propagate(node)
195        self._splay(node)
196        return node
197
198    def merge(self, other: "LazySplayTree") -> None:
199        """``other`` を後ろに連結します。
200        償却 :math:`O(\\log{n})` です。
201
202        Args:
203          other (LazySplayTree):
204        """
205        if not self.root:
206            self.root = other.root
207            return
208        if not other.root:
209            return
210        self.root = self._right_splay(self.root)
211        self.root.right = other.root
212        other.root.par = self.root
213        self._update(self.root)
214
215    def split(self, k: int) -> tuple["LazySplayTree", "LazySplayTree"]:
216        """位置 ``k`` で split します。
217        償却 :math:`O(\\log{n})` です。
218
219        Returns:
220          tuple['LazySplayTree', 'LazySplayTree']:
221        """
222        left, right = self._internal_split(self.root, k)
223        left_splay = LazySplayTree(
224            0, self.op, self.mapping, self.composition, self.e, self.id, left
225        )
226        right_splay = LazySplayTree(
227            0, self.op, self.mapping, self.composition, self.e, self.id, right
228        )
229        return left_splay, right_splay
230
231    def _internal_split(self, k: int) -> tuple[_Node, _Node]:
232        if k == len(self):
233            return self.root, None
234        right = self.kth_splay(self.root, k)
235        left = right.left
236        if left:
237            left.par = None
238        right.left = None
239        self._update(right)
240        return left, right
241
242    def _internal_merge(
243        self, left: Optional[_Node], right: Optional[_Node]
244    ) -> Optional[_Node]:
245        # need (not right) or (not right.left)
246        if not right:
247            return left
248        assert right.left is None
249        right.left = left
250        if left:
251            left.par = right
252        self._update(right)
253        return right
254
255    def reverse(self, l: int, r: int) -> None:
256        """区間 ``[l, r)`` を反転します。
257        償却 :math:`O(\\log{n})` です。
258
259        Args:
260          l (int):
261          r (int):
262        """
263        assert (
264            0 <= l <= r <= len(self)
265        ), f"IndexError: {self.__class__.__name__}.reverse({l}, {r}), len={len(self)}"
266        left, right = self._internal_split(r)
267        if l == 0:
268            self._propagate_rev(left)
269        else:
270            left = self.kth_splay(left, l - 1)
271            self._propagate_rev(left.right)
272        self.root = self._internal_merge(left, right)
273
274    def all_reverse(self) -> None:
275        """区間 ``[0, n)`` を反転します。
276        :math:`O(1)` です。
277        """
278        self._propagate_rev(self.root)
279
280    def apply(self, l: int, r: int, f: F) -> None:
281        """区間 ``[l, r)`` に ``f`` を作用します。
282        償却 :math:`O(\\log{n})` です。
283
284        Args:
285          l (int):
286          r (int):
287          f (F): 作用素です。
288        """
289        assert (
290            0 <= l <= r <= len(self)
291        ), f"IndexError: {self.__class__.__name__}.apply({l}, {r}, {f}), len={len(self)}"
292        left, right = self._internal_split(r)
293        if l == 0:
294            self._propagate_lazy(left, f)
295        else:
296            left = self.kth_splay(left, l - 1)
297            self._propagate_lazy(left.right, f)
298            self._update(left)
299        self.root = self._internal_merge(left, right)
300
301    def all_apply(self, f: F) -> None:
302        """区間 ``[0, n)`` に ``f`` を作用します。
303        :math:`O(1)` です。
304        """
305        self._propagate_lazy(self.root, f)
306
307    def prod(self, l: int, r: int) -> T:
308        """区間 ``[l, r)`` の総積を求めます。
309        償却 :math:`O(\\log{n})` です。
310        """
311        assert (
312            0 <= l <= r <= len(self)
313        ), f"IndexError: {self.__class__.__name__}.prod({l}, {r}), len={len(self)}"
314        if l == r:
315            return self.e
316        left, right = self._internal_split(r)
317        if l == 0:
318            res = left.data
319        else:
320            left = self.kth_splay(left, l - 1)
321            res = left.right.data
322        self.root = self._internal_merge(left, right)
323        return res
324
325    def all_prod(self) -> T:
326        """区間 ``[0, n)`` の総積を求めます。
327        :math:`O(1)` です。
328        """
329        self._propagate(self.root)
330        return self.root.data if self.root else self.e
331
332    def insert(self, k: int, key: T) -> None:
333        """位置 ``k`` に ``key`` を挿入します。
334        償却 :math:`O(\\log{n})` です。
335
336        Args:
337          k (int):
338          key (T):
339        """
340        assert 0 <= k <= len(self)
341        node = self._Node(key, self.id)
342        if not self.root:
343            self.root = node
344            return
345        if k >= len(self):
346            root = self.kth_splay(self.root, len(self) - 1)
347            node.left = root
348        else:
349            root = self.kth_splay(self.root, k)
350            if root.left:
351                node.left = root.left
352                root.left.par = node
353                root.left = None
354                self._update(root)
355            node.right = root
356        root.par = node
357        self.root = node
358        self._update(self.root)
359
360    def append(self, key: T) -> None:
361        """末尾に ``key`` を追加します。
362        償却 :math:`O(\\log{n})` です。
363
364        Args:
365          key (T):
366        """
367        node = self._right_splay(self.root)
368        self.root = self._Node(key, self.id)
369        self.root.left = node
370        if node:
371            node.par = self.root
372        self._update(self.root)
373
374    def appendleft(self, key: T) -> None:
375        """先頭に ``key`` を追加します。
376        償却 :math:`O(\\log{n})` です。
377
378        Args:
379          key (T):
380        """
381        node = self._left_splay(self.root)
382        self.root = self._Node(key, self.id)
383        self.root.right = node
384        if node:
385            node.par = self.root
386        self._update(self.root)
387
388    def pop(self, k: int = -1) -> T:
389        """位置 ``k`` の要素を削除し、その値を返します。
390        償却 :math:`O(\\log{n})` です。
391
392        Args:
393          k (int, optional): 指定するインデックスです。 Defaults to -1.
394        """
395        if k == -1:
396            node = self._right_splay(self.root)
397            if node.left:
398                node.left.par = None
399            self.root = node.left
400            return node.key
401        root = self.kth_splay(self.root, k)
402        res = root.key
403        if root.left and root.right:
404            node = self._right_splay(root.left)
405            node.par = None
406            node.right = root.right
407            if node.right:
408                node.right.par = node
409            self._update(node)
410            self.root = node
411        else:
412            self.root = root.right if root.right else root.left
413            if self.root:
414                self.root.par = None
415        return res
416
417    def popleft(self) -> T:
418        """先頭の要素を削除し、その値を返します。
419        償却 :math:`O(\\log{n})` です。
420
421        Returns:
422          T:
423        """
424        node = self._left_splay(self.root)
425        self.root = node.right
426        if node.right:
427            node.right.par = None
428        return node.key
429
430    def copy(self) -> "LazySplayTree":
431        """コピーします。
432
433        Note:
434          償却 :math:`O(n)` です。
435
436        Returns:
437          LazySplayTree:
438        """
439        return LazySplayTree(
440            self.tolist(), self.op, self.mapping, self.composition, self.e, self.id
441        )
442
443    def clear(self) -> None:
444        """全ての要素を削除します。
445        :math:`O(1)` です。
446        """
447        self.root = None
448
449    def tolist(self) -> list[T]:
450        """``list`` にして返します。
451        :math:`O(n)` です。非再帰です。
452
453        Returns:
454          list[T]:
455        """
456        node = self.root
457        stack = []
458        a = []
459        while stack or node:
460            if node:
461                self._propagate(node)
462                stack.append(node)
463                node = node.left
464            else:
465                node = stack.pop()
466                a.append(node.key)
467                node = node.right
468        return a
469
470    def __setitem__(self, k: int, key: T) -> None:
471        """位置 ``k`` の要素を値 ``key`` で更新します。
472        償却 :math:`O(\\log{n})` です。
473
474        Args:
475          k (int):
476          key (T):
477        """
478        self.root = self.kth_splay(self.root, k)
479        self.root.key = key
480        self._update(self.root)
481
482    def __getitem__(self, k: int) -> T:
483        """位置 ``k`` の値を返します。
484        償却 :math:`O(\\log{n})` です。
485
486        Args:
487          k (int):
488          key (T):
489        """
490        self.root = self.kth_splay(self.root, k)
491        return self.root.key
492
493    def __iter__(self):
494        self.__iter = 0
495        return self
496
497    def __next__(self):
498        if self.__iter == len(self):
499            raise StopIteration
500        res = self[self.__iter]
501        self.__iter += 1
502        return res
503
504    def __reversed__(self):
505        for i in range(len(self)):
506            yield self[-i - 1]
507
508    def __len__(self):
509        """要素数を返します。
510        :math:`O(1)` です。
511
512        Returns:
513          int:
514        """
515        return self.root.size if self.root else 0
516
517    def __str__(self):
518        return str(self.tolist())
519
520    def __bool__(self):
521        return self.root is not None
522
523    def __repr__(self):
524        return f"{self.__class__.__name__}({self})"

仕様

class LazySplayTree(n_or_a: int | Iterable[T], op: Callable[[T, T], T], mapping: Callable[[F, T], T], composition: Callable[[F, F], F], e: T, id: F, _root: _Node | None = None)[source]

Bases: Generic[T, F]

__getitem__(k: int) T[source]

位置 k の値を返します。 償却 \(O(\log{n})\) です。

Parameters:
  • k (int)

  • key (T)

__len__()[source]

要素数を返します。 \(O(1)\) です。

Return type:

int

__setitem__(k: int, key: T) None[source]

位置 k の要素を値 key で更新します。 償却 \(O(\log{n})\) です。

Parameters:
  • k (int)

  • key (T)

all_apply(f: F) None[source]

区間 [0, n)f を作用します。 \(O(1)\) です。

all_prod() T[source]

区間 [0, n) の総積を求めます。 \(O(1)\) です。

all_reverse() None[source]

区間 [0, n) を反転します。 \(O(1)\) です。

append(key: T) None[source]

末尾に key を追加します。 償却 \(O(\log{n})\) です。

Parameters:

key (T)

appendleft(key: T) None[source]

先頭に key を追加します。 償却 \(O(\log{n})\) です。

Parameters:

key (T)

apply(l: int, r: int, f: F) None[source]

区間 [l, r)f を作用します。 償却 \(O(\log{n})\) です。

Parameters:
  • l (int)

  • r (int)

  • f (F) – 作用素です。

clear() None[source]

全ての要素を削除します。 \(O(1)\) です。

copy() LazySplayTree[source]

コピーします。

Note

償却 \(O(n)\) です。

Return type:

LazySplayTree

insert(k: int, key: T) None[source]

位置 kkey を挿入します。 償却 \(O(\log{n})\) です。

Parameters:
  • k (int)

  • key (T)

kth_splay(node: _Node | None, k: int) None[source]
merge(other: LazySplayTree) None[source]

other を後ろに連結します。 償却 \(O(\log{n})\) です。

Parameters:

other (LazySplayTree)

pop(k: int = -1) T[source]

位置 k の要素を削除し、その値を返します。 償却 \(O(\log{n})\) です。

Parameters:

k (int, optional) – 指定するインデックスです。 Defaults to -1.

popleft() T[source]

先頭の要素を削除し、その値を返します。 償却 \(O(\log{n})\) です。

Return type:

T

prod(l: int, r: int) T[source]

区間 [l, r) の総積を求めます。 償却 \(O(\log{n})\) です。

reverse(l: int, r: int) None[source]

区間 [l, r) を反転します。 償却 \(O(\log{n})\) です。

Parameters:
  • l (int)

  • r (int)

split(k: int) tuple[LazySplayTree, LazySplayTree][source]

位置 k で split します。 償却 \(O(\log{n})\) です。

Return type:

tuple[‘LazySplayTree’, ‘LazySplayTree’]

tolist() list[T][source]

list にして返します。 \(O(n)\) です。非再帰です。

Return type:

list[T]