1from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
2from titan_pylib.my_class.supports_less_than import SupportsLessThan
3from titan_pylib.data_structures.bst_base.bst_set_node_base import BSTSetNodeBase
4import math
5from typing import Final, Iterator, TypeVar, Generic, Iterable, Optional
6
7T = TypeVar("T", bound=SupportsLessThan)
8
9
[docs]
10class ScapegoatTreeSet(OrderedSetInterface, Generic[T]):
11
12 ALPHA: Final[float] = 0.75
13 BETA: Final[float] = math.log2(1 / ALPHA)
14
[docs]
15 class Node:
16
17 def __init__(self, key: T):
18 self.key: T = key
19 self.left: Optional["ScapegoatTreeSet.Node"] = None
20 self.right: Optional["ScapegoatTreeSet.Node"] = None
21 self.size: int = 1
22
23 def __str__(self):
24 if self.left is None and self.right is None:
25 return f"key:{self.key, self.size}\n"
26 return (
27 f"key:{self.key, self.size},\n left:{self.left},\n right:{self.right}\n"
28 )
29
30 def __init__(self, a: Iterable[T] = []):
31 self.root: Optional["ScapegoatTreeSet.Node"] = None
32 if not isinstance(a, list):
33 a = list(a)
34 if a:
35 self._build(a)
36
37 def _build(self, a: list[T]) -> None:
38 Node = ScapegoatTreeSet.Node
39
40 def rec(l: int, r: int) -> ScapegoatTreeSet.Node:
41 mid = (l + r) >> 1
42 node = Node(a[mid])
43 if l != mid:
44 node.left = rec(l, mid)
45 node.size += node.left.size
46 if mid + 1 != r:
47 node.right = rec(mid + 1, r)
48 node.size += node.right.size
49 return node
50
51 a = BSTSetNodeBase[T, ScapegoatTreeSet.Node].sort_unique(a)
52 self.root = rec(0, len(a))
53
54 def _rebuild(self, node: Node) -> Node:
55 def rec(l: int, r: int) -> "ScapegoatTreeSet.Node":
56 mid = (l + r) >> 1
57 node = a[mid]
58 node.size = 1
59 if l != mid:
60 node.left = rec(l, mid)
61 node.size += node.left.size
62 else:
63 node.left = None
64 if mid + 1 != r:
65 node.right = rec(mid + 1, r)
66 node.size += node.right.size
67 else:
68 node.right = None
69 return node
70
71 a = []
72 stack = []
73 while stack or node:
74 if node:
75 stack.append(node)
76 node = node.left
77 else:
78 node = stack.pop()
79 a.append(node)
80 node = node.right
81 return rec(0, len(a))
82
[docs]
83 def add(self, key: T) -> bool:
84 Node = ScapegoatTreeSet.Node
85 node = self.root
86 if node is None:
87 self.root = Node(key)
88 return True
89 path = []
90 while node:
91 path.append(node)
92 if key == node.key:
93 return False
94 node = node.left if key < node.key else node.right
95 if key < path[-1].key:
96 path[-1].left = Node(key)
97 else:
98 path[-1].right = Node(key)
99 if len(path) * ScapegoatTreeSet.BETA > math.log(self.root.size):
100 node_size = 1
101 while path:
102 pnode = path.pop()
103 pnode_size = pnode.size + 1
104 if ScapegoatTreeSet.ALPHA * pnode_size < node_size:
105 break
106 node_size = pnode_size
107 new_node = self._rebuild(pnode)
108 if not path:
109 self.root = new_node
110 return True
111 if new_node.key < path[-1].key:
112 path[-1].left = new_node
113 else:
114 path[-1].right = new_node
115 for p in path:
116 p.size += 1
117 return True
118
[docs]
119 def discard(self, key: T) -> bool:
120 d = 1
121 node = self.root
122 path = []
123 while node is not None:
124 if key == node.key:
125 break
126 path.append(node)
127 d = key < node.key
128 node = node.left if d else node.right
129 else:
130 return False
131 if node.left is not None and node.right is not None:
132 path.append(node)
133 lmax = node.left
134 d = 1 if lmax.right is None else 0
135 while lmax.right is not None:
136 path.append(lmax)
137 lmax = lmax.right
138 node.key = lmax.key
139 node = lmax
140 cnode = node.right if node.left is None else node.left
141 if path:
142 if d == 1:
143 path[-1].left = cnode
144 else:
145 path[-1].right = cnode
146 else:
147 self.root = cnode
148 for p in path:
149 p.size -= 1
150 return True
151
[docs]
152 def remove(self, key: T) -> None:
153 if self.discard(key):
154 return
155 raise KeyError
156
[docs]
157 def le(self, key: T) -> Optional[T]:
158 return BSTSetNodeBase[T, ScapegoatTreeSet.Node].le(self.root, key)
159
[docs]
160 def lt(self, key: T) -> Optional[T]:
161 return BSTSetNodeBase[T, ScapegoatTreeSet.Node].lt(self.root, key)
162
[docs]
163 def ge(self, key: T) -> Optional[T]:
164 return BSTSetNodeBase[T, ScapegoatTreeSet.Node].ge(self.root, key)
165
[docs]
166 def gt(self, key: T) -> Optional[T]:
167 return BSTSetNodeBase[T, ScapegoatTreeSet.Node].gt(self.root, key)
168
[docs]
169 def index(self, key: T) -> int:
170 return BSTSetNodeBase[T, ScapegoatTreeSet.Node].index(self.root, key)
171
[docs]
172 def index_right(self, key: T) -> int:
173 return BSTSetNodeBase[T, ScapegoatTreeSet.Node].index_right(self.root, key)
174
[docs]
175 def pop(self, k: int = -1) -> T:
176 if k < 0:
177 k += len(self)
178 d = 1
179 node = self.root
180 path = []
181 while True:
182 t = 0 if node.left is None else node.left.size
183 if t == k:
184 break
185 path.append(node)
186 if t < k:
187 node = node.right
188 k -= t + 1
189 d = 0
190 elif t > k:
191 d = 1
192 node = node.left
193 res = node.key
194 if node.left is not None and node.right is not None:
195 path.append(node)
196 lmax = node.left
197 d = 1 if lmax.right is None else 0
198 while lmax.right is not None:
199 path.append(lmax)
200 lmax = lmax.right
201 node.key = lmax.key
202 node = lmax
203 cnode = node.right if node.left is None else node.left
204 if path:
205 if d == 1:
206 path[-1].left = cnode
207 else:
208 path[-1].right = cnode
209 else:
210 self.root = cnode
211 for p in path:
212 p.size -= 1
213 return res
214
[docs]
215 def pop_min(self) -> T:
216 return self.pop(0)
217
[docs]
218 def pop_max(self) -> T:
219 return self.pop(-1)
220
[docs]
221 def clear(self) -> None:
222 self.root = None
223
[docs]
224 def tolist(self) -> list[T]:
225 return BSTSetNodeBase[T, ScapegoatTreeSet.Node].tolist(self.root)
226
[docs]
227 def get_min(self) -> T:
228 return self[0]
229
[docs]
230 def get_max(self) -> T:
231 return self[-1]
232
233 def __contains__(self, key: T):
234 node = self.root
235 while node is not None:
236 if key == node.key:
237 return True
238 node = node.left if key < node.key else node.right
239 return False
240
241 def __getitem__(self, k: int) -> T:
242 return BSTSetNodeBase[T, ScapegoatTreeSet.Node].kth_elm(self.root, k, len(self))
243
244 def __iter__(self) -> Iterator[T]:
245 self.__iter = 0
246 return self
247
248 def __next__(self) -> T:
249 if self.__iter == self.__len__():
250 raise StopIteration
251 res = self[self.__iter]
252 self.__iter += 1
253 return res
254
255 def __reversed__(self):
256 for i in range(self.__len__()):
257 yield self[-i - 1]
258
259 def __len__(self):
260 return 0 if self.root is None else self.root.size
261
262 def __bool__(self):
263 return self.root is not None
264
265 def __str__(self):
266 return "{" + ", ".join(map(str, self.tolist())) + "}"
267
268 def __repr__(self):
269 return f"{self.__class__.__name__}({self})"