Source code for titan_pylib.data_structures.splay_tree.lazy_splay_tree_array

  1from array import array
  2from typing import (
  3    Generic,
  4    TypeVar,
  5    Callable,
  6    Iterable,
  7    Optional,
  8    Union,
  9    Sequence,
 10)
 11
 12T = TypeVar("T")
 13F = TypeVar("F")
 14
 15
[docs] 16class LazySplayTreeArrayData(Generic[T, F]): 17 18 def __init__( 19 self, 20 op: Optional[Callable[[T, T], T]] = None, 21 mapping: Optional[Callable[[F, T], T]] = None, 22 composition: Optional[Callable[[F, F], F]] = None, 23 e: T = None, 24 id: F = None, 25 ): 26 self.op: Callable[[T, T], T] = (lambda s, t: e) if op is None else op 27 self.mapping: Callable[[F, T], T] = (lambda f, s: e) if op is None else mapping 28 self.composition: Callable[[F, F], F] = ( 29 (lambda f, g: id) if op is None else composition 30 ) 31 self.e: T = e 32 self.id: F = id 33 self.keydata: list[T] = [e, e] 34 self.lazy: list[F] = [id] 35 self.arr: array[int] = array("I", bytes(16)) 36 # left: arr[node<<2] 37 # right: arr[node<<2|1] 38 # size: arr[node<<2|2] 39 # rev: arr[node<<2|3] 40 self.end: int = 1 41
[docs] 42 def reserve(self, n: int) -> None: 43 if n <= 0: 44 return 45 self.keydata += [self.e] * (2 * n) 46 self.lazy += [self.id] * n 47 self.arr += array("I", bytes(16 * n))
48 49
[docs] 50class LazySplayTreeArray(Generic[T, F]): 51 52 def __init__( 53 self, 54 data: "LazySplayTreeArrayData[T, F]", 55 n_or_a: Union[int, Iterable[T]] = 0, 56 _root: int = 0, 57 ) -> None: 58 self.data = data 59 self.root = _root 60 if not n_or_a: 61 return 62 if isinstance(n_or_a, int): 63 a = [data.e] * n_or_a 64 elif not isinstance(n_or_a, Sequence): 65 a = list(n_or_a) 66 else: 67 a = n_or_a 68 if a: 69 self._build(a) 70 71 def _build(self, a: Sequence[T]) -> None: 72 def rec(l: int, r: int) -> int: 73 mid = (l + r) >> 1 74 if l != mid: 75 arr[mid << 2] = rec(l, mid) 76 if mid + 1 != r: 77 arr[mid << 2 | 1] = rec(mid + 1, r) 78 self._update(mid) 79 return mid 80 81 n = len(a) 82 keydata, arr = self.data.keydata, self.data.arr 83 end = self.data.end 84 self.data.reserve(n + end - len(keydata) // 2 + 1) 85 self.data.end += n 86 for i, e in enumerate(a): 87 keydata[end + i << 1] = e 88 keydata[end + i << 1 | 1] = e 89 self.root = rec(end, n + end) 90 91 def _make_node(self, key: T) -> int: 92 data = self.data 93 if data.end >= len(data.arr) // 4: 94 data.keydata.append(key) 95 data.keydata.append(key) 96 data.lazy.append(data.id) 97 data.arr.append(0) 98 data.arr.append(0) 99 data.arr.append(1) 100 data.arr.append(0) 101 else: 102 data.keydata[data.end << 1] = key 103 data.keydata[data.end << 1 | 1] = key 104 data.end += 1 105 return data.end - 1 106 107 def _propagate(self, node: int) -> None: 108 data = self.data 109 arr = data.arr 110 if arr[node << 2 | 3]: 111 arr[node << 2], arr[node << 2 | 1] = arr[node << 2 | 1], arr[node << 2] 112 arr[node << 2 | 3] = 0 113 arr[arr[node << 2] << 2 | 3] ^= 1 114 arr[arr[node << 2 | 1] << 2 | 3] ^= 1 115 nlazy = data.lazy[node] 116 if nlazy == data.id: 117 return 118 lnode, rnode = arr[node << 2], arr[node << 2 | 1] 119 keydata, lazy = data.keydata, data.lazy 120 lazy[node] = data.id 121 if lnode: 122 lazy[lnode] = data.composition(nlazy, lazy[lnode]) 123 lnode <<= 1 124 keydata[lnode] = data.mapping(nlazy, keydata[lnode]) 125 keydata[lnode | 1] = data.mapping(nlazy, keydata[lnode | 1]) 126 if rnode: 127 lazy[rnode] = data.composition(nlazy, lazy[rnode]) 128 rnode <<= 1 129 keydata[rnode] = data.mapping(nlazy, keydata[rnode]) 130 keydata[rnode | 1] = data.mapping(nlazy, keydata[rnode | 1]) 131 132 def _update_triple(self, x: int, y: int, z: int) -> None: 133 data = self.data 134 keydata, arr = data.keydata, data.arr 135 lx, rx = arr[x << 2], arr[x << 2 | 1] 136 ly, ry = arr[y << 2], arr[y << 2 | 1] 137 arr[z << 2 | 2] = arr[x << 2 | 2] 138 arr[x << 2 | 2] = 1 + arr[lx << 2 | 2] + arr[rx << 2 | 2] 139 arr[y << 2 | 2] = 1 + arr[ly << 2 | 2] + arr[ry << 2 | 2] 140 keydata[z << 1 | 1] = keydata[x << 1 | 1] 141 keydata[x << 1 | 1] = data.op( 142 data.op(keydata[lx << 1 | 1], keydata[x << 1]), keydata[rx << 1 | 1] 143 ) 144 keydata[y << 1 | 1] = data.op( 145 data.op(keydata[ly << 1 | 1], keydata[y << 1]), keydata[ry << 1 | 1] 146 ) 147 148 def _update_double(self, x: int, y: int) -> None: 149 data = self.data 150 keydata, arr = data.keydata, data.arr 151 lx, rx = arr[x << 2], arr[x << 2 | 1] 152 arr[y << 2 | 2] = arr[x << 2 | 2] 153 arr[x << 2 | 2] = 1 + arr[lx << 2 | 2] + arr[rx << 2 | 2] 154 keydata[y << 1 | 1] = keydata[x << 1 | 1] 155 keydata[x << 1 | 1] = data.op( 156 data.op(keydata[lx << 1 | 1], keydata[x << 1]), keydata[rx << 1 | 1] 157 ) 158 159 def _update(self, node: int) -> None: 160 data = self.data 161 keydata, arr = data.keydata, data.arr 162 lnode, rnode = arr[node << 2], arr[node << 2 | 1] 163 arr[node << 2 | 2] = 1 + arr[lnode << 2 | 2] + arr[rnode << 2 | 2] 164 keydata[node << 1 | 1] = data.op( 165 data.op(keydata[lnode << 1 | 1], keydata[node << 1]), 166 keydata[rnode << 1 | 1], 167 ) 168 169 def _splay(self, path: list[int], d: int) -> None: 170 arr = self.data.arr 171 g = d & 1 172 while len(path) > 1: 173 pnode = path.pop() 174 gnode = path.pop() 175 f = d >> 1 & 1 176 node = arr[pnode << 2 | g ^ 1] 177 nnode = (pnode if g == f else node) << 2 | f 178 arr[pnode << 2 | g ^ 1] = arr[node << 2 | g] 179 arr[node << 2 | g] = pnode 180 arr[gnode << 2 | f ^ 1] = arr[nnode] 181 arr[nnode] = gnode 182 self._update_triple(gnode, pnode, node) 183 if not path: 184 return 185 d >>= 2 186 g = d & 1 187 arr[path[-1] << 2 | g ^ 1] = node 188 pnode = path.pop() 189 node = arr[pnode << 2 | g ^ 1] 190 arr[pnode << 2 | g ^ 1] = arr[node << 2 | g] 191 arr[node << 2 | g] = pnode 192 self._update_double(pnode, node) 193 194 def _kth_elm_splay(self, node: int, k: int) -> int: 195 arr = self.data.arr 196 if k < 0: 197 k += arr[node << 2 | 2] 198 d = 0 199 path = [] 200 while True: 201 self._propagate(node) 202 t = arr[arr[node << 2] << 2 | 2] 203 if t == k: 204 if path: 205 self._splay(path, d) 206 return node 207 d = d << 1 | (t > k) 208 path.append(node) 209 node = arr[node << 2 | (t < k)] 210 if t < k: 211 k -= t + 1 212 213 def _left_splay(self, node: int) -> int: 214 if not node: 215 return 0 216 self._propagate(node) 217 arr = self.data.arr 218 if not arr[node << 2]: 219 return node 220 path = [] 221 while arr[node << 2]: 222 path.append(node) 223 node = arr[node << 2] 224 self._propagate(node) 225 self._splay(path, (1 << len(path)) - 1) 226 return node 227 228 def _right_splay(self, node: int) -> int: 229 if not node: 230 return 0 231 self._propagate(node) 232 arr = self.data.arr 233 if not arr[node << 2 | 1]: 234 return node 235 path = [] 236 while arr[node << 2 | 1]: 237 path.append(node) 238 node = arr[node << 2 | 1] 239 self._propagate(node) 240 self._splay(path, 0) 241 return node 242
[docs] 243 def reserve(self, n: int) -> None: 244 self.data.reserve(n)
245
[docs] 246 def merge(self, other: "LazySplayTreeArray[T, F]") -> None: 247 assert self.data is other.data 248 if not other.root: 249 return 250 if not self.root: 251 self.root = other.root 252 return 253 self.root = self._right_splay(self.root) 254 self.data.arr[self.root << 2 | 1] = other.root 255 self._update(self.root)
256
[docs] 257 def split( 258 self, k: int 259 ) -> tuple["LazySplayTreeArray[T, F]", "LazySplayTreeArray[T, F]"]: 260 assert ( 261 -len(self) < k <= len(self) 262 ), f"IndexError: LazySplayTreeArray.split({k}), len={len(self)}" 263 if k < 0: 264 k += len(self) 265 if k >= self.data.arr[self.root << 2 | 2]: 266 return self, LazySplayTreeArray(self.data, _root=0) 267 self.root = self._kth_elm_splay(self.root, k) 268 left = LazySplayTreeArray(self.data, _root=self.data.arr[self.root << 2]) 269 self.data.arr[self.root << 2] = 0 270 self._update(self.root) 271 return left, self
272 273 def _internal_split(self, k: int) -> tuple[int, int]: 274 if k >= self.data.arr[self.root << 2 | 2]: 275 return self.root, 0 276 self.root = self._kth_elm_splay(self.root, k) 277 left = self.data.arr[self.root << 2] 278 self.data.arr[self.root << 2] = 0 279 self._update(self.root) 280 return left, self.root 281
[docs] 282 def reverse(self, l: int, r: int) -> None: 283 assert ( 284 0 <= l <= r <= len(self) 285 ), f"IndexError: LazySplayTreeArray.reverse({l}, {r}), len={len(self)}" 286 if l == r: 287 return 288 data = self.data 289 left, right = self._internal_split(r) 290 if l: 291 left = self._kth_elm_splay(left, l - 1) 292 data.arr[(data.arr[left << 2 | 1] if l else left) << 2 | 3] ^= 1 293 if right: 294 data.arr[right << 2] = left 295 self._update(right) 296 self.root = right if right else left
297
[docs] 298 def all_reverse(self) -> None: 299 self.data.arr[self.root << 2 | 3] ^= 1
300
[docs] 301 def apply(self, l: int, r: int, f: F) -> None: 302 assert ( 303 0 <= l <= r <= len(self) 304 ), f"IndexError: LazySplayTreeArray.apply({l}, {r}), len={len(self)}" 305 data = self.data 306 left, right = self._internal_split(r) 307 keydata, lazy = data.keydata, data.lazy 308 if l: 309 left = self._kth_elm_splay(left, l - 1) 310 node = data.arr[left << 2 | 1] if l else left 311 keydata[node << 1] = data.mapping(f, keydata[node << 1]) 312 keydata[node << 1 | 1] = data.mapping(f, keydata[node << 1 | 1]) 313 lazy[node] = data.composition(f, lazy[node]) 314 if l: 315 self._update(left) 316 if right: 317 data.arr[right << 2] = left 318 self._update(right) 319 self.root = right if right else left
320
[docs] 321 def all_apply(self, f: F) -> None: 322 if not self.root: 323 return 324 data, node = self.data, self.root 325 data.keydata[node << 1] = data.mapping(f, data.keydata[node << 1]) 326 data.keydata[node << 1 | 1] = data.mapping(f, data.keydata[node << 1 | 1]) 327 data.lazy[node] = data.composition(f, data.lazy[node])
328
[docs] 329 def prod(self, l: int, r: int) -> T: 330 assert ( 331 0 <= l <= r <= len(self) 332 ), f"IndexError: LazySplayTreeArray.prod({l}, {r}), len={len(self)}" 333 data = self.data 334 left, right = self._internal_split(r) 335 if l: 336 left = self._kth_elm_splay(left, l - 1) 337 res = data.keydata[(data.arr[left << 2 | 1] if l else left) << 1 | 1] 338 if right: 339 data.arr[right << 2] = left 340 self._update(right) 341 self.root = right if right else left 342 return res
343
[docs] 344 def all_prod(self) -> T: 345 return self.data.keydata[self.root << 1 | 1]
346
[docs] 347 def insert(self, k: int, key: T) -> None: 348 assert ( 349 -len(self) <= k <= len(self) 350 ), f"IndexError: LazySplayTreeArray.insert({k}, {key}), len={len(self)}" 351 if k < 0: 352 k += len(self) 353 data = self.data 354 node = self._make_node(key) 355 if not self.root: 356 self._update(node) 357 self.root = node 358 return 359 arr = data.arr 360 if k == data.arr[self.root << 2 | 2]: 361 arr[node << 2] = self._right_splay(self.root) 362 else: 363 node_ = self._kth_elm_splay(self.root, k) 364 if arr[node_ << 2]: 365 arr[node << 2] = arr[node_ << 2] 366 arr[node_ << 2] = 0 367 self._update(node_) 368 arr[node << 2 | 1] = node_ 369 self._update(node) 370 self.root = node
371
[docs] 372 def append(self, key: T) -> None: 373 data = self.data 374 node = self._right_splay(self.root) 375 self.root = self._make_node(key) 376 data.arr[self.root << 2] = node 377 self._update(self.root)
378
[docs] 379 def appendleft(self, key: T) -> None: 380 node = self._left_splay(self.root) 381 self.root = self._make_node(key) 382 self.data.arr[self.root << 2 | 1] = node 383 self._update(self.root)
384
[docs] 385 def pop(self, k: int = -1) -> T: 386 assert -len(self) <= k < len(self), f"IndexError: LazySplayTreeArray.pop({k})" 387 data = self.data 388 if k == -1: 389 node = self._right_splay(self.root) 390 self._propagate(node) 391 self.root = data.arr[node << 2] 392 return data.keydata[node << 1] 393 self.root = self._kth_elm_splay(self.root, k) 394 res = data.keydata[self.root << 1] 395 if not data.arr[self.root << 2]: 396 self.root = data.arr[self.root << 2 | 1] 397 elif not data.arr[self.root << 2 | 1]: 398 self.root = data.arr[self.root << 2] 399 else: 400 node = self._right_splay(data.arr[self.root << 2]) 401 data.arr[node << 2 | 1] = data.arr[self.root << 2 | 1] 402 self.root = node 403 self._update(self.root) 404 return res
405
[docs] 406 def popleft(self) -> T: 407 assert self, "IndexError: LazySplayTreeArray.popleft()" 408 node = self._left_splay(self.root) 409 self.root = self.data.arr[node << 2 | 1] 410 return self.data.keydata[node << 1]
411
[docs] 412 def rotate(self, x: int) -> None: 413 # 「末尾をを削除し先頭に挿入」をx回 414 n = self.data.arr[self.root << 2 | 2] 415 l, self = self.split(n - (x % n)) 416 self.merge(l)
417
[docs] 418 def tolist(self) -> list[T]: 419 node = self.root 420 arr, keydata = self.data.arr, self.data.keydata 421 stack = [] 422 res = [] 423 while stack or node: 424 if node: 425 self._propagate(node) 426 stack.append(node) 427 node = arr[node << 2] 428 else: 429 node = stack.pop() 430 res.append(keydata[node << 1]) 431 node = arr[node << 2 | 1] 432 return res
433
[docs] 434 def clear(self) -> None: 435 self.root = 0
436 437 def __setitem__(self, k: int, key: T): 438 assert ( 439 -len(self) <= k < len(self) 440 ), f"IndexError: LazySplayTreeArray.__setitem__({k})" 441 self.root = self._kth_elm_splay(self.root, k) 442 self.data.keydata[self.root << 1] = key 443 self._update(self.root) 444 445 def __getitem__(self, k: int) -> T: 446 assert ( 447 -len(self) <= k < len(self) 448 ), f"IndexError: LazySplayTreeArray.__getitem__({k})" 449 self.root = self._kth_elm_splay(self.root, k) 450 return self.data.keydata[self.root << 1] 451 452 def __iter__(self): 453 self.__iter = 0 454 return self 455 456 def __next__(self): 457 if self.__iter == self.data.arr[self.root << 2 | 2]: 458 raise StopIteration 459 res = self[self.__iter] 460 self.__iter += 1 461 return res 462 463 def __reversed__(self): 464 for i in range(len(self)): 465 yield self[-i - 1] 466 467 def __len__(self): 468 return self.data.arr[self.root << 2 | 2] 469 470 def __str__(self): 471 return str(self.tolist()) 472 473 def __bool__(self): 474 return self.root != 0 475 476 def __repr__(self): 477 return f"{self.__class__.__name__}({self})"