splay_tree_multiset_top_down¶
ソースコード¶
from titan_pylib.data_structures.splay_tree.splay_tree_multiset_top_down import SplayTreeMultisetTopDown
展開済みコード¶
1# from titan_pylib.data_structures.splay_tree.splay_tree_multiset_top_down import SplayTreeMultisetTopDown
2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
3from typing import Protocol
4
5
6class SupportsLessThan(Protocol):
7
8 def __lt__(self, other) -> bool: ...
9# from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
10# from titan_pylib.my_class.supports_less_than import SupportsLessThan
11from abc import ABC, abstractmethod
12from typing import Iterable, Optional, Iterator, TypeVar, Generic
13
14T = TypeVar("T", bound=SupportsLessThan)
15
16
17class OrderedMultisetInterface(ABC, Generic[T]):
18
19 @abstractmethod
20 def __init__(self, a: Iterable[T]) -> None:
21 raise NotImplementedError
22
23 @abstractmethod
24 def add(self, key: T, cnt: int) -> None:
25 raise NotImplementedError
26
27 @abstractmethod
28 def discard(self, key: T, cnt: int) -> bool:
29 raise NotImplementedError
30
31 @abstractmethod
32 def discard_all(self, key: T) -> bool:
33 raise NotImplementedError
34
35 @abstractmethod
36 def count(self, key: T) -> int:
37 raise NotImplementedError
38
39 @abstractmethod
40 def remove(self, key: T, cnt: int) -> None:
41 raise NotImplementedError
42
43 @abstractmethod
44 def le(self, key: T) -> Optional[T]:
45 raise NotImplementedError
46
47 @abstractmethod
48 def lt(self, key: T) -> Optional[T]:
49 raise NotImplementedError
50
51 @abstractmethod
52 def ge(self, key: T) -> Optional[T]:
53 raise NotImplementedError
54
55 @abstractmethod
56 def gt(self, key: T) -> Optional[T]:
57 raise NotImplementedError
58
59 @abstractmethod
60 def get_max(self) -> Optional[T]:
61 raise NotImplementedError
62
63 @abstractmethod
64 def get_min(self) -> Optional[T]:
65 raise NotImplementedError
66
67 @abstractmethod
68 def pop_max(self) -> T:
69 raise NotImplementedError
70
71 @abstractmethod
72 def pop_min(self) -> T:
73 raise NotImplementedError
74
75 @abstractmethod
76 def clear(self) -> None:
77 raise NotImplementedError
78
79 @abstractmethod
80 def tolist(self) -> list[T]:
81 raise NotImplementedError
82
83 @abstractmethod
84 def __iter__(self) -> Iterator:
85 raise NotImplementedError
86
87 @abstractmethod
88 def __next__(self) -> T:
89 raise NotImplementedError
90
91 @abstractmethod
92 def __contains__(self, key: T) -> bool:
93 raise NotImplementedError
94
95 @abstractmethod
96 def __len__(self) -> int:
97 raise NotImplementedError
98
99 @abstractmethod
100 def __bool__(self) -> bool:
101 raise NotImplementedError
102
103 @abstractmethod
104 def __str__(self) -> str:
105 raise NotImplementedError
106
107 @abstractmethod
108 def __repr__(self) -> str:
109 raise NotImplementedError
110from array import array
111from typing import Optional, Generic, Iterable, Sequence, TypeVar
112
113T = TypeVar("T", bound=SupportsLessThan)
114
115
116class SplayTreeMultisetTopDown(OrderedMultisetInterface, Generic[T]):
117
118 def __init__(self, a: Iterable[T] = [], e: T = 0):
119 self.keys: list[T] = [e]
120 self.vals: list[int] = [0]
121 self.child = array("I", bytes(8))
122 self.end: int = 1
123 self.root: int = 0
124 self.len: int = 0
125 self.e: T = e
126 if not isinstance(a, list):
127 a = list(a)
128 if a:
129 self._build(a)
130
131 def _rle(self, a: Sequence[T]) -> tuple[list[T], list[int]]:
132 x = []
133 y = []
134 x.append(a[0])
135 y.append(1)
136 for i, e in enumerate(a):
137 if i == 0:
138 continue
139 if e == x[-1]:
140 y[-1] += 1
141 continue
142 x.append(e)
143 y.append(1)
144 return x, y
145
146 def _build(self, a: Sequence[T]) -> None:
147 def rec(l: int, r: int) -> int:
148 mid = (l + r) >> 1
149 if l != mid:
150 child[mid << 1] = rec(l, mid)
151 if mid + 1 != r:
152 child[mid << 1 | 1] = rec(mid + 1, r)
153 return mid
154
155 if not all(a[i] <= a[i + 1] for i in range(len(a) - 1)):
156 a = sorted(a)
157 x, y = self._rle(a)
158 n = len(x)
159 keys, vals, child = self.keys, self.vals, self.child
160 self.reserve(n - len(keys) + 2)
161 self.end += n
162 keys[1 : n + 1] = x
163 vals[1 : n + 1] = y
164 self.root = rec(1, n + 1)
165 self.len = len(a)
166
167 def _make_node(self, key: T, val: int) -> int:
168 if self.end >= len(self.keys):
169 self.keys.append(key)
170 self.vals.append(val)
171 self.child.append(0)
172 self.child.append(0)
173 else:
174 self.keys[self.end] = key
175 self.vals[self.end] = val
176 self.end += 1
177 return self.end - 1
178
179 def _rotate_left(self, node: int) -> int:
180 child = self.child
181 u = child[node << 1]
182 child[node << 1] = child[u << 1 | 1]
183 child[u << 1 | 1] = node
184 return u
185
186 def _rotate_right(self, node: int) -> int:
187 child = self.child
188 u = child[node << 1 | 1]
189 child[node << 1 | 1] = child[u << 1]
190 child[u << 1] = node
191 return u
192
193 def _set_search_splay(self, key: T) -> None:
194 node = self.root
195 keys, child = self.keys, self.child
196 if (not node) or keys[node] == key:
197 return
198 left, right = 0, 0
199 while keys[node] != key:
200 f = key > keys[node]
201 if not child[node << 1 | f]:
202 break
203 if f:
204 if key > keys[child[node << 1 | 1]]:
205 node = self._rotate_right(node)
206 if not child[node << 1 | 1]:
207 break
208 child[left << 1 | 1] = node
209 left = node
210 else:
211 if key < keys[child[node << 1]]:
212 node = self._rotate_left(node)
213 if not child[node << 1]:
214 break
215 child[right << 1] = node
216 right = node
217 node = child[node << 1 | f]
218 child[right << 1] = child[node << 1 | 1]
219 child[left << 1 | 1] = child[node << 1]
220 child[node << 1] = child[1]
221 child[node << 1 | 1] = child[0]
222 self.root = node
223
224 def _get_min_splay(self, node: int) -> int:
225 child = self.child
226 if (not node) or (not child[node << 1]):
227 return node
228 right = 0
229 while child[node << 1]:
230 node = self._rotate_left(node)
231 if not child[node << 1]:
232 break
233 child[right << 1] = node
234 right = node
235 node = child[node << 1]
236 child[right << 1] = child[node << 1 | 1]
237 child[1] = child[node << 1]
238 child[node << 1] = child[1]
239 child[node << 1 | 1] = child[0]
240 return node
241
242 def _get_max_splay(self, node: int) -> int:
243 child = self.child
244 if (not node) or (not child[node << 1 | 1]):
245 return node
246 left = 0
247 while child[node << 1 | 1]:
248 node = self._rotate_right(node)
249 if not child[node << 1 | 1]:
250 break
251 child[left << 1 | 1] = node
252 left = node
253 node = child[node << 1 | 1]
254 child[0] = child[node << 1 | 1]
255 child[left << 1 | 1] = child[node << 1]
256 child[node << 1] = child[1]
257 child[node << 1 | 1] = child[0]
258 return node
259
260 def reserve(self, n: int) -> None:
261 assert n >= 0, "ValueError"
262 self.keys += [self.e] * n
263 self.vals += [0] * n
264 self.child += array("I", bytes(8 * n))
265
266 def add(self, key: T, val: int = 1) -> bool:
267 self.len += val
268 if not self.root:
269 self.root = self._make_node(key, val)
270 return
271 keys, vals, child = self.keys, self.vals, self.child
272 self._set_search_splay(key)
273 if keys[self.root] == key:
274 vals[self.root] += val
275 return
276 node = self._make_node(key, val)
277 f = key > keys[self.root]
278 child[node << 1 | f] = child[self.root << 1 | f]
279 child[node << 1 | f ^ 1] = self.root
280 child[self.root << 1 | f] = 0
281 self.root = node
282
283 def discard(self, key: T, val: int = 1) -> bool:
284 if not self.root:
285 return False
286 self._set_search_splay(key)
287 keys, vals, child = self.keys, self.vals, self.child
288 if keys[self.root] != key:
289 return False
290 if vals[self.root] > val:
291 vals[self.root] -= val
292 self.len -= val
293 return True
294 self.len -= vals[self.root]
295 if not child[self.root << 1]:
296 self.root = child[self.root << 1 | 1]
297 elif not child[self.root << 1 | 1]:
298 self.root = child[self.root << 1]
299 else:
300 node = self._get_min_splay(child[self.root << 1 | 1])
301 child[node << 1] = child[self.root << 1]
302 self.root = node
303 return True
304
305 def discard_all(self, key: T) -> bool:
306 return self.discard(key, self.count(key))
307
308 def remove(self, key: T, val: int = 1) -> None:
309 c = self.count(key)
310 if c < val:
311 raise KeyError(key)
312 self.discard(key, val)
313
314 def count(self, key: T) -> int:
315 if not self.root:
316 return 0
317 self._set_search_splay(key)
318 return self.vals[self.root]
319
320 def ge(self, key: T) -> Optional[T]:
321 node = self.root
322 if not node:
323 return None
324 keys, child = self.keys, self.child
325 if keys[node] == key:
326 return key
327 ge = None
328 left, right = 0, 0
329 while True:
330 if keys[node] == key:
331 ge = key
332 break
333 if key < keys[node]:
334 ge = keys[node]
335 if not child[node << 1]:
336 break
337 if key < keys[child[node << 1]]:
338 node = self._rotate_left(node)
339 ge = keys[node]
340 if not child[node << 1]:
341 break
342 child[right << 1] = node
343 right = node
344 node = child[node << 1]
345 else:
346 if not child[node << 1 | 1]:
347 break
348 if key > keys[child[node << 1 | 1]]:
349 node = self._rotate_right(node)
350 if not child[node << 1 | 1]:
351 break
352 child[left << 1 | 1] = node
353 left = node
354 node = child[node << 1 | 1]
355 child[right << 1] = child[node << 1 | 1]
356 child[left << 1 | 1] = child[node << 1]
357 child[node << 1] = child[1]
358 child[node << 1 | 1] = child[0]
359 self.root = node
360 return ge
361
362 def gt(self, key: T) -> Optional[T]:
363 node = self.root
364 if not node:
365 return None
366 gt = None
367 keys, child = self.keys, self.child
368 left, right = 0, 0
369 while True:
370 if key < keys[node]:
371 gt = keys[node]
372 if not child[node << 1]:
373 break
374 if key < keys[child[node << 1]]:
375 node = self._rotate_left(node)
376 gt = keys[node]
377 if not child[node << 1]:
378 break
379 child[right << 1] = node
380 right = node
381 node = child[node << 1]
382 else:
383 if not child[node << 1 | 1]:
384 break
385 if key > keys[child[node << 1 | 1]]:
386 node = self._rotate_right(node)
387 if not child[node << 1 | 1]:
388 break
389 child[left << 1 | 1] = node
390 left = node
391 node = child[node << 1 | 1]
392 child[right << 1] = child[node << 1 | 1]
393 child[left << 1 | 1] = child[node << 1]
394 child[node << 1] = child[1]
395 child[node << 1 | 1] = child[0]
396 self.root = node
397 return gt
398
399 def le(self, key: T) -> Optional[T]:
400 node = self.root
401 if not node:
402 return None
403 keys, child = self.keys, self.child
404 if keys[node] == key:
405 return key
406 le = None
407 left, right = 0, 0
408 while True:
409 if keys[node] == key:
410 le = key
411 break
412 if key < keys[node]:
413 if not child[node << 1]:
414 break
415 if key < keys[child[node << 1]]:
416 node = self._rotate_left(node)
417 if not child[node << 1]:
418 break
419 child[right << 1] = node
420 right = node
421 node = child[node << 1]
422 else:
423 le = keys[node]
424 if not child[node << 1 | 1]:
425 break
426 if key > keys[child[node << 1 | 1]]:
427 node = self._rotate_right(node)
428 le = keys[node]
429 if not child[node << 1 | 1]:
430 break
431 child[left << 1 | 1] = node
432 left = node
433 node = child[node << 1 | 1]
434 child[right << 1] = child[node << 1 | 1]
435 child[left << 1 | 1] = child[node << 1]
436 child[node << 1] = child[1]
437 child[node << 1 | 1] = child[0]
438 self.root = node
439 return le
440
441 def lt(self, key: T) -> Optional[T]:
442 node = self.root
443 if not node:
444 return None
445 lt = None
446 keys, child = self.keys, self.child
447 left, right = 0, 0
448 while True:
449 if not keys[node] > key:
450 if not child[node << 1]:
451 break
452 if key < keys[child[node << 1]]:
453 node = self._rotate_left(node)
454 if not child[node << 1]:
455 break
456 child[right << 1] = node
457 right = node
458 node = child[node << 1]
459 else:
460 lt = keys[node]
461 if not child[node << 1 | 1]:
462 break
463 if key > keys[child[node << 1 | 1]]:
464 node = self._rotate_right(node)
465 lt = keys[node]
466 if not child[node << 1 | 1]:
467 break
468 child[left << 1 | 1] = node
469 left = node
470 node = child[node << 1 | 1]
471 child[right << 1] = child[node << 1 | 1]
472 child[left << 1 | 1] = child[node << 1]
473 child[node << 1] = child[1]
474 child[node << 1 | 1] = child[0]
475 self.root = node
476 return lt
477
478 def tolist(self) -> list[T]:
479 node = self.root
480 child, vals, keys = self.child, self.vals, self.keys
481 stack, res = [], []
482 while stack or node:
483 if node:
484 stack.append(node)
485 node = child[node << 1]
486 else:
487 node = stack.pop()
488 for _ in range(vals[node]):
489 res.append(keys[node])
490 node = child[node << 1 | 1]
491 return res
492
493 def get_max(self) -> T:
494 assert self.root, "IndexError: get_max() from empty SplayTreeMultisetTopDown"
495 self.root = self._get_max_splay(self.root)
496 return self.keys[self.root]
497
498 def get_min(self) -> T:
499 assert self.root, "IndexError: get_min() from empty SplayTreeMultisetTopDown"
500 self.root = self._get_min_splay(self.root)
501 return self.keys[self.root]
502
503 def pop_max(self) -> T:
504 assert self.root, "IndexError: pop_max() from empty SplayTreeMultisetTopDown"
505 node = self._get_max_splay(self.root)
506 self.len -= 1
507 if self.vals[node] > 1:
508 self.vals[node] -= 1
509 self.root = node
510 else:
511 self.root = self.child[node << 1]
512 return self.keys[node]
513
514 def pop_min(self) -> T:
515 assert self.root, "IndexError: pop_min() from empty SplayTreeMultisetTopDown"
516 node = self._get_min_splay(self.root)
517 self.len -= 1
518 if self.vals[node] > 1:
519 self.vals[node] -= 1
520 self.root = node
521 else:
522 self.root = self.child[node << 1 | 1]
523 return self.keys[node]
524
525 def clear(self) -> None:
526 self.root = 0
527 self.len = 0
528
529 def __contains__(self, key: T):
530 self._set_search_splay(key)
531 return self.keys[self.root] == key
532
533 def __iter__(self):
534 raise NotImplementedError
535
536 def __next__(self):
537 raise NotImplementedError
538
539 def __len__(self):
540 return self.len
541
542 def __bool__(self):
543 return self.root != 0
544
545 def __str__(self):
546 return "{" + ", ".join(map(str, self.tolist())) + "}"
547
548 def __repr__(self):
549 return f"SplayTreeMultisetTopDown({self})"
仕様¶
- class SplayTreeMultisetTopDown(a: Iterable[T] = [], e: T = 0)[source]¶
Bases:
OrderedMultisetInterface
,Generic
[T
]