1from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
2from titan_pylib.my_class.supports_less_than import SupportsLessThan
3from titan_pylib.data_structures.bst_base.bst_multiset_node_base import (
4 BSTMultisetNodeBase,
5)
6import math
7from typing import Final, TypeVar, Generic, Iterable, Optional, Iterator
8
9T = TypeVar("T", bound=SupportsLessThan)
10
11
[docs]
12class ScapegoatTreeMultiset(OrderedMultisetInterface, Generic[T]):
13
14 ALPHA: Final[float] = 0.75
15 BETA: Final[float] = math.log2(1 / ALPHA)
16
[docs]
17 class Node:
18
19 def __init__(self, key: T, val: int):
20 self.key: T = key
21 self.val: int = val
22 self.size: int = 1
23 self.valsize: int = val
24 self.left: Optional[ScapegoatTreeMultiset.Node] = None
25 self.right: Optional[ScapegoatTreeMultiset.Node] = None
26
27 def __str__(self):
28 if self.left is None and self.right is None:
29 return f"key:{self.key, self.val, self.size, self.valsize}\n"
30 return f"key:{self.key, self.val, self.size, self.valsize},\n left:{self.left},\n right:{self.right}\n"
31
32 def __init__(self, a: Iterable[T] = []):
33 self.root = None
34 if not isinstance(a, list):
35 a = list(a)
36 self._build(a)
37
38 def _build(self, a: list[T]) -> None:
39 Node = ScapegoatTreeMultiset.Node
40
41 def rec(l: int, r: int) -> ScapegoatTreeMultiset.Node:
42 mid = (l + r) >> 1
43 node = Node(x[mid], y[mid])
44 if l != mid:
45 node.left = rec(l, mid)
46 node.size += node.left.size
47 node.valsize += node.left.valsize
48 if mid + 1 != r:
49 node.right = rec(mid + 1, r)
50 node.size += node.right.size
51 node.valsize += node.right.valsize
52 return node
53
54 if not all(a[i] <= a[i + 1] for i in range(len(a) - 1)):
55 a = sorted(a)
56 if not a:
57 return
58 x, y = BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node]._rle(a)
59 self.root = rec(0, len(x))
60
61 def _rebuild(self, node: Node) -> Node:
62 def rec(l: int, r: int) -> ScapegoatTreeMultiset.Node:
63 mid = (l + r) >> 1
64 node = a[mid]
65 node.size = 1
66 node.valsize = node.val
67 if l != mid:
68 node.left = rec(l, mid)
69 node.size += node.left.size
70 node.valsize += node.left.valsize
71 else:
72 node.left = None
73 if mid + 1 != r:
74 node.right = rec(mid + 1, r)
75 node.size += node.right.size
76 node.valsize += node.right.valsize
77 else:
78 node.right = None
79 return node
80
81 a = []
82 stack = []
83 while stack or node:
84 if node:
85 stack.append(node)
86 node = node.left
87 else:
88 node = stack.pop()
89 a.append(node)
90 node = node.right
91 return rec(0, len(a))
92
93 def _kth_elm(self, k: int) -> tuple[T, int]:
94 if k < 0:
95 k += len(self)
96 node = self.root
97 while node:
98 t = (node.val + node.left.valsize) if node.left else node.val
99 if t - node.val <= k and k < t:
100 return node.key, node.val
101 elif t > k:
102 node = node.left
103 else:
104 node = node.right
105 k -= t
106
107 def _kth_elm_tree(self, k: int) -> tuple[T, int]:
108 if k < 0:
109 k += self.len_elm()
110 node = self.root
111 while node:
112 t = node.left.size if node.left else 0
113 if t == k:
114 return node.key, node.val
115 if t > k:
116 node = node.left
117 else:
118 node = node.right
119 k -= t + 1
120 assert False, "IndexError"
121
[docs]
122 def add(self, key: T, val: int = 1) -> None:
123 if val <= 0:
124 return
125 if not self.root:
126 self.root = ScapegoatTreeMultiset.Node(key, val)
127 return
128 node = self.root
129 path = []
130 while node:
131 path.append(node)
132 if key == node.key:
133 node.val += val
134 for p in path:
135 p.valsize += val
136 return
137 node = node.left if key < node.key else node.right
138 if key < path[-1].key:
139 path[-1].left = ScapegoatTreeMultiset.Node(key, val)
140 else:
141 path[-1].right = ScapegoatTreeMultiset.Node(key, val)
142 if len(path) * ScapegoatTreeMultiset.BETA > math.log(self.len_elm()):
143 node_size = 1
144 while path:
145 pnode = path.pop()
146 pnode_size = pnode.size + 1
147 if ScapegoatTreeMultiset.ALPHA * pnode_size < node_size:
148 break
149 node_size = pnode_size
150 new_node = self._rebuild(pnode)
151 if not path:
152 self.root = new_node
153 return
154 if new_node.key < path[-1].key:
155 path[-1].left = new_node
156 else:
157 path[-1].right = new_node
158 for p in path:
159 p.size += 1
160 p.valsize += val
161
162 def _discard(self, key: T) -> bool:
163 path = []
164 node = self.root
165 di, cnt = 1, 0
166 while node:
167 if key == node.key:
168 break
169 path.append(node)
170 di = key < node.key
171 node = node.left if di else node.right
172 if node.left and node.right:
173 path.append(node)
174 lmax = node.left
175 di = 0 if lmax.right else 1
176 while lmax.right:
177 cnt += 1
178 path.append(lmax)
179 lmax = lmax.right
180 lmax_val = lmax.val
181 node.key = lmax.key
182 node.val = lmax_val
183 node = lmax
184 cnode = node.left if node.left else node.right
185 if path:
186 if di == 1:
187 path[-1].left = cnode
188 else:
189 path[-1].right = cnode
190 else:
191 self.root = cnode
192 return True
193 for _ in range(cnt):
194 p = path.pop()
195 p.size -= 1
196 p.valsize -= lmax_val
197 for p in path:
198 p.size -= 1
199 p.valsize -= 1
200 return True
201
[docs]
202 def discard(self, key: T, val=1) -> bool:
203 if val <= 0:
204 return True
205 path = []
206 node = self.root
207 while node:
208 path.append(node)
209 if key == node.key:
210 break
211 node = node.left if key < node.key else node.right
212 else:
213 return False
214 if val > node.val:
215 val = node.val - 1
216 if val > 0:
217 node.val -= val
218 while path:
219 path.pop().valsize -= val
220 if node.val == 1:
221 self._discard(key)
222 else:
223 node.val -= val
224 while path:
225 path.pop().valsize -= val
226 return True
227
[docs]
228 def remove(self, key: T, val: int = 1) -> None:
229 c = self.count(key)
230 if c > val:
231 raise KeyError(key)
232 self.discard(key, val)
233
[docs]
234 def count(self, key: T) -> int:
235 node = self.root
236 while node:
237 if key == node.key:
238 return node.val
239 node = node.left if key < node.key else node.right
240 return 0
241
[docs]
242 def discard_all(self, key: T) -> bool:
243 return self.discard(key, self.count(key))
244
[docs]
245 def le(self, key: T) -> Optional[T]:
246 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].le(self.root, key)
247
[docs]
248 def lt(self, key: T) -> Optional[T]:
249 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].lt(self.root, key)
250
[docs]
251 def ge(self, key: T) -> Optional[T]:
252 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].ge(self.root, key)
253
[docs]
254 def gt(self, key: T) -> Optional[T]:
255 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].gt(self.root, key)
256
[docs]
257 def index(self, key: T) -> int:
258 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].index(self.root, key)
259
[docs]
260 def index_right(self, key: T) -> int:
261 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].index_right(
262 self.root, key
263 )
264
[docs]
265 def index_keys(self, key: T) -> int:
266 k = 0
267 node = self.root
268 while node:
269 if key == node.key:
270 if node.left:
271 k += node.left.size
272 break
273 elif key < node.key:
274 node = node.left
275 else:
276 k += node.val if node.left is None else node.left.size + node.val
277 node = node.right
278 return k
279
[docs]
280 def index_right_keys(self, key: T) -> int:
281 k = 0
282 node = self.root
283 while node:
284 if key == node.key:
285 k += node.val if node.left is None else node.left.size + node.val
286 break
287 if key < node.key:
288 node = node.left
289 else:
290 k += node.val if node.left is None else node.left.size + node.val
291 node = node.right
292 return k
293
[docs]
294 def pop(self, k: int = -1) -> T:
295 if k < 0:
296 k += self.root.valsize
297 x = self[k]
298 self.discard(x)
299 return x
300
[docs]
301 def pop_min(self) -> T:
302 return self.pop(0)
303
[docs]
304 def pop_max(self) -> T:
305 return self.pop(-1)
306
[docs]
307 def items(self) -> Iterator[tuple[T, int]]:
308 for i in range(self.len_elm()):
309 yield self._kth_elm_tree(i)
310
[docs]
311 def keys(self) -> Iterator[T]:
312 for i in range(self.len_elm()):
313 yield self._kth_elm_tree(i)[0]
314
[docs]
315 def values(self) -> Iterator[int]:
316 for i in range(self.len_elm()):
317 yield self._kth_elm_tree(i)[1]
318
[docs]
319 def show(self) -> None:
320 print(
321 "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}"
322 )
323
[docs]
324 def get_elm(self, k: int) -> T:
325 assert (
326 -self.len_elm() <= k < self.len_elm()
327 ), f"IndexError: {self.__class__.__name__}.get_elm({k}), len_elm=({self.len_elm()})"
328 return self._kth_elm_tree(k)[0]
329
[docs]
330 def len_elm(self) -> int:
331 return self.root.size if self.root else 0
332
[docs]
333 def tolist(self) -> list[T]:
334 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].tolist(self.root)
335
[docs]
336 def tolist_items(self) -> list[tuple[T, int]]:
337 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].tolist_items(
338 self.root
339 )
340
[docs]
341 def clear(self) -> None:
342 self.root = None
343
[docs]
344 def get_max(self) -> T:
345 return self._kth_elm_tree(-1)[0]
346
[docs]
347 def get_min(self) -> T:
348 return self._kth_elm_tree(0)[0]
349
350 def __contains__(self, key: T):
351 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].contains(
352 self.root, key
353 )
354
355 def __getitem__(self, k: int) -> T:
356 assert (
357 -len(self) <= k < len(self)
358 ), f"IndexError: {self.__class__.__name__}[{k}], len={len(self)}"
359 return self._kth_elm(k)[0]
360
361 def __iter__(self):
362 self.__iter = 0
363 return self
364
365 def __next__(self):
366 if self.__iter == len(self):
367 raise StopIteration
368 res = self._kth_elm(self.__iter)[0]
369 self.__iter += 1
370 return res
371
372 def __reversed__(self):
373 for i in range(len(self)):
374 yield self._kth_elm(-i - 1)[0]
375
376 def __len__(self):
377 return self.root.valsize if self.root else 0
378
379 def __bool__(self):
380 return self.root is not None
381
382 def __str__(self):
383 return "{" + ", ".join(map(str, self.tolist())) + "}"
384
385 def __repr__(self):
386 return f"{self.__class__.__name__}({self.tolist})"