persistent_lazy_avl_tree

ソースコード

from titan_pylib.data_structures.avl_tree.persistent_lazy_avl_tree import PersistentLazyAVLTree

view on github

展開済みコード

  1# from titan_pylib.data_structures.avl_tree.persistent_lazy_avl_tree import PersistentLazyAVLTree
  2from typing import Generic, Iterable, Optional, TypeVar, Callable, Union
  3
  4T = TypeVar("T")
  5F = TypeVar("F")
  6
  7
  8class PersistentLazyAVLTree(Generic[T, F]):
  9
 10    class Node:
 11
 12        def __init__(self, key: T, lazy: F):
 13            self.key: T = key
 14            self.data: T = key
 15            self.left: Optional[PersistentLazyAVLTree.Node] = None
 16            self.right: Optional[PersistentLazyAVLTree.Node] = None
 17            self.lazy: F = lazy
 18            self.rev: int = 0
 19            self.height: int = 1
 20            self.size: int = 1
 21
 22        def copy(self) -> "PersistentLazyAVLTree.Node":
 23            node = PersistentLazyAVLTree.Node(self.key, self.lazy)
 24            node.data = self.data
 25            node.left = self.left
 26            node.right = self.right
 27            node.rev = self.rev
 28            node.height = self.height
 29            node.size = self.size
 30            return node
 31
 32        def balance(self) -> int:
 33            return (
 34                (0 if self.right is None else -self.right.height)
 35                if self.left is None
 36                else (
 37                    self.left.height
 38                    if self.right is None
 39                    else self.left.height - self.right.height
 40                )
 41            )
 42
 43        def __str__(self):
 44            if self.left is None and self.right is None:
 45                return f"key={self.key}, height={self.height}, size={self.size}, data={self.data}, lazy={self.lazy}, rev={self.rev}\n"
 46            return f"key={self.key}, height={self.height}, size={self.size}, data={self.data}, lazy={self.lazy}, rev={self.rev},\n left:{self.left},\n right:{self.right}\n"
 47
 48        __repr__ = __str__
 49
 50    def __init__(
 51        self,
 52        a: Iterable[T],
 53        op: Callable[[T, T], T],
 54        mapping: Callable[[F, T], T],
 55        composition: Callable[[F, F], F],
 56        e: T,
 57        id: F,
 58        _root: Optional[Node] = None,
 59    ) -> None:
 60        self.root: Optional[PersistentLazyAVLTree.Node] = _root
 61        self.op: Callable[[T, T], T] = op
 62        self.mapping: Callable[[F, T], T] = mapping
 63        self.composition: Callable[[F, F], F] = composition
 64        self.e: T = e
 65        self.id: F = id
 66        a = list(a)
 67        if a:
 68            self._build(list(a))
 69
 70    def _build(self, a: list[T]) -> None:
 71        Node = PersistentLazyAVLTree.Node
 72
 73        def build(l: int, r: int) -> PersistentLazyAVLTree.Node:
 74            mid = (l + r) >> 1
 75            node = Node(a[mid], id)
 76            if l != mid:
 77                node.left = build(l, mid)
 78            if mid + 1 != r:
 79                node.right = build(mid + 1, r)
 80            self._update(node)
 81            return node
 82
 83        id = self.id
 84        self.root = build(0, len(a))
 85
 86    def _propagate(self, node: Node) -> None:
 87        if node.rev:
 88            node.rev = 0
 89            l = node.left.copy() if node.left else None
 90            r = node.right.copy() if node.right else None
 91            node.left, node.right = r, l
 92            if l:
 93                l.rev ^= 1
 94            if r:
 95                r.rev ^= 1
 96        if node.lazy != self.id:
 97            lazy = node.lazy
 98            node.lazy = self.id
 99            if node.left:
