Source code for titan_pylib.data_structures.avl_tree.avl_tree_dict
1from titan_pylib.my_class.supports_less_than import SupportsLessThan
2from typing import Callable, Generic, Iterable, TypeVar, Union, Optional
3
4K = TypeVar("K", bound=SupportsLessThan)
5V = TypeVar("V")
6
7
[docs]
8class AVLTreeDict(Generic[K, V]):
9
[docs]
10 class Node:
11
12 def __init__(self, key: K, val: V):
13 self.key: K = key
14 self.val: V = val
15 self.left: Optional[AVLTreeDict.Node] = None
16 self.right: Optional[AVLTreeDict.Node] = None
17 self.balance = 0
18
19 def __str__(self):
20 if self.left is None and self.right is None:
21 return f"key:{self.key, self.val}\n"
22 return (
23 f"key:{self.key, self.val},\n left:{self.left},\n right:{self.right}\n"
24 )
25
26 def __init__(
27 self,
28 a: Iterable[K] = [],
29 counter: bool = False,
30 default: Callable[[], K] = None,
31 ) -> None:
32 self._default = default
33 self.node = None
34 self._len = 0
35 if counter:
36 self._default = int
37 self._build(a)
38
39 def _build(self, a: Iterable[K]) -> None:
40 for a_ in sorted(a):
41 self.__setitem__(a_, self.__getitem__(a_) + 1)
42
43 def _rotate_L(self, node: Node) -> Node:
44 u = node.left
45 node.left = u.right
46 u.right = node
47 if u.balance == 1:
48 u.balance = 0
49 node.balance = 0
50 else:
51 u.balance = -1
52 node.balance = 1
53 return u
54
55 def _rotate_R(self, node: Node) -> Node:
56 u = node.right
57 node.right = u.left
58 u.left = node
59 if u.balance == -1:
60 u.balance = 0
61 node.balance = 0
62 else:
63 u.balance = 1
64 node.balance = -1
65 return u
66
67 def _update_balance(self, node: Node) -> None:
68 if node.balance == 1:
69 node.right.balance = -1
70 node.left.balance = 0
71 elif node.balance == -1:
72 node.right.balance = 0
73 node.left.balance = 1
74 else:
75 node.right.balance = 0
76 node.left.balance = 0
77 node.balance = 0
78
79 def _rotate_LR(self, node: Node) -> Node:
80 B = node.left
81 E = B.right
82 B.right = E.left
83 E.left = B
84 node.left = E.right
85 E.right = node
86 self._update_balance(E)
87 return E
88
89 def _rotate_RL(self, node: Node) -> Node:
90 C = node.right
91 D = C.left
92 C.left = D.right
93 D.right = C
94 node.right = D.left
95 D.left = node
96 self._update_balance(D)
97 return D
98
[docs]
99 def items(self):
100 a = self.tolist_items()
101 for i in range(self.__len__()):
102 yield a[i]
103
[docs]
104 def keys(self):
105 a = self.tolist_items()
106 for i in range(self.__len__()):
107 yield a[i][0]
108
[docs]
109 def values(self):
110 a = self.tolist_items()
111 for i in range(self.__len__()):
112 yield a[i][1]
113
114 def _search_node(self, key: K) -> Union[Node, None]:
115 node = self.node
116 while node is not None:
117 if key == node.key:
118 return node
119 elif key < node.key:
120 node = node.left
121 else:
122 node = node.right
123 return None
124
125 def _discard(self, key: K) -> bool:
126 di = 0
127 path = []
128 node = self.node
129 while node is not None:
130 if key == node.key:
131 break
132 elif key < node.key:
133 path.append(node)
134 di <<= 1
135 di |= 1
136 node = node.left
137 else:
138 path.append(node)
139 di <<= 1
140 node = node.right
141 else:
142 return False
143 if node.left is not None and node.right is not None:
144 path.append(node)
145 di <<= 1
146 di |= 1
147 lmax = node.left
148 while lmax.right is not None:
149 path.append(lmax)
150 di <<= 1
151 lmax = lmax.right
152 node.key = lmax.key
153 node = lmax
154 cnode = node.right if node.left is None else node.left
155 if path:
156 pnode = path[-1]
157 if di & 1:
158 pnode.left = cnode
159 else:
160 pnode.right = cnode
161 else:
162 self.node = cnode
163 return True
164 while path:
165 new_node = None
166 pnode = path.pop()
167 pnode.balance -= 1 if di & 1 else -1
168 di >>= 1
169 if pnode.balance == 2:
170 new_node = (
171 self._rotate_LR(pnode)
172 if pnode.left.balance == -1
173 else self._rotate_L(pnode)
174 )
175 elif pnode.balance == -2:
176 new_node = (
177 self._rotate_RL(pnode)
178 if pnode.right.balance == 1
179 else self._rotate_R(pnode)
180 )
181 elif pnode.balance != 0:
182 break
183 if new_node is not None:
184 if not path:
185 self.node = new_node
186 return True
187 if di & 1:
188 path[-1].left = new_node
189 else:
190 path[-1].right = new_node
191 if new_node.balance != 0:
192 break
193 return True
194
[docs]
195 def tolist_items(self) -> list[tuple[K, V]]:
196 a = []
197 if self.node is None:
198 return a
199
200 def rec(node):
201 if node.left is not None:
202 rec(node.left)
203 a.append((node.key, node.val))
204 if node.right is not None:
205 rec(node.right)
206
207 rec(self.node)
208 return a
209
210 def __setitem__(self, key: K, val: V):
211 self._len += 1
212 if self.node is None:
213 self.node = AVLTreeDict.Node(key, val)
214 return True
215 pnode = self.node
216 path = []
217 di = 0
218 while pnode is not None:
219 if key == pnode.key:
220 pnode.val = val
221 return
222 elif key < pnode.key:
223 path.append(pnode)
224 di <<= 1
225 di |= 1
226 pnode = pnode.left
227 else:
228 path.append(pnode)
229 di <<= 1
230 pnode = pnode.right
231 if di & 1:
232 path[-1].left = AVLTreeDict.Node(key, val)
233 else:
234 path[-1].right = AVLTreeDict.Node(key, val)
235 new_node = None
236 while path:
237 pnode = path.pop()
238 pnode.balance += 1 if di & 1 else -1
239 di >>= 1
240 if pnode.balance == 0:
241 break
242 if pnode.balance == 2:
243 new_node = (
244 self._rotate_LR(pnode)
245 if pnode.left.balance == -1
246 else self._rotate_L(pnode)
247 )
248 break
249 elif pnode.balance == -2:
250 new_node = (
251 self._rotate_RL(pnode)
252 if pnode.right.balance == 1
253 else self._rotate_R(pnode)
254 )
255 break
256 if new_node is not None:
257 if path:
258 gnode = path.pop()
259 if di & 1:
260 gnode.left = new_node
261 else:
262 gnode.right = new_node
263 else:
264 self.node = new_node
265 return True
266
267 def __delitem__(self, key: K):
268 if self._discard(key):
269 self._len -= 1
270 return
271 raise KeyError(key)
272
273 def __getitem__(self, key: K):
274 node = self._search_node(key)
275 return self.__missing__() if node is None else node.val
276
277 def __contains__(self, key: K):
278 return self._search_node(key) is not None
279
280 def __len__(self):
281 return self._len
282
283 def __bool__(self):
284 return self.node is not None
285
286 def __str__(self):
287 return "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.items())) + "}"
288
289 def __repr__(self):
290 return "AVLTreeDict(" + str(self) + ")"
291
292 def __missing__(self, e):
293 return self._default()