Source code for titan_pylib.data_structures.splay_tree.splay_tree_set

  1from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
  2from titan_pylib.my_class.supports_less_than import SupportsLessThan
  3from array import array
  4from typing import Generic, Iterable, TypeVar, Optional
  5
  6T = TypeVar("T", bound=SupportsLessThan)
  7
  8
[docs] 9class SplayTreeSet(OrderedSetInterface, Generic[T]): 10 11 def __init__(self, a: Iterable[T] = [], e: T = 0): 12 self.keys: list[T] = [e] 13 self.size = array("I", bytes(4)) 14 self.child = array("I", bytes(8)) 15 self.end = 1 16 self.node = 0 17 self.e = e 18 if not isinstance(a, list): 19 a = list(a) 20 if a: 21 self._build(a) 22 23 def _build(self, a: list[T]) -> None: 24 def sort(l: int, r: int) -> int: 25 mid = (l + r) >> 1 26 if l != mid: 27 child[mid << 1] = sort(l, mid) 28 if mid + 1 != r: 29 child[mid << 1 | 1] = sort(mid + 1, r) 30 size[mid] = 1 + size[child[mid << 1]] + size[child[mid << 1 | 1]] 31 return mid 32 33 if not all(a[i] < a[i + 1] for i in range(len(a) - 1)): 34 a = sorted(set(a)) 35 n = len(a) 36 key, child, size = self.keys, self.child, self.size 37 self.reserve(n - len(key) + 2) 38 self.end += n 39 key[1 : n + 1] = a 40 self.node = sort(1, n + 1) 41 42 def _make_node(self, key: T) -> int: 43 end = self.end 44 if end >= len(self.keys): 45 self.keys.append(key) 46 self.size.append(1) 47 self.child.append(0) 48 self.child.append(0) 49 else: 50 self.keys[end] = key 51 self.end += 1 52 return end 53 54 def _update(self, node: int) -> None: 55 self.size[node] = ( 56 1 + self.size[self.child[node << 1]] + self.size[self.child[node << 1 | 1]] 57 ) 58 59 def _update_triple(self, x: int, y: int, z: int) -> None: 60 size, child = self.size, self.child 61 size[z] = size[x] 62 size[x] = 1 + size[child[x << 1]] + size[child[x << 1 | 1]] 63 size[y] = 1 + size[child[y << 1]] + size[child[y << 1 | 1]] 64 65 def _splay(self, path: list[int], d: int) -> int: 66 child = self.child 67 g = d & 1 68 while len(path) > 1: 69 node = path.pop() 70 pnode = path.pop() 71 f = d >> 1 & 1 72 tmp = child[node << 1 | g ^ 1] 73 child[node << 1 | g ^ 1] = child[tmp << 1 | g] 74 child[tmp << 1 | f ^ (g != f)] = node 75 child[pnode << 1 | f ^ 1] = child[(node if g == f else tmp) << 1 | f] 76 child[(node if g == f else tmp) << 1 | f] = pnode 77 self._update_triple(pnode, node, tmp) 78 if not path: 79 return tmp 80 d >>= 2 81 g = d & 1 82 child[path[-1] << 1 | g ^ 1] = tmp 83 pnode = path[0] 84 node = child[pnode << 1 | g ^ 1] 85 child[pnode << 1 | g ^ 1] = child[node << 1 | g] 86 child[node << 1 | g] = pnode 87 self._update(pnode) 88 self._update(node) 89 return node 90 91 def _set_search_splay(self, key: T) -> None: 92 node = self.node 93 keys, child = self.keys, self.child 94 if (not node) or keys[node] == key: 95 return 96 path = [] 97 d = 0 98 while keys[node] != key: 99 nxt = child[node << 1 | (key > keys[node])] 100 if not nxt: 101 break 102 path.append(node) 103 d, node = d << 1 | (key < keys[node]), nxt 104 if path: 105 self.node = self._splay(path, d) 106 107 def _set_kth_elm_splay(self, k: int) -> None: 108 size, child = self.size, self.child 109 node = self.node 110 if k < 0: 111 k += size[node] 112 d = 0 113 path = [] 114 while True: 115 t = size[child[node << 1]] 116 if t == k: 117 if path: 118 self.node = self._splay(path, d) 119 return 120 path.append(node) 121 d, node = d << 1 | (t > k), child[node << 1 | (t < k)] 122 if t < k: 123 k -= t + 1 124 125 def _get_min_splay(self, node: int) -> int: 126 if not node: 127 return 0 128 child = self.child 129 if not child[node << 1]: 130 return node 131 path = [] 132 while child[node << 1]: 133 path.append(node) 134 node = child[node << 1] 135 return self._splay(path, (1 << len(path)) - 1) 136 137 def _get_max_splay(self, node: int) -> int: 138 if not node: 139 return 0 140 child = self.child 141 if not child[node << 1 | 1]: 142 return node 143 path = [] 144 while child[node << 1 | 1]: 145 path.append(node) 146 node = child[node << 1 | 1] 147 return self._splay(path, 0) 148
[docs] 149 def reserve(self, n: int) -> None: 150 assert n >= 0, f"ValueError: SplayTreeSet.reserve({n})" 151 self.keys += [self.e] * n 152 self.size += array("I", [1] * n) 153 self.child += array("I", bytes(8 * n))
154
[docs] 155 def add(self, key: T) -> bool: 156 keys, child = self.keys, self.child 157 self._set_search_splay(key) 158 if keys[self.node] == key: 159 return False 160 node = self._make_node(key) 161 if not self.node: 162 self.node = node 163 return True 164 f = key > keys[self.node] 165 child[node << 1 | f ^ 1] = self.node 166 child[node << 1 | f] = child[self.node << 1 | f] 167 child[self.node << 1 | f] = 0 168 self._update(child[node << 1 | f ^ 1]) 169 self._update(node) 170 self.node = node 171 return True
172
[docs] 173 def discard(self, key: T) -> bool: 174 if not self.node: 175 return False 176 self._set_search_splay(key) 177 if self.keys[self.node] != key: 178 return False 179 child = self.child 180 if not child[self.node << 1]: 181 self.node = child[self.node << 1 | 1] 182 elif not child[self.node << 1 | 1]: 183 self.node = child[self.node << 1] 184 else: 185 node = self._get_min_splay(child[self.node << 1 | 1]) 186 child[node << 1] = child[self.node << 1] 187 self._update(child[node << 1]) 188 self.node = node 189 return True
190
[docs] 191 def remove(self, key: T) -> None: 192 if self.discard(key): 193 return 194 raise KeyError(key)
195
[docs] 196 def le(self, key: T) -> Optional[T]: 197 node = self.node 198 keys, child = self.keys, self.child 199 path = [] 200 d = 0 201 res = None 202 while node: 203 if keys[node] == key: 204 res = key 205 break 206 path.append(node) 207 d = d << 1 | (key < keys[node]) 208 if key > keys[node]: 209 res = keys[node] 210 node = child[node << 1 | (key > keys[node])] 211 else: 212 if path: 213 path.pop() 214 d >>= 1 215 if path: 216 self.node = self._splay(path, d) 217 return res
218
[docs] 219 def lt(self, key: T) -> Optional[T]: 220 node = self.node 221 keys, child = self.keys, self.child 222 path = [] 223 d = 0 224 res = None 225 while node: 226 path.append(node) 227 d = d << 1 | (key <= keys[node]) 228 if key > keys[node]: 229 res = keys[node] 230 node = child[node << 1 | (key > keys[node])] 231 else: 232 if path: 233 path.pop() 234 d >>= 1 235 if path: 236 self.node = self._splay(path, d) 237 return res
238
[docs] 239 def ge(self, key: T) -> Optional[T]: 240 node = self.node 241 keys, child = self.keys, self.child 242 path = [] 243 d = 0 244 res = None 245 while node: 246 if keys[node] == key: 247 res = keys[node] 248 break 249 path.append(node) 250 d = d << 1 | (key < keys[node]) 251 if key < keys[node]: 252 res = keys[node] 253 node = child[node << 1 | (key >= keys[node])] 254 else: 255 if path: 256 path.pop() 257 d >>= 1 258 if path: 259 self.node = self._splay(path, d) 260 return res
261
[docs] 262 def gt(self, key: T) -> Optional[T]: 263 node = self.node 264 keys, child = self.keys, self.child 265 path = [] 266 d = 0 267 res = None 268 while node: 269 path.append(node) 270 d = d << 1 | (key < keys[node]) 271 if key < keys[node]: 272 res = keys[node] 273 node = child[node << 1 | (key >= keys[node])] 274 else: 275 if path: 276 path.pop() 277 d >>= 1 278 if path: 279 self.node = self._splay(path, d) 280 return res
281
[docs] 282 def index(self, key: T) -> int: 283 if not self.node: 284 return 0 285 self._set_search_splay(key) 286 res = self.size[self.child[self.node << 1]] 287 if self.keys[self.node] < key: 288 res += 1 289 return res
290
[docs] 291 def index_right(self, key: T) -> int: 292 if not self.node: 293 return 0 294 self._set_search_splay(key) 295 res = self.size[self.child[self.node << 1]] 296 if self.keys[self.node] <= key: 297 res += 1 298 return res
299
[docs] 300 def get_min(self) -> T: 301 return self.__getitem__(0)
302
[docs] 303 def get_max(self) -> T: 304 return self.__getitem__(-1)
305
[docs] 306 def pop(self, k: int = -1) -> T: 307 assert self.node, f"IndexError: SplayTreeSet.pop({k})" 308 keys, child = self.keys, self.child 309 if k == -1: 310 node = self._get_max_splay(self.node) 311 self.node = child[node << 1] 312 return keys[node] 313 self._set_kth_elm_splay(k) 314 res = keys[self.node] 315 if not [self.node << 1]: 316 self.node = child[self.node << 1 | 1] 317 elif not child[self.node << 1 | 1]: 318 self.node = child[self.node << 1] 319 else: 320 node = self._get_min_splay(child[self.node << 1 | 1]) 321 child[node << 1] = child[self.node << 1] 322 self._update(node) 323 self.node = node 324 return res
325
[docs] 326 def pop_max(self) -> T: 327 return self.pop()
328
[docs] 329 def pop_min(self) -> T: 330 assert self.node, "IndexError: SplayTreeSet.pop_min()" 331 node = self._get_min_splay(self.node) 332 self.node = self.child[node << 1 | 1] 333 return self.keys[node]
334
[docs] 335 def clear(self) -> None: 336 self.node = 0
337
[docs] 338 def tolist(self) -> list[T]: 339 node = self.node 340 child, keys = self.child, self.keys 341 stack, res = [], [] 342 while stack or node: 343 if node: 344 stack.append(node) 345 node = child[node << 1] 346 else: 347 node = stack.pop() 348 res.append(keys[node]) 349 node = child[node << 1 | 1] 350 return res
351 352 def __iter__(self): 353 self.__iter = 0 354 return self 355 356 def __next__(self): 357 if self.__iter == self.__len__(): 358 raise StopIteration 359 self.__iter += 1 360 return self.__getitem__(self.__iter - 1) 361 362 def __reversed__(self): 363 for i in range(self.__len__()): 364 yield self.__getitem__(-i - 1) 365 366 def __contains__(self, key: T): 367 self._set_search_splay(key) 368 return self.keys[self.node] == key 369 370 def __getitem__(self, k: int) -> T: 371 self._set_kth_elm_splay(k) 372 return self.keys[self.node] 373 374 def __len__(self): 375 return self.size[self.node] 376 377 def __bool__(self): 378 return self.node != 0 379 380 def __str__(self): 381 return "{" + ", ".join(map(str, self.tolist())) + "}" 382 383 def __repr__(self): 384 return f"SplayTreeSet({self})"