1from typing import Generic, Iterable, Optional, TypeVar, Callable, Union
2
3T = TypeVar("T")
4F = TypeVar("F")
5
6
[docs]
7class PersistentLazyAVLTree(Generic[T, F]):
8
[docs]
9 class Node:
10
11 def __init__(self, key: T, lazy: F):
12 self.key: T = key
13 self.data: T = key
14 self.left: Optional[PersistentLazyAVLTree.Node] = None
15 self.right: Optional[PersistentLazyAVLTree.Node] = None
16 self.lazy: F = lazy
17 self.rev: int = 0
18 self.height: int = 1
19 self.size: int = 1
20
[docs]
21 def copy(self) -> "PersistentLazyAVLTree.Node":
22 node = PersistentLazyAVLTree.Node(self.key, self.lazy)
23 node.data = self.data
24 node.left = self.left
25 node.right = self.right
26 node.rev = self.rev
27 node.height = self.height
28 node.size = self.size
29 return node
30
[docs]
31 def balance(self) -> int:
32 return (
33 (0 if self.right is None else -self.right.height)
34 if self.left is None
35 else (
36 self.left.height
37 if self.right is None
38 else self.left.height - self.right.height
39 )
40 )
41
42 def __str__(self):
43 if self.left is None and self.right is None:
44 return f"key={self.key}, height={self.height}, size={self.size}, data={self.data}, lazy={self.lazy}, rev={self.rev}\n"
45 return f"key={self.key}, height={self.height}, size={self.size}, data={self.data}, lazy={self.lazy}, rev={self.rev},\n left:{self.left},\n right:{self.right}\n"
46
47 __repr__ = __str__
48
49 def __init__(
50 self,
51 a: Iterable[T],
52 op: Callable[[T, T], T],
53 mapping: Callable[[F, T], T],
54 composition: Callable[[F, F], F],
55 e: T,
56 id: F,
57 _root: Optional[Node] = None,
58 ) -> None:
59 self.root: Optional[PersistentLazyAVLTree.Node] = _root
60 self.op: Callable[[T, T], T] = op
61 self.mapping: Callable[[F, T], T] = mapping
62 self.composition: Callable[[F, F], F] = composition
63 self.e: T = e
64 self.id: F = id
65 a = list(a)
66 if a:
67 self._build(list(a))
68
69 def _build(self, a: list[T]) -> None:
70 Node = PersistentLazyAVLTree.Node
71
72 def build(l: int, r: int) -> PersistentLazyAVLTree.Node:
73 mid = (l + r) >> 1
74 node = Node(a[mid], id)
75 if l != mid:
76 node.left = build(l, mid)
77 if mid + 1 != r:
78 node.right = build(mid + 1, r)
79 self._update(node)
80 return node
81
82 id = self.id
83 self.root = build(0, len(a))
84
85 def _propagate(self, node: Node) -> None:
86 if node.rev:
87 node.rev = 0
88 l = node.left.copy() if node.left else None
89 r = node.right.copy() if node.right else None
90 node.left, node.right = r, l
91 if l:
92 l.rev ^= 1
93 if r:
94 r.rev ^= 1
95 if node.lazy != self.id:
96 lazy = node.lazy
97 node.lazy = self.id
98 if node.left:
99 l = node.left.copy()
100 l.data = self.mapping(lazy, l.data)
101 l.key = self.mapping(lazy, l.key)
102 l.lazy = self.composition(lazy, l.lazy)
103 node.left = l
104 if node.right:
105 r = node.right.copy()
106 r.data = self.mapping(lazy, r.data)
107 r.key = self.mapping(lazy, r.key)
108 r.lazy = self.composition(lazy, r.lazy)
109 node.right = r
110
111 def _update(self, node: Node) -> None:
112 node.size = 1
113 node.data = node.key
114 node.height = 1
115 if node.left:
116 node.size += node.left.size
117 node.data = self.op(node.left.data, node.data)
118 node.height = max(node.left.height + 1, 1)
119 if node.right:
120 node.size += node.right.size
121 node.data = self.op(node.data, node.right.data)
122 node.height = max(node.height, node.right.height + 1)
123
124 def _rotate_right(self, node: Node) -> Node:
125 assert node.left
126 u = node.left.copy()
127 node.left = u.right
128 u.right = node
129 self._update(node)
130 self._update(u)
131 return u
132
133 def _rotate_left(self, node: Node) -> Node:
134 assert node.right
135 u = node.right.copy()
136 node.right = u.left
137 u.left = node
138 self._update(node)
139 self._update(u)
140 return u
141
142 def _balance_left(self, node: Node) -> Node:
143 assert node.right
144 self._propagate(node.right)
145 node.right = node.right.copy()
146 u = node.right
147 if u.balance() == 1:
148 assert u.left
149 self._propagate(u.left)
150 node.right = self._rotate_right(u)
151 u = self._rotate_left(node)
152 return u
153
154 def _balance_right(self, node: Node) -> Node:
155 assert node.left
156 self._propagate(node.left)
157 node.left = node.left.copy()
158 u = node.left
159 if u.balance() == -1:
160 assert u.right
161 self._propagate(u.right)
162 node.left = self._rotate_left(u)
163 u = self._rotate_right(node)
164 return u
165
166 def _merge_with_root(
167 self, l: Optional[Node], root: Node, r: Optional[Node]
168 ) -> Node:
169 diff = 0
170 if l:
171 diff += l.height
172 if r:
173 diff -= r.height
174 if diff > 1:
175 assert l
176 self._propagate(l)
177 l = l.copy()
178 l.right = self._merge_with_root(l.right, root, r)
179 self._update(l)
180 if l.balance() == -2:
181 return self._balance_left(l)
182 return l
183 if diff < -1:
184 assert r
185 self._propagate(r)
186 r = r.copy()
187 r.left = self._merge_with_root(l, root, r.left)
188 self._update(r)
189 if r.balance() == 2:
190 return self._balance_right(r)
191 return r
192 root = root.copy()
193 root.left = l
194 root.right = r
195 self._update(root)
196 return root
197
198 def _merge_node(self, l: Optional[Node], r: Optional[Node]) -> Optional[Node]:
199 if l is None and r is None:
200 return None
201 if l is None:
202 return r.copy()
203 if r is None:
204 return l.copy()
205 l = l.copy()
206 r = r.copy()
207 l, root = self._pop_right(l)
208 return self._merge_with_root(l, root, r)
209
[docs]
210 def merge(self, other: "PersistentLazyAVLTree") -> "PersistentLazyAVLTree":
211 root = self._merge_node(self.root, other.root)
212 return self._new(root)
213
214 def _pop_right(self, node: Node) -> tuple[Node, Node]:
215 path = []
216 self._propagate(node)
217 node = node.copy()
218 mx = node
219 while node.right:
220 path.append(node)
221 self._propagate(node.right)
222 node = node.right.copy()
223 mx = node
224 path.append(node.left.copy() if node.left else None)
225 for _ in range(len(path) - 1):
226 node = path.pop()
227 if node is None:
228 path[-1].right = None
229 self._update(path[-1])
230 continue
231 b = node.balance()
232 if b == 2:
233 path[-1].right = self._balance_right(node)
234 elif b == -2:
235 path[-1].right = self._balance_left(node)
236 else:
237 path[-1].right = node
238 self._update(path[-1])
239 if path[0] is not None:
240 b = path[0].balance()
241 if b == 2:
242 path[0] = self._balance_right(path[0])
243 elif b == -2:
244 path[0] = self._balance_left(path[0])
245 mx.left = None
246 self._update(mx)
247 return path[0], mx
248
249 def _split_node(
250 self, node: Optional[Node], k: int
251 ) -> tuple[Optional[Node], Optional[Node]]:
252 if node is None:
253 return None, None
254 self._propagate(node)
255 tmp = k if node.left is None else k - node.left.size
256 l, r = None, None
257 if tmp == 0:
258 return node.left, self._merge_with_root(None, node, node.right)
259 elif tmp < 0:
260 l, r = self._split_node(node.left, k)
261 return l, self._merge_with_root(r, node, node.right)
262 else:
263 l, r = self._split_node(node.right, tmp - 1)
264 return self._merge_with_root(node.left, node, l), r
265
[docs]
266 def split(self, k: int) -> tuple["PersistentLazyAVLTree", "PersistentLazyAVLTree"]:
267 l, r = self._split_node(self.root, k)
268 return self._new(l), self._new(r)
269
270 def _new(
271 self, root: Optional["PersistentLazyAVLTree.Node"]
272 ) -> "PersistentLazyAVLTree":
273 return PersistentLazyAVLTree(
274 [], self.op, self.mapping, self.composition, self.e, self.id, root
275 )
276
[docs]
277 def apply(self, l: int, r: int, f: F) -> "PersistentLazyAVLTree":
278 if l >= r or (not self.root):
279 return self._new(self.root.copy() if self.root else None)
280 root = self.root.copy()
281 stack: list[
282 Union[
283 PersistentLazyAVLTree.Node, tuple[PersistentLazyAVLTree.Node, int, int]
284 ]
285 ] = [(root), (root, 0, len(self))]
286 while stack:
287 if isinstance(stack[-1], tuple):
288 node, left, right = stack.pop()
289 if right <= l or r <= left:
290 continue
291 self._propagate(node)
292 if l <= left and right < r:
293 node.key = self.mapping(f, node.key)
294 node.data = self.mapping(f, node.data)
295 node.lazy = (
296 f if node.lazy == self.id else self.composition(f, node.lazy)
297 )
298 else:
299 lsize = node.left.size if node.left else 0
300 stack.append(node)
301 if node.left:
302 left_copy = node.left.copy()
303 node.left = left_copy
304 stack.append((left_copy, left, left + lsize))
305 if l <= left + lsize < r:
306 node.key = self.mapping(f, node.key)
307 if node.right:
308 r_copy = node.right.copy()
309 node.right = r_copy
310 stack.append((r_copy, left + lsize + 1, right))
311 else:
312 self._update(stack.pop())
313 return self._new(root)
314
[docs]
315 def prod(self, l: int, r) -> T:
316 if l >= r or (not self.root):
317 return self.e
318
319 def dfs(node: PersistentLazyAVLTree.Node, left: int, right: int) -> T:
320 if right <= l or r <= left:
321 return self.e
322 self._propagate(node)
323 if l <= left and right < r:
324 return node.data
325 lsize = node.left.size if node.left else 0
326 res = self.e
327 if node.left:
328 res = dfs(node.left, left, left + lsize)
329 if l <= left + lsize < r:
330 res = self.op(res, node.key)
331 if node.right:
332 res = self.op(res, dfs(node.right, left + lsize + 1, right))
333 return res
334
335 return dfs(self.root, 0, len(self))
336
[docs]
337 def insert(self, k: int, key: T) -> "PersistentLazyAVLTree":
338 s, t = self._split_node(self.root, k)
339 root = self._merge_with_root(s, PersistentLazyAVLTree.Node(key, self.id), t)
340 return self._new(root)
341
[docs]
342 def pop(self, k: int) -> tuple["PersistentLazyAVLTree", T]:
343 s, t = self._split_node(self.root, k + 1)
344 assert s
345 s, tmp = self._pop_right(s)
346 root = self._merge_node(s, t)
347 return self._new(root), tmp.key
348
[docs]
349 def reverse(self, l: int, r: int) -> "PersistentLazyAVLTree":
350 if l >= r:
351 return self._new(self.root.copy() if self.root else None)
352 s, t = self._split_node(self.root, r)
353 u, s = self._split_node(s, l)
354 assert s
355 s.rev ^= 1
356 root = self._merge_node(self._merge_node(u, s), t)
357 return self._new(root)
358
[docs]
359 def tolist(self) -> list[T]:
360 node = self.root
361 stack: list[PersistentLazyAVLTree.Node] = []
362 a: list[T] = []
363 while stack or node:
364 if node:
365 self._propagate(node)
366 stack.append(node)
367 node = node.left
368 else:
369 node = stack.pop()
370 a.append(node.key)
371 node = node.right
372 return a
373
374 def __getitem__(self, k: int) -> T:
375 if k < 0:
376 k += len(self)
377 node = self.root
378 while True:
379 assert node
380 self._propagate(node)
381 t = 0 if node.left is None else node.left.size
382 if t == k:
383 return node.key
384 elif t < k:
385 k -= t + 1
386 node = node.right
387 else:
388 node = node.left
389
390 def __len__(self):
391 return self.root.size if self.root else 0
392
393 def __str__(self):
394 return str(self.tolist())
395
396 def __repr__(self):
397 return f"PersistentLazyAVLTree({self})"