1# from titan_pylib.data_structures.avl_tree.avl_tree_dict import AVLTreeDict
2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
3from typing import Protocol
4
5
6class SupportsLessThan(Protocol):
7
8 def __lt__(self, other) -> bool: ...
9from typing import Callable, Generic, Iterable, TypeVar, Union, Optional
10
11K = TypeVar("K", bound=SupportsLessThan)
12V = TypeVar("V")
13
14
15class AVLTreeDict(Generic[K, V]):
16
17 class Node:
18
19 def __init__(self, key: K, val: V):
20 self.key: K = key
21 self.val: V = val
22 self.left: Optional[AVLTreeDict.Node] = None
23 self.right: Optional[AVLTreeDict.Node] = None
24 self.balance = 0
25
26 def __str__(self):
27 if self.left is None and self.right is None:
28 return f"key:{self.key, self.val}\n"
29 return (
30 f"key:{self.key, self.val},\n left:{self.left},\n right:{self.right}\n"
31 )
32
33 def __init__(
34 self,
35 a: Iterable[K] = [],
36 counter: bool = False,
37 default: Callable[[], K] = None,
38 ) -> None:
39 self._default = default
40 self.node = None
41 self._len = 0
42 if counter:
43 self._default = int
44 self._build(a)
45
46 def _build(self, a: Iterable[K]) -> None:
47 for a_ in sorted(a):
48 self.__setitem__(a_, self.__getitem__(a_) + 1)
49
50 def _rotate_L(self, node: Node) -> Node:
51 u = node.left
52 node.left = u.right
53 u.right = node
54 if u.balance == 1:
55 u.balance = 0
56 node.balance = 0
57 else:
58 u.balance = -1
59 node.balance = 1
60 return u
61
62 def _rotate_R(self, node: Node) -> Node:
63 u = node.right
64 node.right = u.left
65 u.left = node
66 if u.balance == -1:
67 u.balance = 0
68 node.balance = 0
69 else:
70 u.balance = 1
71 node.balance = -1
72 return u
73
74 def _update_balance(self, node: Node) -> None:
75 if node.balance == 1:
76 node.right.balance = -1
77 node.left.balance = 0
78 elif node.balance == -1:
79 node.right.balance = 0
80 node.left.balance = 1
81 else:
82 node.right.balance = 0
83 node.left.balance = 0
84 node.balance = 0
85
86 def _rotate_LR(self, node: Node) -> Node:
87 B = node.left
88 E = B.right
89 B.right = E.left
90 E.left = B
91 node.left = E.right
92 E.right = node
93 self._update_balance(E)
94 return E
95
96 def _rotate_RL(self, node: Node) -> Node:
97 C = node.right
98 D = C.left
99 C.left = D.right
100 D.right = C
101 node.right = D.left
102 D.left = node
103 self._update_balance(D)
104 return D
105
106 def items(self):
107 a = self.tolist_items()
108 for i in range(self.__len__()):
109 yield a[i]
110
111 def keys(self):
112 a = self.tolist_items()
113 for i in range(self.__len__()):
114 yield a[i][0]
115
116 def values(self):
117 a = self.tolist_items()
118 for i in range(self.__len__()):
119 yield a[i][1]
120
121 def _search_node(self, key: K) -> Union[Node, None]:
122 node = self.node
123 while node is not None:
124 if key == node.key:
125 return node
126 elif key < node.key:
127 node = node.left
128 else:
129 node = node.right
130 return None
131
132 def _discard(self, key: K) -> bool:
133 di = 0
134 path = []
135 node = self.node
136 while node is not None:
137 if key == node.key:
138 break
139 elif key < node.key:
140 path.append(node)
141 di <<= 1
142 di |= 1
143 node = node.left
144 else:
145 path.append(node)
146 di <<= 1
147 node = node.right
148 else:
149 return False
150 if node.left is not None and node.right is not None:
151 path.append(node)
152 di <<= 1
153 di |= 1
154 lmax = node.left
155 while lmax.right is not None:
156 path.append(lmax)
157 di <<= 1
158 lmax = lmax.right
159 node.key = lmax.key
160 node = lmax
161 cnode = node.right if node.left is None else node.left
162 if path:
163 pnode = path[-1]
164 if di & 1:
165 pnode.left = cnode
166 else:
167 pnode.right = cnode
168 else:
169 self.node = cnode
170 return True
171 while path:
172 new_node = None
173 pnode = path.pop()
174 pnode.balance -= 1 if di & 1 else -1
175 di >>= 1
176 if pnode.balance == 2:
177 new_node = (
178 self._rotate_LR(pnode)
179 if pnode.left.balance == -1
180 else self._rotate_L(pnode)
181 )
182 elif pnode.balance == -2:
183 new_node = (
184 self._rotate_RL(pnode)
185 if pnode.right.balance == 1
186 else self._rotate_R(pnode)
187 )
188 elif pnode.balance != 0:
189 break
190 if new_node is not None:
191 if not path:
192 self.node = new_node
193 return True
194 if di & 1:
195 path[-1].left = new_node
196 else:
197 path[-1].right = new_node
198 if new_node.balance != 0:
199 break
200 return True
201
202 def tolist_items(self) -> list[tuple[K, V]]:
203 a = []
204 if self.node is None:
205 return a
206
207 def rec(node):
208 if node.left is not None:
209 rec(node.left)
210 a.append((node.key, node.val))
211 if node.right is not None:
212 rec(node.right)
213
214 rec(self.node)
215 return a
216
217 def __setitem__(self, key: K, val: V):
218 self._len += 1
219 if self.node is None:
220 self.node = AVLTreeDict.Node(key, val)
221 return True
222 pnode = self.node
223 path = []
224 di = 0
225 while pnode is not None:
226 if key == pnode.key:
227 pnode.val = val
228 return
229 elif key < pnode.key:
230 path.append(pnode)
231 di <<= 1
232 di |= 1
233 pnode = pnode.left
234 else:
235 path.append(pnode)
236 di <<= 1
237 pnode = pnode.right
238 if di & 1:
239 path[-1].left = AVLTreeDict.Node(key, val)
240 else:
241 path[-1].right = AVLTreeDict.Node(key, val)
242 new_node = None
243 while path:
244 pnode = path.pop()
245 pnode.balance += 1 if di & 1 else -1
246 di >>= 1
247 if pnode.balance == 0:
248 break
249 if pnode.balance == 2:
250 new_node = (
251 self._rotate_LR(pnode)
252 if pnode.left.balance == -1
253 else self._rotate_L(pnode)
254 )
255 break
256 elif pnode.balance == -2:
257 new_node = (
258 self._rotate_RL(pnode)
259 if pnode.right.balance == 1
260 else self._rotate_R(pnode)
261 )
262 break
263 if new_node is not None:
264 if path:
265 gnode = path.pop()
266 if di & 1:
267 gnode.left = new_node
268 else:
269 gnode.right = new_node
270 else:
271 self.node = new_node
272 return True
273
274 def __delitem__(self, key: K):
275 if self._discard(key):
276 self._len -= 1
277 return
278 raise KeyError(key)
279
280 def __getitem__(self, key: K):
281 node = self._search_node(key)
282 return self.__missing__() if node is None else node.val
283
284 def __contains__(self, key: K):
285 return self._search_node(key) is not None
286
287 def __len__(self):
288 return self._len
289
290 def __bool__(self):
291 return self.node is not None
292
293 def __str__(self):
294 return "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.items())) + "}"
295
296 def __repr__(self):
297 return "AVLTreeDict(" + str(self) + ")"
298
299 def __missing__(self, e):
300 return self._default()