splay_tree_multiset_sum¶
ソースコード¶
from titan_pylib.data_structures.splay_tree.splay_tree_multiset_sum import SplayTreeMultisetSum
展開済みコード¶
1# from titan_pylib.data_structures.splay_tree.splay_tree_multiset_sum import SplayTreeMultisetSum
2import sys
3from typing import Iterator, Optional, Generic, Iterable, TypeVar
4
5T = TypeVar("T")
6
7
8class SplayTreeMultisetSum(Generic[T]):
9
10 class Node:
11
12 def __init__(self, key, cnt: int) -> None:
13 self.data = key * cnt
14 self.key = key
15 self.size = 1
16 self.cnt = cnt
17 self.cntsize = cnt
18 self.left = None
19 self.right = None
20
21 def __str__(self) -> str:
22 if self.left is None and self.right is None:
23 return f"key:{self.key, self.size, self.cnt, self.cntsize}\n"
24 return f"key:{self.key, self.size, self.cnt, self.cntsize},\n left:{self.left},\n right:{self.right}\n"
25
26 __repr__ = __str__
27
28 def __init__(self, e: T, a: Iterable[T] = [], _node=None) -> None:
29 self.node = _node
30 self.e = e
31 if a:
32 self._build(a)
33
34 def _build(self, a: Iterable[T]) -> None:
35 Node = SplayTreeMultisetSum.Node
36
37 def sort(l: int, r: int) -> SplayTreeMultisetSum.Node:
38 mid = (l + r) >> 1
39 node = Node(key[mid], cnt[mid])
40 if l != mid:
41 node.left = sort(l, mid)
42 if mid + 1 != r:
43 node.right = sort(mid + 1, r)
44 self._update(node)
45 return node
46
47 key, cnt = self._rle(sorted(a))
48 self.node = sort(0, len(key))
49
50 def _rle(self, a: list[T]) -> tuple[list[T], list[int]]:
51 x = []
52 y = []
53 x.append(a[0])
54 y.append(1)
55 for i, e in enumerate(a):
56 if i == 0:
57 continue
58 if e == x[-1]:
59 y[-1] += 1
60 continue
61 x.append(e)
62 y.append(1)
63 return x, y
64
65 def _update(self, node: Node) -> None:
66 if node.left is None:
67 if node.right is None:
68 node.size = 1
69 node.cntsize = node.cnt
70 node.data = node.key * node.cnt
71 else:
72 node.size = 1 + node.right.size
73 node.cntsize = node.cnt + node.right.cntsize
74 node.data = node.key * node.cnt + node.right.data
75 else:
76 if node.right is None:
77 node.size = 1 + node.left.size
78 node.cntsize = node.cnt + node.left.cntsize
79 node.data = node.key * node.cnt + node.left.data
80 else:
81 node.size = 1 + node.left.size + node.right.size
82 node.cntsize = node.cnt + node.left.cntsize + node.right.cntsize
83 node.data = node.key * node.cnt + node.left.data + node.right.data
84
85 def _splay(self, path: list[Node], d: int) -> Node:
86 for _ in range(len(path) >> 1):
87 node = path.pop()
88 pnode = path.pop()
89 if d & 1 == d >> 1 & 1:
90 if d & 1:
91 tmp = node.left
92 node.left = tmp.right
93 tmp.right = node
94 pnode.left = node.right
95 node.right = pnode
96 else:
97 tmp = node.right
98 node.right = tmp.left
99 tmp.left = node
100 pnode.right = node.left
101 node.left = pnode
102 else:
103 if d & 1:
104 tmp = node.left
105 node.left = tmp.right
106 pnode.right = tmp.left
107 tmp.right = node
108 tmp.left = pnode
109 else:
110 tmp = node.right
111 node.right = tmp.left
112 pnode.left = tmp.right
113 tmp.left = node
114 tmp.right = pnode
115 self._update(pnode)
116 self._update(node)
117 self._update(tmp)
118 if not path:
119 return tmp
120 d >>= 2
121 if d & 1:
122 path[-1].left = tmp
123 else:
124 path[-1].right = tmp
125 gnode = path[0]
126 if d & 1:
127 node = gnode.left
128 gnode.left = node.right
129 node.right = gnode
130 else:
131 node = gnode.right
132 gnode.right = node.left
133 node.left = gnode
134 self._update(gnode)
135 self._update(node)
136 return node
137
138 def _set_search_splay(self, key: T) -> None:
139 node = self.node
140 if node is None or node.key == key:
141 return
142 path = []
143 d = 0
144 while True:
145 if node.key == key:
146 break
147 elif key < node.key:
148 if node.left is None:
149 break
150 path.append(node)
151 d <<= 1
152 d |= 1
153 node = node.left
154 else:
155 if node.right is None:
156 break
157 path.append(node)
158 d <<= 1
159 node = node.right
160 if path:
161 self.node = self._splay(path, d)
162
163 def _set_kth_elm_splay(self, k: int) -> None:
164 if k < 0:
165 k += self.__len__()
166 d = 0
167 node = self.node
168 path = []
169 while True:
170 t = node.cnt if node.left is None else node.cnt + node.left.cntsize
171 if t - node.cnt <= k < t:
172 if path:
173 self.node = self._splay(path, d)
174 break
175 elif t > k:
176 path.append(node)
177 d <<= 1
178 d |= 1
179 node = node.left
180 else:
181 path.append(node)
182 d <<= 1
183 node = node.right
184 k -= t
185
186 def _set_kth_elm_tree_splay(self, k: int) -> None:
187 if k < 0:
188 k += self.len_elm()
189 assert 0 <= k < self.len_elm()
190 d = 0
191 node = self.node
192 path = []
193 while True:
194 t = 0 if node.left is None else node.left.size
195 if t == k:
196 if path:
197 self.node = self._splay(path, d)
198 return
199 elif t > k:
200 path.append(node)
201 d <<= 1
202 d |= 1
203 node = node.left
204 else:
205 path.append(node)
206 d <<= 1
207 node = node.right
208 k -= t + 1
209
210 def _get_min_splay(self, node: Node) -> Node:
211 if node is None or node.left is None:
212 return node
213 path = []
214 while node.left is not None:
215 path.append(node)
216 node = node.left
217 return self._splay(path, (1 << len(path)) - 1)
218
219 def _get_max_splay(self, node: Node) -> Node:
220 if node is None or node.right is None:
221 return node
222 path = []
223 while node.right is not None:
224 path.append(node)
225 node = node.right
226 return self._splay(path, 0)
227
228 def add(self, key: T, cnt: int = 1) -> None:
229 if self.node is None:
230 self.node = SplayTreeMultisetSum.Node(key, cnt)
231 return
232 self._set_search_splay(key)
233 if self.node.key == key:
234 self.node.cnt += cnt
235 self._update(self.node)
236 return
237 node = SplayTreeMultisetSum.Node(key, cnt)
238 if key < self.node.key:
239 node.left = self.node.left
240 node.right = self.node
241 self.node.left = None
242 self._update(node.right)
243 else:
244 node.left = self.node
245 node.right = self.node.right
246 self.node.right = None
247 self._update(node.left)
248 self._update(node)
249 self.node = node
250 return
251
252 def discard(self, key: T, cnt: int = 1) -> bool:
253 if self.node is None:
254 return False
255 self._set_search_splay(key)
256 if self.node.key != key:
257 return False
258 if self.node.cnt > cnt:
259 self.node.cnt -= cnt
260 self._update(self.node)
261 return True
262 if self.node.left is None:
263 self.node = self.node.right
264 elif self.node.right is None:
265 self.node = self.node.left
266 else:
267 node = self._get_min_splay(self.node.right)
268 node.left = self.node.left
269 self._update(node)
270 self.node = node
271 return True
272
273 def discard_all(self, key: T) -> bool:
274 return self.discard(key, self.count(key))
275
276 def count(self, key: T) -> int:
277 if self.node is None:
278 return 0
279 self._set_search_splay(key)
280 return self.node.cnt if self.node.key == key else 0
281
282 def le(self, key: T) -> Optional[T]:
283 node = self.node
284 if node is None:
285 return None
286 path = []
287 d = 0
288 res = None
289 while True:
290 if node.key == key:
291 res = key
292 break
293 elif key < node.key:
294 if node.left is None:
295 break
296 path.append(node)
297 d <<= 1
298 d |= 1
299 node = node.left
300 else:
301 res = node.key
302 if node.right is None:
303 break
304 path.append(node)
305 d <<= 1
306 node = node.right
307 if path:
308 self.node = self._splay(path, d)
309 return res
310
311 def lt(self, key: T) -> Optional[T]:
312 node = self.node
313 path = []
314 d = 0
315 res = None
316 while node is not None:
317 if key <= node.key:
318 path.append(node)
319 d <<= 1
320 d |= 1
321 node = node.left
322 else:
323 path.append(node)
324 d <<= 1
325 res = node.key
326 node = node.right
327 else:
328 if path:
329 path.pop()
330 d >>= 1
331 if path:
332 self.node = self._splay(path, d)
333 return res
334
335 def ge(self, key: T) -> Optional[T]:
336 node = self.node
337 if node is None:
338 return None
339 path = []
340 d = 0
341 res = None
342 while True:
343 if node.key == key:
344 res = node.key
345 break
346 elif key < node.key:
347 res = node.key
348 if node.left is None:
349 break
350 path.append(node)
351 d <<= 1
352 d |= 1
353 node = node.left
354 else:
355 if node.right is None:
356 break
357 path.append(node)
358 d <<= 1
359 node = node.right
360 if path:
361 self.node = self._splay(path, d)
362 return res
363
364 def gt(self, key: T) -> Optional[T]:
365 node = self.node
366 path = []
367 d = 0
368 res = None
369 while node is not None:
370 if key < node.key:
371 path.append(node)
372 d <<= 1
373 d |= 1
374 res = node.key
375 node = node.left
376 else:
377 path.append(node)
378 d <<= 1
379 node = node.right
380 else:
381 if path:
382 path.pop()
383 d >>= 1
384 if path:
385 self.node = self._splay(path, d)
386 return res
387
388 def index(self, key: T) -> int:
389 if self.node is None:
390 return 0
391 self._set_search_splay(key)
392 res = 0 if self.node.left is None else self.node.left.cntsize
393 if self.node.key < key:
394 res += self.node.cnt
395 return res
396
397 def index_right(self, key: T) -> int:
398 if self.node is None:
399 return 0
400 self._set_search_splay(key)
401 res = 0 if self.node.left is None else self.node.left.cntsize
402 if self.node.key <= key:
403 res += self.node.cnt
404 return res
405
406 def index_keys(self, key: T) -> int:
407 if self.node is None:
408 return 0
409 self._set_search_splay(key)
410 res = 0 if self.node.left is None else self.node.left.size
411 if self.node.key < key:
412 res += 1
413 return res
414
415 def index_right_keys(self, key: T) -> int:
416 if self.node is None:
417 return 0
418 self._set_search_splay(key)
419 res = 0 if self.node.left is None else self.node.left.size
420 if self.node.key <= key:
421 res += 1
422 return res
423
424 def pop(self, k: int = -1) -> T:
425 self._set_kth_elm_splay(k)
426 res = self.node.key
427 self.discard(res)
428 return res
429
430 def pop_max(self) -> T:
431 return self.pop()
432
433 def pop_min(self) -> T:
434 return self.pop(0)
435
436 def tolist(self) -> list[T]:
437 a = []
438 if self.node is None:
439 return a
440 if sys.getrecursionlimit() < self.len_elm():
441 sys.setrecursionlimit(self.len_elm() + 1)
442
443 def rec(node):
444 if node.left is not None:
445 rec(node.left)
446 for _ in range(node.cnt):
447 a.append(node.key)
448 if node.right is not None:
449 rec(node.right)
450
451 rec(self.node)
452 return a
453
454 def tolist_items(self) -> list[tuple[T, int]]:
455 a = []
456 if self.node is None:
457 return a
458 if sys.getrecursionlimit() < self.len_elm():
459 sys.setrecursionlimit(self.len_elm() + 1)
460
461 def rec(node):
462 if node.left is not None:
463 rec(node.left)
464 a.append((node.key, node.cnt))
465 if node.right is not None:
466 rec(node.right)
467
468 rec(self.node)
469 return a
470
471 def get_elm(self, k: int) -> T:
472 assert -self.len_elm() <= k < self.len_elm()
473 self._set_kth_elm_tree_splay(k)
474 return self.node.key
475
476 def items(self) -> Iterator[tuple[T, int]]:
477 for i in range(self.len_elm()):
478 self._set_kth_elm_tree_splay(i)
479 yield self.node.key, self.node.cnt
480
481 def keys(self) -> Iterator[T]:
482 for i in range(self.len_elm()):
483 self._set_kth_elm_tree_splay(i)
484 yield self.node.key
485
486 def cntues(self) -> Iterator[int]:
487 for i in range(self.len_elm()):
488 self._set_kth_elm_tree_splay(i)
489 yield self.node.cnt
490
491 def len_elm(self) -> int:
492 return 0 if self.node is None else self.node.size
493
494 def show(self) -> None:
495 print(
496 "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}"
497 )
498
499 def clear(self) -> None:
500 self.node = None
501
502 def get_sum(self) -> T:
503 if self.node is None:
504 return self.e
505 return self.node.data
506
507 def merge(self, other: "SplayTreeMultisetSum") -> None:
508 if self.node is None:
509 self.node = other.node
510 return
511 if other.node is None:
512 return
513 self.node = self._get_max_splay(self.node)
514 other.node = other._get_min_splay(other.node)
515 assert self.node.key <= other.node.key
516 if self.node.key == other.node.key:
517 self.node.cnt += other.node.cnt
518 self._update(self.node)
519 other.discard(other.node.key, other.node.cnt)
520 self.node.right = other.node
521 self._update(self.node)
522
523 def split(self, k: int) -> tuple["SplayTreeMultisetSum", "SplayTreeMultisetSum"]:
524 if self.node is None:
525 return SplayTreeMultisetSum(self.e), self
526 if k >= len(self):
527 return self, SplayTreeMultisetSum(self.e)
528 self._set_kth_elm_splay(k)
529 if self.node.left is None:
530 left = SplayTreeMultisetSum(self.e)
531 else:
532 left = SplayTreeMultisetSum(self.e, _node=self.node.left)
533 if len(left) != k:
534 assert k - len(left) > 0
535 self.node.cnt -= k - len(left)
536 self._update(self.node)
537 root = SplayTreeMultisetSum.Node(self.node.key, k - len(left))
538 root.left = left.node
539 left = SplayTreeMultisetSum(self.e, _node=root)
540 left._update(left.node)
541 if self.node:
542 self.node.left = None
543 self._update(self.node)
544 return left, self
545
546 def __iter__(self):
547 self.__iter = 0
548 return self
549
550 def __next__(self):
551 if self.__iter == self.__len__():
552 raise StopIteration
553 res = self.__getitem__(self.__iter)
554 self.__iter += 1
555 return res
556
557 def __reversed__(self):
558 for i in range(self.__len__()):
559 yield self.__getitem__(-i - 1)
560
561 def __contains__(self, key: T) -> bool:
562 self._set_search_splay(key)
563 return self.node is not None and self.node.key == key
564
565 def __getitem__(self, k: int) -> T:
566 self._set_kth_elm_splay(k)
567 return self.node.key
568
569 def __len__(self):
570 return 0 if self.node is None else self.node.cntsize
571
572 def __bool__(self):
573 return self.node is not None
574
575 def __str__(self):
576 return "{" + ", ".join(map(str, self.tolist())) + "}"
577
578 def __repr__(self):
579 return "SplayTreeMultisetSum(" + str(self) + ")"
仕様¶
- class SplayTreeMultisetSum(e: T, a: Iterable[T] = [], _node=None)[source]¶
Bases:
Generic
[T
]- merge(other: SplayTreeMultisetSum) None [source]¶
- split(k: int) tuple[SplayTreeMultisetSum, SplayTreeMultisetSum] [source]¶