Source code for titan_pylib.data_structures.splay_tree.lazy_splay_tree

  1from typing import Generic, Union, TypeVar, Callable, Iterable, Optional
  2
  3T = TypeVar("T")
  4F = TypeVar("F")
  5
  6
[docs] 7class LazySplayTree(Generic[T, F]): 8 9 class _Node: 10 11 def __init__(self, key: T, lazy: F) -> None: 12 self.key: T = key 13 self.data: T = key 14 self.rdata: T = key 15 self.lazy: F = lazy 16 self.left: Optional["LazySplayTree._Node"] = None 17 self.right: Optional["LazySplayTree._Node"] = None 18 self.par: Optional["LazySplayTree._Node"] = None 19 self.size: int = 1 20 self.rev: int = 0 21 22 def __init__( 23 self, 24 n_or_a: Union[int, Iterable[T]], 25 op: Callable[[T, T], T], 26 mapping: Callable[[F, T], T], 27 composition: Callable[[F, F], F], 28 e: T, 29 id: F, 30 _root: Optional[_Node] = None, 31 ) -> None: 32 """構築します。 33 :math:`O(n)` です。 34 35 Args: 36 n_or_a (Union[int, Iterable[T]]): ``n`` のとき、 ``e`` から長さ ``n`` で構築します。 37 ``a`` のとき、 ``a`` から構築します。 38 op (Callable[[T, T], T]): 遅延セグ木のあれです。 39 mapping (Callable[[F, T], T]): 遅延セグ木のあれです。 40 composition (Callable[[F, F], F]): 遅延セグ木のあれです。 41 e (T): 遅延セグ木のあれです。 42 id (F): 遅延セグ木のあれです。 43 """ 44 self.op = op 45 self.mapping = mapping 46 self.composition = composition 47 self.e = e 48 self.id = id 49 self.root = _root 50 if _root: 51 return 52 a = n_or_a 53 if isinstance(a, int): 54 a = [e for _ in range(a)] 55 elif not isinstance(a, list): 56 a = list(a) 57 if a: 58 self._build(a) 59 60 def _build(self, a: list[T]) -> None: 61 _Node = LazySplayTree._Node 62 id = self.id 63 64 def build(l: int, r: int) -> LazySplayTree._Node: 65 mid = (l + r) >> 1 66 node = _Node(a[mid], id) 67 if l != mid: 68 node.left = build(l, mid) 69 node.left.par = node 70 if mid + 1 != r: 71 node.right = build(mid + 1, r) 72 node.right.par = node 73 self._update(node) 74 return node 75 76 self.root = build(0, len(a)) 77 78 def _rotate(self, node: _Node) -> None: 79 pnode = node.par 80 gnode = pnode.par 81 if gnode: 82 if gnode.left is pnode: 83 gnode.left = node 84 else: 85 gnode.right = node 86 node.par = gnode 87 if pnode.left is node: 88 pnode.left = node.right 89 if node.right: 90 node.right.par = pnode 91 node.right = pnode 92 else: 93 pnode.right = node.left 94 if node.left: 95 node.left.par = pnode 96 node.left = pnode 97 pnode.par = node 98 self._update_double(pnode, node) 99 100 def _propagate_rev(self, node: Optional[_Node]) -> None: 101 if not node: 102 return 103 node.rev ^= 1 104 105 def _propagate_lazy(self, node: Optional[_Node], f: F) -> None: 106 if not node: 107 return 108 node.key = self.mapping(f, node.key) 109 node.data = self.mapping(f, node.data) 110 node.rdata = self.mapping(f, node.rdata) 111 node.lazy = f if node.lazy == self.id else self.composition(f, node.lazy) 112 113 def _propagate(self, node: Optional[_Node]) -> None: 114 if not node: 115 return 116 if node.rev: 117 node.data, node.rdata = node.rdata, node.data 118 node.left, node.right = node.right, node.left 119 self._propagate_rev(node.left) 120 self._propagate_rev(node.right) 121 node.rev = 0 122 if node.lazy != self.id: 123 self._propagate_lazy(node.left, node.lazy) 124 self._propagate_lazy(node.right, node.lazy) 125 node.lazy = self.id 126 127 def _update_double(self, pnode: _Node, node: _Node) -> None: 128 node.data = pnode.data 129 node.rdata = pnode.rdata 130 node.size = pnode.size 131 self._update(pnode) 132 133 def _update(self, node: _Node) -> None: 134 node.data = node.key 135 node.rdata = node.key 136 node.size = 1 137 if node.left: 138 node.data = self.op(node.left.data, node.data) 139 node.rdata = self.op(node.rdata, node.left.rdata) 140 node.size += node.left.size 141 if node.right: 142 node.data = self.op(node.data, node.right.data) 143 node.rdata = self.op(node.right.rdata, node.rdata) 144 node.size += node.right.size 145 146 def _splay(self, node: _Node) -> None: 147 # while node.par and node.par.par: 148 # pnode = node.par 149 # self._rotate(pnode if (pnode.par.left is pnode) == (pnode.left is node) else node) 150 # self._rotate(node) 151 # if node.par: 152 # self._rotate(node) 153 while node.par: 154 pnode = node.par 155 if pnode: 156 self._rotate( 157 pnode if (pnode.par.left is pnode) == (pnode.left is node) else node 158 ) 159 self._rotate(node) 160
[docs] 161 def kth_splay(self, node: Optional[_Node], k: int) -> None: 162 if k < 0: 163 k += len(self) 164 while True: 165 self._propagate(node) 166 t = node.left.size if node.left else 0 167 if t == k: 168 break 169 if t > k: 170 node = node.left 171 else: 172 node = node.right 173 k -= t + 1 174 self._splay(node) 175 return node
176 177 def _left_splay(self, node: Optional[_Node]) -> Optional[_Node]: 178 self._propagate(node) 179 if not node or not node.left: 180 return node 181 while node.left: 182 node = node.left 183 self._propagate(node) 184 self._splay(node) 185 return node 186 187 def _right_splay(self, node: Optional[_Node]) -> Optional[_Node]: 188 self._propagate(node) 189 if not node or not node.right: 190 return node 191 while node.right: 192 node = node.right 193 self._propagate(node) 194 self._splay(node) 195 return node 196
[docs] 197 def merge(self, other: "LazySplayTree") -> None: 198 """``other`` を後ろに連結します。 199 償却 :math:`O(\\log{n})` です。 200 201 Args: 202 other (LazySplayTree): 203 """ 204 if not self.root: 205 self.root = other.root 206 return 207 if not other.root: 208 return 209 self.root = self._right_splay(self.root) 210 self.root.right = other.root 211 other.root.par = self.root 212 self._update(self.root)
213
[docs] 214 def split(self, k: int) -> tuple["LazySplayTree", "LazySplayTree"]: 215 """位置 ``k`` で split します。 216 償却 :math:`O(\\log{n})` です。 217 218 Returns: 219 tuple['LazySplayTree', 'LazySplayTree']: 220 """ 221 left, right = self._internal_split(self.root, k) 222 left_splay = LazySplayTree( 223 0, self.op, self.mapping, self.composition, self.e, self.id, left 224 ) 225 right_splay = LazySplayTree( 226 0, self.op, self.mapping, self.composition, self.e, self.id, right 227 ) 228 return left_splay, right_splay
229 230 def _internal_split(self, k: int) -> tuple[_Node, _Node]: 231 if k == len(self): 232 return self.root, None 233 right = self.kth_splay(self.root, k) 234 left = right.left 235 if left: 236 left.par = None 237 right.left = None 238 self._update(right) 239 return left, right 240 241 def _internal_merge( 242 self, left: Optional[_Node], right: Optional[_Node] 243 ) -> Optional[_Node]: 244 # need (not right) or (not right.left) 245 if not right: 246 return left 247 assert right.left is None 248 right.left = left 249 if left: 250 left.par = right 251 self._update(right) 252 return right 253
[docs] 254 def reverse(self, l: int, r: int) -> None: 255 """区間 ``[l, r)`` を反転します。 256 償却 :math:`O(\\log{n})` です。 257 258 Args: 259 l (int): 260 r (int): 261 """ 262 assert ( 263 0 <= l <= r <= len(self) 264 ), f"IndexError: {self.__class__.__name__}.reverse({l}, {r}), len={len(self)}" 265 left, right = self._internal_split(r) 266 if l == 0: 267 self._propagate_rev(left) 268 else: 269 left = self.kth_splay(left, l - 1) 270 self._propagate_rev(left.right) 271 self.root = self._internal_merge(left, right)
272
[docs] 273 def all_reverse(self) -> None: 274 """区間 ``[0, n)`` を反転します。 275 :math:`O(1)` です。 276 """ 277 self._propagate_rev(self.root)
278
[docs] 279 def apply(self, l: int, r: int, f: F) -> None: 280 """区間 ``[l, r)`` に ``f`` を作用します。 281 償却 :math:`O(\\log{n})` です。 282 283 Args: 284 l (int): 285 r (int): 286 f (F): 作用素です。 287 """ 288 assert ( 289 0 <= l <= r <= len(self) 290 ), f"IndexError: {self.__class__.__name__}.apply({l}, {r}, {f}), len={len(self)}" 291 left, right = self._internal_split(r) 292 if l == 0: 293 self._propagate_lazy(left, f) 294 else: 295 left = self.kth_splay(left, l - 1) 296 self._propagate_lazy(left.right, f) 297 self._update(left) 298 self.root = self._internal_merge(left, right)
299
[docs] 300 def all_apply(self, f: F) -> None: 301 """区間 ``[0, n)`` に ``f`` を作用します。 302 :math:`O(1)` です。 303 """ 304 self._propagate_lazy(self.root, f)
305
[docs] 306 def prod(self, l: int, r: int) -> T: 307 """区間 ``[l, r)`` の総積を求めます。 308 償却 :math:`O(\\log{n})` です。 309 """ 310 assert ( 311 0 <= l <= r <= len(self) 312 ), f"IndexError: {self.__class__.__name__}.prod({l}, {r}), len={len(self)}" 313 if l == r: 314 return self.e 315 left, right = self._internal_split(r) 316 if l == 0: 317 res = left.data 318 else: 319 left = self.kth_splay(left, l - 1) 320 res = left.right.data 321 self.root = self._internal_merge(left, right) 322 return res
323
[docs] 324 def all_prod(self) -> T: 325 """区間 ``[0, n)`` の総積を求めます。 326 :math:`O(1)` です。 327 """ 328 self._propagate(self.root) 329 return self.root.data if self.root else self.e
330
[docs] 331 def insert(self, k: int, key: T) -> None: 332 """位置 ``k`` に ``key`` を挿入します。 333 償却 :math:`O(\\log{n})` です。 334 335 Args: 336 k (int): 337 key (T): 338 """ 339 assert 0 <= k <= len(self) 340 node = self._Node(key, self.id) 341 if not self.root: 342 self.root = node 343 return 344 if k >= len(self): 345 root = self.kth_splay(self.root, len(self) - 1) 346 node.left = root 347 else: 348 root = self.kth_splay(self.root, k) 349 if root.left: 350 node.left = root.left 351 root.left.par = node 352 root.left = None 353 self._update(root) 354 node.right = root 355 root.par = node 356 self.root = node 357 self._update(self.root)
358
[docs] 359 def append(self, key: T) -> None: 360 """末尾に ``key`` を追加します。 361 償却 :math:`O(\\log{n})` です。 362 363 Args: 364 key (T): 365 """ 366 node = self._right_splay(self.root) 367 self.root = self._Node(key, self.id) 368 self.root.left = node 369 if node: 370 node.par = self.root 371 self._update(self.root)
372
[docs] 373 def appendleft(self, key: T) -> None: 374 """先頭に ``key`` を追加します。 375 償却 :math:`O(\\log{n})` です。 376 377 Args: 378 key (T): 379 """ 380 node = self._left_splay(self.root) 381 self.root = self._Node(key, self.id) 382 self.root.right = node 383 if node: 384 node.par = self.root 385 self._update(self.root)
386
[docs] 387 def pop(self, k: int = -1) -> T: 388 """位置 ``k`` の要素を削除し、その値を返します。 389 償却 :math:`O(\\log{n})` です。 390 391 Args: 392 k (int, optional): 指定するインデックスです。 Defaults to -1. 393 """ 394 if k == -1: 395 node = self._right_splay(self.root) 396 if node.left: 397 node.left.par = None 398 self.root = node.left 399 return node.key 400 root = self.kth_splay(self.root, k) 401 res = root.key 402 if root.left and root.right: 403 node = self._right_splay(root.left) 404 node.par = None 405 node.right = root.right 406 if node.right: 407 node.right.par = node 408 self._update(node) 409 self.root = node 410 else: 411 self.root = root.right if root.right else root.left 412 if self.root: 413 self.root.par = None 414 return res
415
[docs] 416 def popleft(self) -> T: 417 """先頭の要素を削除し、その値を返します。 418 償却 :math:`O(\\log{n})` です。 419 420 Returns: 421 T: 422 """ 423 node = self._left_splay(self.root) 424 self.root = node.right 425 if node.right: 426 node.right.par = None 427 return node.key
428
[docs] 429 def copy(self) -> "LazySplayTree": 430 """コピーします。 431 432 Note: 433 償却 :math:`O(n)` です。 434 435 Returns: 436 LazySplayTree: 437 """ 438 return LazySplayTree( 439 self.tolist(), self.op, self.mapping, self.composition, self.e, self.id 440 )
441
[docs] 442 def clear(self) -> None: 443 """全ての要素を削除します。 444 :math:`O(1)` です。 445 """ 446 self.root = None
447
[docs] 448 def tolist(self) -> list[T]: 449 """``list`` にして返します。 450 :math:`O(n)` です。非再帰です。 451 452 Returns: 453 list[T]: 454 """ 455 node = self.root 456 stack = [] 457 a = [] 458 while stack or node: 459 if node: 460 self._propagate(node) 461 stack.append(node) 462 node = node.left 463 else: 464 node = stack.pop() 465 a.append(node.key) 466 node = node.right 467 return a
468
[docs] 469 def __setitem__(self, k: int, key: T) -> None: 470 """位置 ``k`` の要素を値 ``key`` で更新します。 471 償却 :math:`O(\\log{n})` です。 472 473 Args: 474 k (int): 475 key (T): 476 """ 477 self.root = self.kth_splay(self.root, k) 478 self.root.key = key 479 self._update(self.root)
480
[docs] 481 def __getitem__(self, k: int) -> T: 482 """位置 ``k`` の値を返します。 483 償却 :math:`O(\\log{n})` です。 484 485 Args: 486 k (int): 487 key (T): 488 """ 489 self.root = self.kth_splay(self.root, k) 490 return self.root.key
491 492 def __iter__(self): 493 self.__iter = 0 494 return self 495 496 def __next__(self): 497 if self.__iter == len(self): 498 raise StopIteration 499 res = self[self.__iter] 500 self.__iter += 1 501 return res 502 503 def __reversed__(self): 504 for i in range(len(self)): 505 yield self[-i - 1] 506
[docs] 507 def __len__(self): 508 """要素数を返します。 509 :math:`O(1)` です。 510 511 Returns: 512 int: 513 """ 514 return self.root.size if self.root else 0
515 516 def __str__(self): 517 return str(self.tolist()) 518 519 def __bool__(self): 520 return self.root is not None 521 522 def __repr__(self): 523 return f"{self.__class__.__name__}({self})"