splay_tree_set¶
ソースコード¶
from titan_pylib.data_structures.splay_tree.splay_tree_set import SplayTreeSet
展開済みコード¶
1# from titan_pylib.data_structures.splay_tree.splay_tree_set import SplayTreeSet
2# from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
4from typing import Protocol
5
6
7class SupportsLessThan(Protocol):
8
9 def __lt__(self, other) -> bool: ...
10from abc import ABC, abstractmethod
11from typing import Iterable, Optional, Iterator, TypeVar, Generic
12
13T = TypeVar("T", bound=SupportsLessThan)
14
15
16class OrderedSetInterface(ABC, Generic[T]):
17
18 @abstractmethod
19 def __init__(self, a: Iterable[T]) -> None:
20 raise NotImplementedError
21
22 @abstractmethod
23 def add(self, key: T) -> bool:
24 raise NotImplementedError
25
26 @abstractmethod
27 def discard(self, key: T) -> bool:
28 raise NotImplementedError
29
30 @abstractmethod
31 def remove(self, key: T) -> None:
32 raise NotImplementedError
33
34 @abstractmethod
35 def le(self, key: T) -> Optional[T]:
36 raise NotImplementedError
37
38 @abstractmethod
39 def lt(self, key: T) -> Optional[T]:
40 raise NotImplementedError
41
42 @abstractmethod
43 def ge(self, key: T) -> Optional[T]:
44 raise NotImplementedError
45
46 @abstractmethod
47 def gt(self, key: T) -> Optional[T]:
48 raise NotImplementedError
49
50 @abstractmethod
51 def get_max(self) -> Optional[T]:
52 raise NotImplementedError
53
54 @abstractmethod
55 def get_min(self) -> Optional[T]:
56 raise NotImplementedError
57
58 @abstractmethod
59 def pop_max(self) -> T:
60 raise NotImplementedError
61
62 @abstractmethod
63 def pop_min(self) -> T:
64 raise NotImplementedError
65
66 @abstractmethod
67 def clear(self) -> None:
68 raise NotImplementedError
69
70 @abstractmethod
71 def tolist(self) -> list[T]:
72 raise NotImplementedError
73
74 @abstractmethod
75 def __iter__(self) -> Iterator:
76 raise NotImplementedError
77
78 @abstractmethod
79 def __next__(self) -> T:
80 raise NotImplementedError
81
82 @abstractmethod
83 def __contains__(self, key: T) -> bool:
84 raise NotImplementedError
85
86 @abstractmethod
87 def __len__(self) -> int:
88 raise NotImplementedError
89
90 @abstractmethod
91 def __bool__(self) -> bool:
92 raise NotImplementedError
93
94 @abstractmethod
95 def __str__(self) -> str:
96 raise NotImplementedError
97
98 @abstractmethod
99 def __repr__(self) -> str:
100 raise NotImplementedError
101# from titan_pylib.my_class.supports_less_than import SupportsLessThan
102from array import array
103from typing import Generic, Iterable, TypeVar, Optional
104
105T = TypeVar("T", bound=SupportsLessThan)
106
107
108class SplayTreeSet(OrderedSetInterface, Generic[T]):
109
110 def __init__(self, a: Iterable[T] = [], e: T = 0):
111 self.keys: list[T] = [e]
112 self.size = array("I", bytes(4))
113 self.child = array("I", bytes(8))
114 self.end = 1
115 self.node = 0
116 self.e = e
117 if not isinstance(a, list):
118 a = list(a)
119 if a:
120 self._build(a)
121
122 def _build(self, a: list[T]) -> None:
123 def sort(l: int, r: int) -> int:
124 mid = (l + r) >> 1
125 if l != mid:
126 child[mid << 1] = sort(l, mid)
127 if mid + 1 != r:
128 child[mid << 1 | 1] = sort(mid + 1, r)
129 size[mid] = 1 + size[child[mid << 1]] + size[child[mid << 1 | 1]]
130 return mid
131
132 if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
133 a = sorted(set(a))
134 n = len(a)
135 key, child, size = self.keys, self.child, self.size
136 self.reserve(n - len(key) + 2)
137 self.end += n
138 key[1 : n + 1] = a
139 self.node = sort(1, n + 1)
140
141 def _make_node(self, key: T) -> int:
142 end = self.end
143 if end >= len(self.keys):
144 self.keys.append(key)
145 self.size.append(1)
146 self.child.append(0)
147 self.child.append(0)
148 else:
149 self.keys[end] = key
150 self.end += 1
151 return end
152
153 def _update(self, node: int) -> None:
154 self.size[node] = (
155 1 + self.size[self.child[node << 1]] + self.size[self.child[node << 1 | 1]]
156 )
157
158 def _update_triple(self, x: int, y: int, z: int) -> None:
159 size, child = self.size, self.child
160 size[z] = size[x]
161 size[x] = 1 + size[child[x << 1]] + size[child[x << 1 | 1]]
162 size[y] = 1 + size[child[y << 1]] + size[child[y << 1 | 1]]
163
164 def _splay(self, path: list[int], d: int) -> int:
165 child = self.child
166 g = d & 1
167 while len(path) > 1:
168 node = path.pop()
169 pnode = path.pop()
170 f = d >> 1 & 1
171 tmp = child[node << 1 | g ^ 1]
172 child[node << 1 | g ^ 1] = child[tmp << 1 | g]
173 child[tmp << 1 | f ^ (g != f)] = node
174 child[pnode << 1 | f ^ 1] = child[(node if g == f else tmp) << 1 | f]
175 child[(node if g == f else tmp) << 1 | f] = pnode
176 self._update_triple(pnode, node, tmp)
177 if not path:
178 return tmp
179 d >>= 2
180 g = d & 1
181 child[path[-1] << 1 | g ^ 1] = tmp
182 pnode = path[0]
183 node = child[pnode << 1 | g ^ 1]
184 child[pnode << 1 | g ^ 1] = child[node << 1 | g]
185 child[node << 1 | g] = pnode
186 self._update(pnode)
187 self._update(node)
188 return node
189
190 def _set_search_splay(self, key: T) -> None:
191 node = self.node
192 keys, child = self.keys, self.child
193 if (not node) or keys[node] == key:
194 return
195 path = []
196 d = 0
197 while keys[node] != key:
198 nxt = child[node << 1 | (key > keys[node])]
199 if not nxt:
200 break
201 path.append(node)
202 d, node = d << 1 | (key < keys[node]), nxt
203 if path:
204 self.node = self._splay(path, d)
205
206 def _set_kth_elm_splay(self, k: int) -> None:
207 size, child = self.size, self.child
208 node = self.node
209 if k < 0:
210 k += size[node]
211 d = 0
212 path = []
213 while True:
214 t = size[child[node << 1]]
215 if t == k:
216 if path:
217 self.node = self._splay(path, d)
218 return
219 path.append(node)
220 d, node = d << 1 | (t > k), child[node << 1 | (t < k)]
221 if t < k:
222 k -= t + 1
223
224 def _get_min_splay(self, node: int) -> int:
225 if not node:
226 return 0
227 child = self.child
228 if not child[node << 1]:
229 return node
230 path = []
231 while child[node << 1]:
232 path.append(node)
233 node = child[node << 1]
234 return self._splay(path, (1 << len(path)) - 1)
235
236 def _get_max_splay(self, node: int) -> int:
237 if not node:
238 return 0
239 child = self.child
240 if not child[node << 1 | 1]:
241 return node
242 path = []
243 while child[node << 1 | 1]:
244 path.append(node)
245 node = child[node << 1 | 1]
246 return self._splay(path, 0)
247
248 def reserve(self, n: int) -> None:
249 assert n >= 0, f"ValueError: SplayTreeSet.reserve({n})"
250 self.keys += [self.e] * n
251 self.size += array("I", [1] * n)
252 self.child += array("I", bytes(8 * n))
253
254 def add(self, key: T) -> bool:
255 keys, child = self.keys, self.child
256 self._set_search_splay(key)
257 if keys[self.node] == key:
258 return False
259 node = self._make_node(key)
260 if not self.node:
261 self.node = node
262 return True
263 f = key > keys[self.node]
264 child[node << 1 | f ^ 1] = self.node
265 child[node << 1 | f] = child[self.node << 1 | f]
266 child[self.node << 1 | f] = 0
267 self._update(child[node << 1 | f ^ 1])
268 self._update(node)
269 self.node = node
270 return True
271
272 def discard(self, key: T) -> bool:
273 if not self.node:
274 return False
275 self._set_search_splay(key)
276 if self.keys[self.node] != key:
277 return False
278 child = self.child
279 if not child[self.node << 1]:
280 self.node = child[self.node << 1 | 1]
281 elif not child[self.node << 1 | 1]:
282 self.node = child[self.node << 1]
283 else:
284 node = self._get_min_splay(child[self.node << 1 | 1])
285 child[node << 1] = child[self.node << 1]
286 self._update(child[node << 1])
287 self.node = node
288 return True
289
290 def remove(self, key: T) -> None:
291 if self.discard(key):
292 return
293 raise KeyError(key)
294
295 def le(self, key: T) -> Optional[T]:
296 node = self.node
297 keys, child = self.keys, self.child
298 path = []
299 d = 0
300 res = None
301 while node:
302 if keys[node] == key:
303 res = key
304 break
305 path.append(node)
306 d = d << 1 | (key < keys[node])
307 if key > keys[node]:
308 res = keys[node]
309 node = child[node << 1 | (key > keys[node])]
310 else:
311 if path:
312 path.pop()
313 d >>= 1
314 if path:
315 self.node = self._splay(path, d)
316 return res
317
318 def lt(self, key: T) -> Optional[T]:
319 node = self.node
320 keys, child = self.keys, self.child
321 path = []
322 d = 0
323 res = None
324 while node:
325 path.append(node)
326 d = d << 1 | (key <= keys[node])
327 if key > keys[node]:
328 res = keys[node]
329 node = child[node << 1 | (key > keys[node])]
330 else:
331 if path:
332 path.pop()
333 d >>= 1
334 if path:
335 self.node = self._splay(path, d)
336 return res
337
338 def ge(self, key: T) -> Optional[T]:
339 node = self.node
340 keys, child = self.keys, self.child
341 path = []
342 d = 0
343 res = None
344 while node:
345 if keys[node] == key:
346 res = keys[node]
347 break
348 path.append(node)
349 d = d << 1 | (key < keys[node])
350 if key < keys[node]:
351 res = keys[node]
352 node = child[node << 1 | (key >= keys[node])]
353 else:
354 if path:
355 path.pop()
356 d >>= 1
357 if path:
358 self.node = self._splay(path, d)
359 return res
360
361 def gt(self, key: T) -> Optional[T]:
362 node = self.node
363 keys, child = self.keys, self.child
364 path = []
365 d = 0
366 res = None
367 while node:
368 path.append(node)
369 d = d << 1 | (key < keys[node])
370 if key < keys[node]:
371 res = keys[node]
372 node = child[node << 1 | (key >= keys[node])]
373 else:
374 if path:
375 path.pop()
376 d >>= 1
377 if path:
378 self.node = self._splay(path, d)
379 return res
380
381 def index(self, key: T) -> int:
382 if not self.node:
383 return 0
384 self._set_search_splay(key)
385 res = self.size[self.child[self.node << 1]]
386 if self.keys[self.node] < key:
387 res += 1
388 return res
389
390 def index_right(self, key: T) -> int:
391 if not self.node:
392 return 0
393 self._set_search_splay(key)
394 res = self.size[self.child[self.node << 1]]
395 if self.keys[self.node] <= key:
396 res += 1
397 return res
398
399 def get_min(self) -> T:
400 return self.__getitem__(0)
401
402 def get_max(self) -> T:
403 return self.__getitem__(-1)
404
405 def pop(self, k: int = -1) -> T:
406 assert self.node, f"IndexError: SplayTreeSet.pop({k})"
407 keys, child = self.keys, self.child
408 if k == -1:
409 node = self._get_max_splay(self.node)
410 self.node = child[node << 1]
411 return keys[node]
412 self._set_kth_elm_splay(k)
413 res = keys[self.node]
414 if not [self.node << 1]:
415 self.node = child[self.node << 1 | 1]
416 elif not child[self.node << 1 | 1]:
417 self.node = child[self.node << 1]
418 else:
419 node = self._get_min_splay(child[self.node << 1 | 1])
420 child[node << 1] = child[self.node << 1]
421 self._update(node)
422 self.node = node
423 return res
424
425 def pop_max(self) -> T:
426 return self.pop()
427
428 def pop_min(self) -> T:
429 assert self.node, "IndexError: SplayTreeSet.pop_min()"
430 node = self._get_min_splay(self.node)
431 self.node = self.child[node << 1 | 1]
432 return self.keys[node]
433
434 def clear(self) -> None:
435 self.node = 0
436
437 def tolist(self) -> list[T]:
438 node = self.node
439 child, keys = self.child, self.keys
440 stack, res = [], []
441 while stack or node:
442 if node:
443 stack.append(node)
444 node = child[node << 1]
445 else:
446 node = stack.pop()
447 res.append(keys[node])
448 node = child[node << 1 | 1]
449 return res
450
451 def __iter__(self):
452 self.__iter = 0
453 return self
454
455 def __next__(self):
456 if self.__iter == self.__len__():
457 raise StopIteration
458 self.__iter += 1
459 return self.__getitem__(self.__iter - 1)
460
461 def __reversed__(self):
462 for i in range(self.__len__()):
463 yield self.__getitem__(-i - 1)
464
465 def __contains__(self, key: T):
466 self._set_search_splay(key)
467 return self.keys[self.node] == key
468
469 def __getitem__(self, k: int) -> T:
470 self._set_kth_elm_splay(k)
471 return self.keys[self.node]
472
473 def __len__(self):
474 return self.size[self.node]
475
476 def __bool__(self):
477 return self.node != 0
478
479 def __str__(self):
480 return "{" + ", ".join(map(str, self.tolist())) + "}"
481
482 def __repr__(self):
483 return f"SplayTreeSet({self})"
仕様¶
- class SplayTreeSet(a: Iterable[T] = [], e: T = 0)[source]¶
Bases:
OrderedSetInterface
,Generic
[T
]