lazy_avl_tree

ソースコード

from titan_pylib.data_structures.avl_tree.lazy_avl_tree import LazyAVLTree

view on github

展開済みコード

  1# from titan_pylib.data_structures.avl_tree.lazy_avl_tree import LazyAVLTree
  2from typing import Generic, Iterable, TypeVar, Callable, Optional
  3
  4T = TypeVar("T")
  5F = TypeVar("F")
  6
  7
  8class LazyAVLTree(Generic[T, F]):
  9    """遅延伝播反転可能平衡二分木です。"""
 10
 11    class Node:
 12
 13        def __init__(self, key: T, id: F):
 14            self.key: T = key
 15            self.data: T = key
 16            self.left: Optional[LazyAVLTree.Node] = None
 17            self.right: Optional[LazyAVLTree.Node] = None
 18            self.lazy: F = id
 19            self.rev: int = 0
 20            self.height: int = 1
 21            self.size: int = 1
 22
 23        def __str__(self):
 24            if self.left is None and self.right is None:
 25                return f"key:{self.key, self.height, self.size, self.data, self.lazy, self.rev}\n"
 26            return f"key:{self.key, self.height, self.size, self.data, self.lazy, self.rev},\n left:{self.left},\n right:{self.right}\n"
 27
 28    def __init__(
 29        self,
 30        a: Iterable[T],
 31        op: Callable[[T, T], T],
 32        mapping: Callable[[F, T], T],
 33        composition: Callable[[F, F], F],
 34        e: T,
 35        id: F,
 36        node: Node = None,
 37    ) -> None:
 38        self.root: Optional[LazyAVLTree.Node] = node
 39        self.op: Callable[[T, T], T] = op
 40        self.mapping: Callable[[F, T], T] = mapping
 41        self.composition: Callable[[F, F], F] = composition
 42        self.e: T = e
 43        self.id: F = id
 44        a = list(a)
 45        if a:
 46            self._build(a)
 47
 48    def _build(self, a: list[T]) -> None:
 49        Node = LazyAVLTree.Node
 50        id = self.id
 51
 52        def sort(l: int, r: int) -> LazyAVLTree.Node:
 53            mid = (l + r) >> 1
 54            node = Node(a[mid], id)
 55            if l != mid:
 56                node.left = sort(l, mid)
 57            if mid + 1 != r:
 58                node.right = sort(mid + 1, r)
 59            self._update(node)
 60            return node
 61
 62        self.root = sort(0, len(a))
 63
 64    def _propagate(self, node: Node) -> None:
 65        l, r = node.left, node.right
 66        if node.rev:
 67            node.left, node.right = r, l
 68            if l:
 69                l.rev ^= 1
 70            if r:
 71                r.rev ^= 1
 72            node.rev = 0
 73        if node.lazy != self.id:
 74            lazy = node.lazy
 75            if l:
 76                l.data = self.mapping(lazy, l.data)
 77                l.key = self.mapping(lazy, l.key)
 78                l.lazy = lazy if l.lazy == self.id else self.composition(lazy, l.lazy)
 79            if r:
 80                r.data = self.mapping(lazy, r.data)
 81                r.key = self.mapping(lazy, r.key)
 82                r.lazy = lazy if r.lazy == self.id else self.composition(lazy, r.lazy)
 83            node.lazy = self.id
 84
 85    def _update(self, node: Node) -> None:
 86        node.size = 1
 87        node.data = node.key
 88        node.height = 1
 89        if node.left:
 90            node.size += node.left.size
 91            node.data = self.op(node.left.data, node.data)
 92            node.height = max(node.left.height + 1, 1)
 93        if node.right:
 94            node.size += node.right.size
 95            node.data = self.op(node.data, node.right.data)
 96            node.height = max(node.height, node.right.height + 1)
 97
 98    def _get_balance(self, node: Node) -> int:
 99        return (
100            (0 if node.right is None else -node.right.height)
101            if node.left is None
102            else (
103                node.left.height
104                if node.right is None
105                else node.left.height - node.right.height
106            )
107        )
108
109    def _balance_left(self, node: Node) -> Node:
110        self._propagate(node.left)
111        if node.left.left is None or node.left.left.height + 2 == node.left.height:
112            u = node.left.right
113            self._propagate(u)
114            node.left.right = u.left
115            u.left = node.left
116            node.left = u.right
117            u.right = node
118            self._update(u.left)
119        else:
120            u = node.left
121            node.left = u.right
122            u.right = node
123        self._update(u.right)
124        self._update(u)
125        return u
126
127    def _balance_right(self, node: Node) -> Node:
128        self._propagate(node.right)
129        if node.right.right is None or node.right.right.height + 2 == node.right.height:
130            u = node.right.left
131            self._propagate(u)
132            node.right.left = u.right
133            u.right = node.right
134            node.right = u.left
135            u.left = node
136            self._update(u.right)
137        else:
138            u = node.right
139            node.right = u.left
140            u.left = node
141        self._update(u.left)
142        self._update(u)
143        return u
144
145    def _kth_elm(self, k: int) -> T:
146        if k < 0:
147            k += len(self)
148        node = self.root
149        while True:
150            self._propagate(node)
151            t = 0 if node.left is None else node.left.size
152            if t == k:
153                return node.key
154            elif t < k:
155                k -= t + 1
156                node = node.right
157            else:
158                node = node.left
159
160    def _merge_with_root(self, l: Node, root: Node, r: Node) -> Node:
161        diff = (
162            (0 if r is None else -r.height)
163            if l is None
164            else (l.height if r is None else l.height - r.height)
165        )
166        if diff > 1:
167            self._propagate(l)
168            l.right = self._merge_with_root(l.right, root, r)
169            self._update(l)
170            if (
171                -l.right.height
172                if l.left is None
173                else l.left.height - l.right.height == -2
174            ):
175                return self._balance_right(l)
176            return l
177        elif diff < -1:
178            self._propagate(r)
179            r.left = self._merge_with_root(l, root, r.left)
180            self._update(r)
181            if (
182                r.left.height
183                if r.right is None
184                else r.left.height - r.right.height == 2
185            ):
186                return self._balance_left(r)
187            return r
188        else:
189            root.left = l
190            root.right = r
191            self._update(root)
192            return root
193
194    def _merge_node(self, l: Node, r: Node) -> Node:
195        if l is None:
196            return r
197        if r is None:
198            return l
199        l, tmp = self._pop_max(l)
200        return self._merge_with_root(l, tmp, r)
201
202    def merge(self, other: "LazyAVLTree") -> None:
203        self.root = self._merge_node(self.root, other.root)
204
205    def _pop_max(self, node: Node) -> tuple[Node, Node]:
206        self._propagate(node)
207        path = []
208        mx = node
209        while node.right:
210            path.append(node)
211            mx = node.right
212            node = node.right
213            self._propagate(node)
214        path.append(node.left)
215        for _ in range(len(path) - 1):
216            node = path.pop()
217            if node is None:
218                path[-1].right = None
219                self._update(path[-1])
220                continue
221            b = self._get_balance(node)
222            path[-1].right = (
223                self._balance_left(node)
224                if b == 2
225                else self._balance_right(node) if b == -2 else node
226            )
227            self._update(path[-1])
228        if path[0]:
229            b = self._get_balance(path[0])
230            path[0] = (
231                self._balance_left(path[0])
232                if b == 2
233                else self._balance_right(path[0]) if b == -2 else path[0]
234            )
235        mx.left = None
236        self._update(mx)
237        return path[0], mx
238
239    def _split_node(self, node: Node, k: int) -> tuple[Node, Node]:
240        if not node:
241            return None, None
242        self._propagate(node)
243        tmp = k if node.left is None else k - node.left.size
244        if tmp == 0:
245            return node.left, self._merge_with_root(None, node, node.right)
246        elif tmp < 0:
247            s, t = self._split_node(node.left, k)
248            return s, self._merge_with_root(t, node, node.right)
249        else:
250            s, t = self._split_node(node.right, tmp - 1)
251            return self._merge_with_root(node.left, node, s), t
252
253    def split(self, k: int) -> tuple["LazyAVLTree", "LazyAVLTree"]:
254        l, r = self._split_node(self.root, k)
255        return LazyAVLTree(
256            [], self.op, self.mapping, self.composition, self.e, self.id, l
257        ), LazyAVLTree([], self.op, self.mapping, self.composition, self.e, self.id, r)
258
259    def insert(self, k: int, key: T) -> None:
260        s, t = self._split_node(self.root, k)
261        self.root = self._merge_with_root(s, LazyAVLTree.Node(key, self.id), t)
262
263    def pop(self, k: int) -> T:
264        s, t = self._split_node(self.root, k + 1)
265        s, tmp = self._pop_max(s)
266        self.root = self._merge_node(s, t)
267        return tmp.key
268
269    def apply(self, l: int, r: int, f: F) -> None:
270        if l >= r or (not self.root):
271            return
272        stack = [(self.root), (self.root, 0, len(self))]
273        while stack:
274            if isinstance(stack[-1], tuple):
275                node, left, right = stack.pop()
276                if right <= l or r <= left:
277                    continue
278                self._propagate(node)
279                if l <= left and right < r:
280                    node.key = self.mapping(f, node.key)
281                    node.data = self.mapping(f, node.data)
282                    node.lazy = (
283                        f if node.lazy == self.id else self.composition(f, node.lazy)
284                    )
285                else:
286                    lsize = node.left.size if node.left else 0
287                    stack.append(node)
288                    if node.left:
289                        stack.append((node.left, left, left + lsize))
290                    if l <= left + lsize < r:
291                        node.key = self.mapping(f, node.key)
292                    if node.right:
293                        stack.append((node.right, left + lsize + 1, right))
294            else:
295                self._update(stack.pop())
296
297    def all_apply(self, f: F) -> None:
298        if not self.root:
299            return
300        self.root.key = self.mapping(f, self.root.key)
301        self.root.data = self.mapping(f, self.root.data)
302        self.root.lazy = (
303            f if self.root.lazy == self.id else self.composition(f, self.root.lazy)
304        )
305
306    def reverse(self, l: int, r: int) -> None:
307        if l >= r:
308            return
309        s, t = self._split_node(self.root, r)
310        r, s = self._split_node(s, l)
311        s.rev ^= 1
312        self.root = self._merge_node(self._merge_node(r, s), t)
313
314    def all_reverse(self) -> None:
315        if self.root is None:
316            return
317        self.root.rev ^= 1
318
319    def prod(self, l: int, r: int) -> T:
320        if l >= r or (not self.root):
321            return self.e
322
323        def dfs(node: LazyAVLTree.Node, left: int, right: int) -> T:
324            if right <= l or r <= left:
325                return self.e
326            self._propagate(node)
327            if l <= left and right < r:
328                return node.data
329            lsize = node.left.size if node.left else 0
330            res = self.e
331            if node.left:
332                res = dfs(node.left, left, left + lsize)
333            if l <= left + lsize < r:
334                res = self.op(res, node.key)
335            if node.right:
336                res = self.op(res, dfs(node.right, left + lsize + 1, right))
337            return res
338
339        return dfs(self.root, 0, len(self))
340
341    def all_prod(self) -> T:
342        return self.root.data if self.root else self.e
343
344    def clear(self) -> None:
345        self.root = None
346
347    def tolist(self) -> list[T]:
348        node = self.root
349        stack = []
350        a = []
351        while stack or node:
352            if node:
353                self._propagate(node)
354                stack.append(node)
355                node = node.left
356            else:
357                node = stack.pop()
358                a.append(node.key)
359                node = node.right
360        return a
361
362    def __len__(self):
363        return 0 if self.root is None else self.root.size
364
365    def __iter__(self):
366        self.__iter = 0
367        return self
368
369    def __next__(self):
370        if self.__iter == len(self):
371            raise StopIteration
372        res = self[self.__iter]
373        self.__iter += 1
374        return res
375
376    def __reversed__(self):
377        for i in range(len(self)):
378            yield self[-i - 1]
379
380    def __bool__(self):
381        return self.root is not None
382
383    def __getitem__(self, k: int) -> T:
384        return self._kth_elm(k)
385
386    def __setitem__(self, k, key: T):
387        if k < 0:
388            k += len(self)
389        node = self.root
390        path = []
391        while True:
392            self._propagate(node)
393            path.append(node)
394            t = 0 if node.left is None else node.left.size
395            if t == k:
396                node.key = key
397                break
398            if t < k:
399                k -= t + 1
400                node = node.right
401            else:
402                node = node.left
403        while path:
404            self._update(path.pop())
405
406    def __str__(self):
407        return "[" + ", ".join(map(str, self.tolist())) + "]"
408
409    def __repr__(self):
410        return f"LazyAVLTree({self})"

仕様

class LazyAVLTree(a: Iterable[T], op: Callable[[T, T], T], mapping: Callable[[F, T], T], composition: Callable[[F, F], F], e: T, id: F, node: Node = None)[source]

Bases: Generic[T, F]

遅延伝播反転可能平衡二分木です。

class Node(key: T, id: F)[source]

Bases: object

all_apply(f: F) None[source]
all_prod() T[source]
all_reverse() None[source]
apply(l: int, r: int, f: F) None[source]
clear() None[source]
insert(k: int, key: T) None[source]
merge(other: LazyAVLTree) None[source]
pop(k: int) T[source]
prod(l: int, r: int) T[source]
reverse(l: int, r: int) None[source]
split(k: int) tuple[LazyAVLTree, LazyAVLTree][source]
tolist() list[T][source]