persistent_avl_tree_list¶
ソースコード¶
from titan_pylib.data_structures.avl_tree.persistent_avl_tree_list import PersistentAVLTreeList
展開済みコード¶
1# from titan_pylib.data_structures.avl_tree.persistent_avl_tree_list import PersistentAVLTreeList
2from typing import Generic, Iterable, Optional, TypeVar
3
4T = TypeVar("T")
5
6
7class PersistentAVLTreeList(Generic[T]):
8 """挿入削除が対数時間で行える永続AVL木です。"""
9
10 class Node:
11
12 def __init__(self, key: T):
13 self.key: T = key
14 self.left: Optional[PersistentAVLTreeList.Node] = None
15 self.right: Optional[PersistentAVLTreeList.Node] = None
16 self.height: int = 1
17 self.size: int = 1
18
19 def copy(self) -> "PersistentAVLTreeList.Node":
20 node = PersistentAVLTreeList.Node(self.key)
21 node.left = self.left
22 node.right = self.right
23 node.height = self.height
24 node.size = self.size
25 return node
26
27 def balance(self) -> int:
28 return (
29 (0 if self.right is None else -self.right.height)
30 if self.left is None
31 else (
32 self.left.height
33 if self.right is None
34 else self.left.height - self.right.height
35 )
36 )
37
38 def __str__(self):
39 if self.left is None and self.right is None:
40 return f"key={self.key}, height={self.height}, size={self.size}\n"
41 return f"key={self.key}, height={self.height}, size={self.size},\n left:{self.left},\n right:{self.right}\n"
42
43 __repr__ = __str__
44
45 def __init__(self, a: Iterable[T] = [], _root: Optional[Node] = None) -> None:
46 self.root: Optional[PersistentAVLTreeList.Node] = _root
47 a = list(a)
48 if a:
49 self._build(list(a))
50
51 def _build(self, a: list[T]) -> None:
52 Node = PersistentAVLTreeList.Node
53
54 def build(l: int, r: int) -> PersistentAVLTreeList.Node:
55 mid = (l + r) >> 1
56 node = Node(a[mid])
57 if l != mid:
58 node.left = build(l, mid)
59 if mid + 1 != r:
60 node.right = build(mid + 1, r)
61 self._update(node)
62 return node
63
64 self.root = build(0, len(a))
65
66 def _update(self, node: Node) -> None:
67 if node.left:
68 if node.right:
69 node.size = 1 + node.left.size + node.right.size
70 node.height = (
71 node.left.height + 1
72 if node.left.height > node.right.height
73 else node.right.height + 1
74 )
75 else:
76 node.size = 1 + node.left.size
77 node.height = node.left.height + 1
78 else:
79 if node.right:
80 node.size = 1 + node.right.size
81 node.height = node.right.height + 1
82 else:
83 node.size = 1
84 node.height = 1
85
86 def _rotate_right(self, node: Node) -> Node:
87 assert node.left
88 u = node.left.copy()
89 node.left = u.right
90 u.right = node
91 self._update(node)
92 self._update(u)
93 return u
94
95 def _rotate_left(self, node: Node) -> Node:
96 assert node.right
97 u = node.right.copy()
98 node.right = u.left
99 u.left = node
100 self._update(node)
101 self._update(u)
102 return u
103
104 def _balance_left(self, node: Node) -> Node:
105 assert node.right
106 node.right = node.right.copy()
107 u = node.right
108 if u.balance() == 1:
109 assert u.left
110 node.right = self._rotate_right(u)
111 u = self._rotate_left(node)
112 return u
113
114 def _balance_right(self, node: Node) -> Node:
115 assert node.left
116 node.left = node.left.copy()
117 u = node.left
118 if u.balance() == -1:
119 assert u.right
120 node.left = self._rotate_left(u)
121 u = self._rotate_right(node)
122 return u
123
124 def _merge_with_root(
125 self, l: Optional[Node], root: Node, r: Optional[Node]
126 ) -> Node:
127 diff = 0
128 if l:
129 diff += l.height
130 if r:
131 diff -= r.height
132 if diff > 1:
133 assert l
134 l = l.copy()
135 l.right = self._merge_with_root(l.right, root, r)
136 self._update(l)
137 if l.balance() == -2:
138 return self._balance_left(l)
139 return l
140 if diff < -1:
141 assert r
142 r = r.copy()
143 r.left = self._merge_with_root(l, root, r.left)
144 self._update(r)
145 if r.balance() == 2:
146 return self._balance_right(r)
147 return r
148 root = root.copy()
149 root.left = l
150 root.right = r
151 self._update(root)
152 return root
153
154 def _merge_node(self, l: Optional[Node], r: Optional[Node]) -> Optional[Node]:
155 if l is None and r is None:
156 return None
157 if l is None:
158 assert r
159 return r.copy()
160 if r is None:
161 return l.copy()
162 l = l.copy()
163 r = r.copy()
164 l, root = self._pop_right(l)
165 return self._merge_with_root(l, root, r)
166
167 def merge(self, other: "PersistentAVLTreeList") -> "PersistentAVLTreeList":
168 root = self._merge_node(self.root, other.root)
169 return self._new(root)
170
171 def _pop_right(self, node: Node) -> tuple[Node, Node]:
172 path = []
173 node = node.copy()
174 mx = node
175 while node.right is not None:
176 path.append(node)
177 node = node.right.copy()
178 mx = node
179 path.append(node.left.copy() if node.left else None)
180 for _ in range(len(path) - 1):
181 node = path.pop()
182 if node is None:
183 path[-1].right = None
184 self._update(path[-1])
185 continue
186 b = node.balance()
187 if b == 2:
188 path[-1].right = self._balance_right(node)
189 elif b == -2:
190 path[-1].right = self._balance_left(node)
191 else:
192 path[-1].right = node
193 self._update(path[-1])
194 if path[0] is not None:
195 b = path[0].balance()
196 if b == 2:
197 path[0] = self._balance_right(path[0])
198 elif b == -2:
199 path[0] = self._balance_left(path[0])
200 mx.left = None
201 self._update(mx)
202 return path[0], mx
203
204 def _split_node(
205 self, node: Optional[Node], k: int
206 ) -> tuple[Optional[Node], Optional[Node]]:
207 if node is None:
208 return None, None
209 tmp = k if node.left is None else k - node.left.size
210 l, r = None, None
211 if tmp == 0:
212 return node.left, self._merge_with_root(None, node, node.right)
213 elif tmp < 0:
214 l, r = self._split_node(node.left, k)
215 return l, self._merge_with_root(r, node, node.right)
216 else:
217 l, r = self._split_node(node.right, tmp - 1)
218 return self._merge_with_root(node.left, node, l), r
219
220 def split(self, k: int) -> tuple["PersistentAVLTreeList", "PersistentAVLTreeList"]:
221 l, r = self._split_node(self.root, k)
222 return self._new(l), self._new(r)
223
224 def _new(
225 self, root: Optional["PersistentAVLTreeList.Node"]
226 ) -> "PersistentAVLTreeList":
227 return PersistentAVLTreeList([], root)
228
229 def insert(self, k: int, key: T) -> "PersistentAVLTreeList":
230 s, t = self._split_node(self.root, k)
231 root = self._merge_with_root(s, PersistentAVLTreeList.Node(key), t)
232 return self._new(root)
233
234 def pop(self, k: int) -> tuple["PersistentAVLTreeList", T]:
235 s, t = self._split_node(self.root, k + 1)
236 assert s
237 s, tmp = self._pop_right(s)
238 root = self._merge_node(s, t)
239 return self._new(root), tmp.key
240
241 def tolist(self) -> list[T]:
242 node = self.root
243 stack = []
244 a = []
245 while stack or node:
246 if node:
247 stack.append(node)
248 node = node.left
249 else:
250 node = stack.pop()
251 a.append(node.key)
252 node = node.right
253 return a
254
255 def __getitem__(self, k: int) -> T:
256 if k < 0:
257 k += len(self)
258 node = self.root
259 while True:
260 assert node
261 t = 0 if node.left is None else node.left.size
262 if t == k:
263 return node.key
264 elif t < k:
265 k -= t + 1
266 node = node.right
267 else:
268 node = node.left
269
270 def __len__(self):
271 return 0 if self.root is None else self.root.size
272
273 def __str__(self):
274 return "[" + ", ".join(map(str, self.tolist())) + "]"
275
276 def __repr__(self):
277 return f"PersistentAVLTreeList({self})"
仕様¶
- class PersistentAVLTreeList(a: Iterable[T] = [], _root: Node | None = None)[source]¶
Bases:
Generic
[T
]挿入削除が対数時間で行える永続AVL木です。
- insert(k: int, key: T) PersistentAVLTreeList [source]¶
- merge(other: PersistentAVLTreeList) PersistentAVLTreeList [source]¶
- pop(k: int) tuple[PersistentAVLTreeList, T] [source]¶
- split(k: int) tuple[PersistentAVLTreeList, PersistentAVLTreeList] [source]¶