1from typing import Generic, Union, TypeVar, Callable, Iterable, Optional
2
3T = TypeVar("T")
4F = TypeVar("F")
5
6
[docs]
7class LazySplayTree(Generic[T, F]):
8
9 class _Node:
10
11 def __init__(self, key: T, lazy: F) -> None:
12 self.key: T = key
13 self.data: T = key
14 self.rdata: T = key
15 self.lazy: F = lazy
16 self.left: Optional["LazySplayTree._Node"] = None
17 self.right: Optional["LazySplayTree._Node"] = None
18 self.par: Optional["LazySplayTree._Node"] = None
19 self.size: int = 1
20 self.rev: int = 0
21
22 def __init__(
23 self,
24 n_or_a: Union[int, Iterable[T]],
25 op: Callable[[T, T], T],
26 mapping: Callable[[F, T], T],
27 composition: Callable[[F, F], F],
28 e: T,
29 id: F,
30 _root: Optional[_Node] = None,
31 ) -> None:
32 """構築します。
33 :math:`O(n)` です。
34
35 Args:
36 n_or_a (Union[int, Iterable[T]]): ``n`` のとき、 ``e`` から長さ ``n`` で構築します。
37 ``a`` のとき、 ``a`` から構築します。
38 op (Callable[[T, T], T]): 遅延セグ木のあれです。
39 mapping (Callable[[F, T], T]): 遅延セグ木のあれです。
40 composition (Callable[[F, F], F]): 遅延セグ木のあれです。
41 e (T): 遅延セグ木のあれです。
42 id (F): 遅延セグ木のあれです。
43 """
44 self.op = op
45 self.mapping = mapping
46 self.composition = composition
47 self.e = e
48 self.id = id
49 self.root = _root
50 if _root:
51 return
52 a = n_or_a
53 if isinstance(a, int):
54 a = [e for _ in range(a)]
55 elif not isinstance(a, list):
56 a = list(a)
57 if a:
58 self._build(a)
59
60 def _build(self, a: list[T]) -> None:
61 _Node = LazySplayTree._Node
62 id = self.id
63
64 def build(l: int, r: int) -> LazySplayTree._Node:
65 mid = (l + r) >> 1
66 node = _Node(a[mid], id)
67 if l != mid:
68 node.left = build(l, mid)
69 node.left.par = node
70 if mid + 1 != r:
71 node.right = build(mid + 1, r)
72 node.right.par = node
73 self._update(node)
74 return node
75
76 self.root = build(0, len(a))
77
78 def _rotate(self, node: _Node) -> None:
79 pnode = node.par
80 gnode = pnode.par
81 if gnode:
82 if gnode.left is pnode:
83 gnode.left = node
84 else:
85 gnode.right = node
86 node.par = gnode
87 if pnode.left is node:
88 pnode.left = node.right
89 if node.right:
90 node.right.par = pnode
91 node.right = pnode
92 else:
93 pnode.right = node.left
94 if node.left:
95 node.left.par = pnode
96 node.left = pnode
97 pnode.par = node
98 self._update_double(pnode, node)
99
100 def _propagate_rev(self, node: Optional[_Node]) -> None:
101 if not node:
102 return
103 node.rev ^= 1
104
105 def _propagate_lazy(self, node: Optional[_Node], f: F) -> None:
106 if not node:
107 return
108 node.key = self.mapping(f, node.key)
109 node.data = self.mapping(f, node.data)
110 node.rdata = self.mapping(f, node.rdata)
111 node.lazy = f if node.lazy == self.id else self.composition(f, node.lazy)
112
113 def _propagate(self, node: Optional[_Node]) -> None:
114 if not node:
115 return
116 if node.rev:
117 node.data, node.rdata = node.rdata, node.data
118 node.left, node.right = node.right, node.left
119 self._propagate_rev(node.left)
120 self._propagate_rev(node.right)
121 node.rev = 0
122 if node.lazy != self.id:
123 self._propagate_lazy(node.left, node.lazy)
124 self._propagate_lazy(node.right, node.lazy)
125 node.lazy = self.id
126
127 def _update_double(self, pnode: _Node, node: _Node) -> None:
128 node.data = pnode.data
129 node.rdata = pnode.rdata
130 node.size = pnode.size
131 self._update(pnode)
132
133 def _update(self, node: _Node) -> None:
134 node.data = node.key
135 node.rdata = node.key
136 node.size = 1
137 if node.left:
138 node.data = self.op(node.left.data, node.data)
139 node.rdata = self.op(node.rdata, node.left.rdata)
140 node.size += node.left.size
141 if node.right:
142 node.data = self.op(node.data, node.right.data)
143 node.rdata = self.op(node.right.rdata, node.rdata)
144 node.size += node.right.size
145
146 def _splay(self, node: _Node) -> None:
147 # while node.par and node.par.par:
148 # pnode = node.par
149 # self._rotate(pnode if (pnode.par.left is pnode) == (pnode.left is node) else node)
150 # self._rotate(node)
151 # if node.par:
152 # self._rotate(node)
153 while node.par:
154 pnode = node.par
155 if pnode:
156 self._rotate(
157 pnode if (pnode.par.left is pnode) == (pnode.left is node) else node
158 )
159 self._rotate(node)
160
[docs]
161 def kth_splay(self, node: Optional[_Node], k: int) -> None:
162 if k < 0:
163 k += len(self)
164 while True:
165 self._propagate(node)
166 t = node.left.size if node.left else 0
167 if t == k:
168 break
169 if t > k:
170 node = node.left
171 else:
172 node = node.right
173 k -= t + 1
174 self._splay(node)
175 return node
176
177 def _left_splay(self, node: Optional[_Node]) -> Optional[_Node]:
178 self._propagate(node)
179 if not node or not node.left:
180 return node
181 while node.left:
182 node = node.left
183 self._propagate(node)
184 self._splay(node)
185 return node
186
187 def _right_splay(self, node: Optional[_Node]) -> Optional[_Node]:
188 self._propagate(node)
189 if not node or not node.right:
190 return node
191 while node.right:
192 node = node.right
193 self._propagate(node)
194 self._splay(node)
195 return node
196
[docs]
197 def merge(self, other: "LazySplayTree") -> None:
198 """``other`` を後ろに連結します。
199 償却 :math:`O(\\log{n})` です。
200
201 Args:
202 other (LazySplayTree):
203 """
204 if not self.root:
205 self.root = other.root
206 return
207 if not other.root:
208 return
209 self.root = self._right_splay(self.root)
210 self.root.right = other.root
211 other.root.par = self.root
212 self._update(self.root)
213
[docs]
214 def split(self, k: int) -> tuple["LazySplayTree", "LazySplayTree"]:
215 """位置 ``k`` で split します。
216 償却 :math:`O(\\log{n})` です。
217
218 Returns:
219 tuple['LazySplayTree', 'LazySplayTree']:
220 """
221 left, right = self._internal_split(self.root, k)
222 left_splay = LazySplayTree(
223 0, self.op, self.mapping, self.composition, self.e, self.id, left
224 )
225 right_splay = LazySplayTree(
226 0, self.op, self.mapping, self.composition, self.e, self.id, right
227 )
228 return left_splay, right_splay
229
230 def _internal_split(self, k: int) -> tuple[_Node, _Node]:
231 if k == len(self):
232 return self.root, None
233 right = self.kth_splay(self.root, k)
234 left = right.left
235 if left:
236 left.par = None
237 right.left = None
238 self._update(right)
239 return left, right
240
241 def _internal_merge(
242 self, left: Optional[_Node], right: Optional[_Node]
243 ) -> Optional[_Node]:
244 # need (not right) or (not right.left)
245 if not right:
246 return left
247 assert right.left is None
248 right.left = left
249 if left:
250 left.par = right
251 self._update(right)
252 return right
253
[docs]
254 def reverse(self, l: int, r: int) -> None:
255 """区間 ``[l, r)`` を反転します。
256 償却 :math:`O(\\log{n})` です。
257
258 Args:
259 l (int):
260 r (int):
261 """
262 assert (
263 0 <= l <= r <= len(self)
264 ), f"IndexError: {self.__class__.__name__}.reverse({l}, {r}), len={len(self)}"
265 left, right = self._internal_split(r)
266 if l == 0:
267 self._propagate_rev(left)
268 else:
269 left = self.kth_splay(left, l - 1)
270 self._propagate_rev(left.right)
271 self.root = self._internal_merge(left, right)
272
[docs]
273 def all_reverse(self) -> None:
274 """区間 ``[0, n)`` を反転します。
275 :math:`O(1)` です。
276 """
277 self._propagate_rev(self.root)
278
[docs]
279 def apply(self, l: int, r: int, f: F) -> None:
280 """区間 ``[l, r)`` に ``f`` を作用します。
281 償却 :math:`O(\\log{n})` です。
282
283 Args:
284 l (int):
285 r (int):
286 f (F): 作用素です。
287 """
288 assert (
289 0 <= l <= r <= len(self)
290 ), f"IndexError: {self.__class__.__name__}.apply({l}, {r}, {f}), len={len(self)}"
291 left, right = self._internal_split(r)
292 if l == 0:
293 self._propagate_lazy(left, f)
294 else:
295 left = self.kth_splay(left, l - 1)
296 self._propagate_lazy(left.right, f)
297 self._update(left)
298 self.root = self._internal_merge(left, right)
299
[docs]
300 def all_apply(self, f: F) -> None:
301 """区間 ``[0, n)`` に ``f`` を作用します。
302 :math:`O(1)` です。
303 """
304 self._propagate_lazy(self.root, f)
305
[docs]
306 def prod(self, l: int, r: int) -> T:
307 """区間 ``[l, r)`` の総積を求めます。
308 償却 :math:`O(\\log{n})` です。
309 """
310 assert (
311 0 <= l <= r <= len(self)
312 ), f"IndexError: {self.__class__.__name__}.prod({l}, {r}), len={len(self)}"
313 if l == r:
314 return self.e
315 left, right = self._internal_split(r)
316 if l == 0:
317 res = left.data
318 else:
319 left = self.kth_splay(left, l - 1)
320 res = left.right.data
321 self.root = self._internal_merge(left, right)
322 return res
323
[docs]
324 def all_prod(self) -> T:
325 """区間 ``[0, n)`` の総積を求めます。
326 :math:`O(1)` です。
327 """
328 self._propagate(self.root)
329 return self.root.data if self.root else self.e
330
[docs]
331 def insert(self, k: int, key: T) -> None:
332 """位置 ``k`` に ``key`` を挿入します。
333 償却 :math:`O(\\log{n})` です。
334
335 Args:
336 k (int):
337 key (T):
338 """
339 assert 0 <= k <= len(self)
340 node = self._Node(key, self.id)
341 if not self.root:
342 self.root = node
343 return
344 if k >= len(self):
345 root = self.kth_splay(self.root, len(self) - 1)
346 node.left = root
347 else:
348 root = self.kth_splay(self.root, k)
349 if root.left:
350 node.left = root.left
351 root.left.par = node
352 root.left = None
353 self._update(root)
354 node.right = root
355 root.par = node
356 self.root = node
357 self._update(self.root)
358
[docs]
359 def append(self, key: T) -> None:
360 """末尾に ``key`` を追加します。
361 償却 :math:`O(\\log{n})` です。
362
363 Args:
364 key (T):
365 """
366 node = self._right_splay(self.root)
367 self.root = self._Node(key, self.id)
368 self.root.left = node
369 if node:
370 node.par = self.root
371 self._update(self.root)
372
[docs]
373 def appendleft(self, key: T) -> None:
374 """先頭に ``key`` を追加します。
375 償却 :math:`O(\\log{n})` です。
376
377 Args:
378 key (T):
379 """
380 node = self._left_splay(self.root)
381 self.root = self._Node(key, self.id)
382 self.root.right = node
383 if node:
384 node.par = self.root
385 self._update(self.root)
386
[docs]
387 def pop(self, k: int = -1) -> T:
388 """位置 ``k`` の要素を削除し、その値を返します。
389 償却 :math:`O(\\log{n})` です。
390
391 Args:
392 k (int, optional): 指定するインデックスです。 Defaults to -1.
393 """
394 if k == -1:
395 node = self._right_splay(self.root)
396 if node.left:
397 node.left.par = None
398 self.root = node.left
399 return node.key
400 root = self.kth_splay(self.root, k)
401 res = root.key
402 if root.left and root.right:
403 node = self._right_splay(root.left)
404 node.par = None
405 node.right = root.right
406 if node.right:
407 node.right.par = node
408 self._update(node)
409 self.root = node
410 else:
411 self.root = root.right if root.right else root.left
412 if self.root:
413 self.root.par = None
414 return res
415
[docs]
416 def popleft(self) -> T:
417 """先頭の要素を削除し、その値を返します。
418 償却 :math:`O(\\log{n})` です。
419
420 Returns:
421 T:
422 """
423 node = self._left_splay(self.root)
424 self.root = node.right
425 if node.right:
426 node.right.par = None
427 return node.key
428
[docs]
429 def copy(self) -> "LazySplayTree":
430 """コピーします。
431
432 Note:
433 償却 :math:`O(n)` です。
434
435 Returns:
436 LazySplayTree:
437 """
438 return LazySplayTree(
439 self.tolist(), self.op, self.mapping, self.composition, self.e, self.id
440 )
441
[docs]
442 def clear(self) -> None:
443 """全ての要素を削除します。
444 :math:`O(1)` です。
445 """
446 self.root = None
447
[docs]
448 def tolist(self) -> list[T]:
449 """``list`` にして返します。
450 :math:`O(n)` です。非再帰です。
451
452 Returns:
453 list[T]:
454 """
455 node = self.root
456 stack = []
457 a = []
458 while stack or node:
459 if node:
460 self._propagate(node)
461 stack.append(node)
462 node = node.left
463 else:
464 node = stack.pop()
465 a.append(node.key)
466 node = node.right
467 return a
468
[docs]
469 def __setitem__(self, k: int, key: T) -> None:
470 """位置 ``k`` の要素を値 ``key`` で更新します。
471 償却 :math:`O(\\log{n})` です。
472
473 Args:
474 k (int):
475 key (T):
476 """
477 self.root = self.kth_splay(self.root, k)
478 self.root.key = key
479 self._update(self.root)
480
[docs]
481 def __getitem__(self, k: int) -> T:
482 """位置 ``k`` の値を返します。
483 償却 :math:`O(\\log{n})` です。
484
485 Args:
486 k (int):
487 key (T):
488 """
489 self.root = self.kth_splay(self.root, k)
490 return self.root.key
491
492 def __iter__(self):
493 self.__iter = 0
494 return self
495
496 def __next__(self):
497 if self.__iter == len(self):
498 raise StopIteration
499 res = self[self.__iter]
500 self.__iter += 1
501 return res
502
503 def __reversed__(self):
504 for i in range(len(self)):
505 yield self[-i - 1]
506
[docs]
507 def __len__(self):
508 """要素数を返します。
509 :math:`O(1)` です。
510
511 Returns:
512 int:
513 """
514 return self.root.size if self.root else 0
515
516 def __str__(self):
517 return str(self.tolist())
518
519 def __bool__(self):
520 return self.root is not None
521
522 def __repr__(self):
523 return f"{self.__class__.__name__}({self})"