100                l = node.left.copy()
101                l.data = self.mapping(lazy, l.data)
102                l.key = self.mapping(lazy, l.key)
103                l.lazy = self.composition(lazy, l.lazy)
104                node.left = l
105            if node.right:
106                r = node.right.copy()
107                r.data = self.mapping(lazy, r.data)
108                r.key = self.mapping(lazy, r.key)
109                r.lazy = self.composition(lazy, r.lazy)
110                node.right = r
111
112    def _update(self, node: Node) -> None:
113        node.size = 1
114        node.data = node.key
115        node.height = 1
116        if node.left:
117            node.size += node.left.size
118            node.data = self.op(node.left.data, node.data)
119            node.height = max(node.left.height + 1, 1)
120        if node.right:
121            node.size += node.right.size
122            node.data = self.op(node.data, node.right.data)
123            node.height = max(node.height, node.right.height + 1)
124
125    def _rotate_right(self, node: Node) -> Node:
126        assert node.left
127        u = node.left.copy()
128        node.left = u.right
129        u.right = node
130        self._update(node)
131        self._update(u)
132        return u
133
134    def _rotate_left(self, node: Node) -> Node:
135        assert node.right
136        u = node.right.copy()
137        node.right = u.left
138        u.left = node
139        self._update(node)
140        self._update(u)
141        return u
142
143    def _balance_left(self, node: Node) -> Node:
144        assert node.right
145        self._propagate(node.right)
146        node.right = node.right.copy()
147        u = node.right
148        if u.balance() == 1:
149            assert u.left
150            self._propagate(u.left)
151            node.right = self._rotate_right(u)
152        u = self._rotate_left(node)
153        return u
154
155    def _balance_right(self, node: Node) -> Node:
156        assert node.left
157        self._propagate(node.left)
158        node.left = node.left.copy()
159        u = node.left
160        if u.balance() == -1:
161            assert u.right
162            self._propagate(u.right)
163            node.left = self._rotate_left(u)
164        u = self._rotate_right(node)
165        return u
166
167    def _merge_with_root(
168        self, l: Optional[Node], root: Node, r: Optional[Node]
169    ) -> Node:
170        diff = 0
171        if l:
172            diff += l.height
173        if r:
174            diff -= r.height
175        if diff > 1:
176            assert l
177            self._propagate(l)
178            l = l.copy()
179            l.right = self._merge_with_root(l.right, root, r)
180            self._update(l)
181            if l.balance() == -2:
182                return self._balance_left(l)
183            return l
184        if diff < -1:
185            assert r
186            self._propagate(r)
187            r = r.copy()
188            r.left = self._merge_with_root(l, root, r.left)
189            self._update(r)
190            if r.balance() == 2:
191                return self._balance_right(r)
192            return r
193        root = root.copy()
194        root.left = l
195        root.right = r
196        self._update(root)
197        return root
198
199    def _merge_node(self, l: Optional[Node], r: Optional[Node]) -> Optional[Node]:
200        if l is None and r is None:
201            return None
202        if l is None:
203            return r.copy()
204        if r is None:
205            return l.copy()
206        l = l.copy()
207        r = r.copy()
208        l, root = self._pop_right(l)
209        return self._merge_with_root(l, root, r)
210
211    def merge(self, other: "PersistentLazyAVLTree") -> "PersistentLazyAVLTree":
212        root = self._merge_node(self.root, other.root)
213        return self._new(root)
214
215    def _pop_right(self, node: Node) -> tuple[Node, Node]:
216        path = []
217        self._propagate(node)
218        node = node.copy()
219        mx = node
220        while node.right:
221            path.append(node)
222            self._propagate(node.right)
223            node = node.right.copy()
224            mx = node
225        path.append(node.left.copy() if node.left else None)
226        for _ in range(len(path) - 1):
227            node = path.pop()
228            if node is None:
229                path[-1].right = None
230                self._update(path[-1])
231                continue
232            b = node.balance()
233            if b == 2:
234                path[-1].right = self._balance_right(node)
235            elif b == -2:
236                path[-1].right = self._balance_left(node)
237            else:
238                path[-1].right = node
239            self._update(path[-1])
240        if path[0] is not None:
241            b = path[0].balance()
242            if b == 2:
243                path[0] = self._balance_right(path[0])
244            elif b == -2:
245                path[0] = self._balance_left(path[0])
246        mx.left = None
247        self._update(mx)
248        return path[0], mx
249
250    def _split_node(
251        self, node: Optional[Node], k: int
252    ) -> tuple[Optional[Node], Optional[Node]]:
253        if node is None:
254            return None, None
255        self._propagate(node)
256        tmp = k if node.left is None else k - node.left.size
257        l, r = None, None
258        if tmp == 0:
259            return node.left, self._merge_with_root(None, node, node.right)
260        elif tmp < 0:
261            l, r = self._split_node(node.left, k)
262            return l, self._merge_with_root(r, node, node.right)
263        else:
264            l, r = self._split_node(node.right, tmp - 1)
265            return self._merge_with_root(node.left, node, l), r
266
267    def split(self, k: int) -> tuple["PersistentLazyAVLTree", "PersistentLazyAVLTree"]:
268        l, r = self._split_node(self.root, k)
269        return self._new(l), self._new(r)
270
271    def _new(
272        self, root: Optional["PersistentLazyAVLTree.Node"]
273    ) -> "PersistentLazyAVLTree":
274        return PersistentLazyAVLTree(
275            [], self.op, self.mapping, self.composition, self.e, self.id, root
276        )
277
278    def apply(self, l: int, r: int, f: F) -> "PersistentLazyAVLTree":
279        if l >= r or (not self.root):
280            return self._new(self.root.copy() if self.root else None)
281        root = self.root.copy()
282        stack: list[
283            Union[
284                PersistentLazyAVLTree.Node, tuple[PersistentLazyAVLTree.Node, int, int]
285            ]
286        ] = [(root), (root, 0, len(self))]
287        while stack:
288            if isinstance(stack[-1], tuple):
289                node, left, right = stack.pop()
290                if right <= l or r <= left:
291                    continue
292                self._propagate(node)
293                if l <= left and right < r:
294                    node.key = self.mapping(f, node.key)
295                    node.data = self.mapping(f, node.data)
296                    node.lazy = (
297                        f if node.lazy == self.id else self.composition(f, node.lazy)
298                    )
299                else:
300                    lsize = node.left.size if node.left else 0
301                    stack.append(node)
302                    if node.left:
303                        left_copy = node.left.copy()
304                        node.left = left_copy
305                        stack.append((left_copy, left, left + lsize))
306                    if l <= left + lsize < r:
307                        node.key = self.mapping(f, node.key)
308                    if node.right:
309                        r_copy = node.right.copy()
310                        node.right = r_copy
311                        stack.append((r_copy, left + lsize + 1, right))
312            else:
313                self._update(stack.pop())
314        return self._new(root)
315
316    def prod(self, l: int, r) -> T:
317        if l >= r or (not self.root):
318            return self.e
319
320        def dfs(node: PersistentLazyAVLTree.Node, left: int, right: int) -> T:
321            if right <= l or r <= left:
322                return self.e
323            self._propagate(node)
324            if l <= left and right < r:
325                return node.data
326            lsize = node.left.size if node.left else 0
327            res = self.e
328            if node.left:
329                res = dfs(node.left, left, left + lsize)
330            if l <= left + lsize < r:
331                res = self.op(res, node.key)
332            if node.right:
333                res = self.op(res, dfs(node.right, left + lsize + 1, right))
334            return res
335
336        return dfs(self.root, 0, len(self))
337
338    def insert(self, k: int, key: T) -> "PersistentLazyAVLTree":
339        s, t = self._split_node(self.root, k)
340        root = self._merge_with_root(s, PersistentLazyAVLTree.Node(key, self.id), t)
341        return self._new(root)
342
343    def pop(self, k: int) -> tuple["PersistentLazyAVLTree", T]:
344        s, t = self._split_node(self.root, k + 1)
345        assert s
346        s, tmp = self._pop_right(s)
347        root = self._merge_node(s, t)
348        return self._new(root), tmp.key
349
350    def reverse(self, l: int, r: int) -> "PersistentLazyAVLTree":
351        if l >= r:
352            return self._new(self.root.copy() if self.root else None)
353        s, t = self._split_node(self.root, r)
354        u, s = self._split_node(s, l)
355        assert s
356        s.rev ^= 1
357        root = self._merge_node(self._merge_node(u, s), t)
358        return self._new(root)
359
360    def tolist(self) -> list[T]:
361        node = self.root
362        stack: list[PersistentLazyAVLTree.Node] = []
363        a: list[T] = []
364        while stack or node:
365            if node:
366                self._propagate(node)
367                stack.append(node)
368                node = node.left
369            else:
370                node = stack.pop()
371                a.append(node.key)
372                node = node.right
373        return a
374
375    def __getitem__(self, k: int) -> T:
376        if k < 0:
377            k += len(self)
378        node = self.root
379        while True:
380            assert node
381            self._propagate(node)
382            t = 0 if node.left is None else node.left.size
383            if t == k:
384                return node.key
385            elif t < k:
386                k -= t + 1
387                node = node.right
388            else:
389                node = node.left
390
391    def __len__(self):
392        return self.root.size if self.root else 0
393
394    def __str__(self):
395        return str(self.tolist())
396
397    def __repr__(self):
398        return f"PersistentLazyAVLTree({self})"

仕様

class PersistentLazyAVLTree(a: Iterable[T], op: Callable[[T, T], T], mapping: Callable[[F, T], T], composition: Callable[[F, F], F], e: T, id: F, _root: Node | None = None)[source]

Bases: Generic[T, F]

class Node(key: T, lazy: F)[source]

Bases: object

balance() int[source]
copy() Node[source]
apply(l: int, r: int, f: F) PersistentLazyAVLTree[source]
insert(k: int, key: T) PersistentLazyAVLTree[source]
merge(other: PersistentLazyAVLTree) PersistentLazyAVLTree[source]
pop(k: int) tuple[PersistentLazyAVLTree, T][source]
prod(l: int, r) T[source]
reverse(l: int, r: int) PersistentLazyAVLTree[source]
split(k: int) tuple[PersistentLazyAVLTree, PersistentLazyAVLTree][source]
tolist() list[T][source]