persistent_avl_tree_list

ソースコード

from titan_pylib.data_structures.avl_tree.persistent_avl_tree_list import PersistentAVLTreeList

view on github

展開済みコード

  1# from titan_pylib.data_structures.avl_tree.persistent_avl_tree_list import PersistentAVLTreeList
  2from typing import Generic, Iterable, Optional, TypeVar
  3
  4T = TypeVar("T")
  5
  6
  7class PersistentAVLTreeList(Generic[T]):
  8    """挿入削除が対数時間で行える永続AVL木です。"""
  9
 10    class Node:
 11
 12        def __init__(self, key: T):
 13            self.key: T = key
 14            self.left: Optional[PersistentAVLTreeList.Node] = None
 15            self.right: Optional[PersistentAVLTreeList.Node] = None
 16            self.height: int = 1
 17            self.size: int = 1
 18
 19        def copy(self) -> "PersistentAVLTreeList.Node":
 20            node = PersistentAVLTreeList.Node(self.key)
 21            node.left = self.left
 22            node.right = self.right
 23            node.height = self.height
 24            node.size = self.size
 25            return node
 26
 27        def balance(self) -> int:
 28            return (
 29                (0 if self.right is None else -self.right.height)
 30                if self.left is None
 31                else (
 32                    self.left.height
 33                    if self.right is None
 34                    else self.left.height - self.right.height
 35                )
 36            )
 37
 38        def __str__(self):
 39            if self.left is None and self.right is None:
 40                return f"key={self.key}, height={self.height}, size={self.size}\n"
 41            return f"key={self.key}, height={self.height}, size={self.size},\n left:{self.left},\n right:{self.right}\n"
 42
 43        __repr__ = __str__
 44
 45    def __init__(self, a: Iterable[T] = [], _root: Optional[Node] = None) -> None:
 46        self.root: Optional[PersistentAVLTreeList.Node] = _root
 47        a = list(a)
 48        if a:
 49            self._build(list(a))
 50
 51    def _build(self, a: list[T]) -> None:
 52        Node = PersistentAVLTreeList.Node
 53
 54        def build(l: int, r: int) -> PersistentAVLTreeList.Node:
 55            mid = (l + r) >> 1
 56            node = Node(a[mid])
 57            if l != mid:
 58                node.left = build(l, mid)
 59            if mid + 1 != r:
 60                node.right = build(mid + 1, r)
 61            self._update(node)
 62            return node
 63
 64        self.root = build(0, len(a))
 65
 66    def _update(self, node: Node) -> None:
 67        if node.left:
 68            if node.right:
 69                node.size = 1 + node.left.size + node.right.size
 70                node.height = (
 71                    node.left.height + 1
 72                    if node.left.height > node.right.height
 73                    else node.right.height + 1
 74                )
 75            else:
 76                node.size = 1 + node.left.size
 77                node.height = node.left.height + 1
 78        else:
 79            if node.right:
 80                node.size = 1 + node.right.size
 81                node.height = node.right.height + 1
 82            else:
 83                node.size = 1
 84                node.height = 1
 85
 86    def _rotate_right(self, node: Node) -> Node:
 87        assert node.left
 88        u = node.left.copy()
 89        node.left = u.right
 90        u.right = node
 91        self._update(node)
 92        self._update(u)
 93        return u
 94
 95    def _rotate_left(self, node: Node) -> Node:
 96        assert node.right
 97        u = node.right.copy()
 98        node.right = u.left
 99        u.left = node
100        self._update(node)
101        self._update(u)
102        return u
103
104    def _balance_left(self, node: Node) -> Node:
105        assert node.right
106        node.right = node.right.copy()
107        u = node.right
108        if u.balance() == 1:
109            assert u.left
110            node.right = self._rotate_right(u)
111        u = self._rotate_left(node)
112        return u
113
114    def _balance_right(self, node: Node) -> Node:
115        assert node.left
116        node.left = node.left.copy()
117        u = node.left
118        if u.balance() == -1:
119            assert u.right
120            node.left = self._rotate_left(u)
121        u = self._rotate_right(node)
122        return u
123
124    def _merge_with_root(
125        self, l: Optional[Node], root: Node, r: Optional[Node]
126    ) -> Node:
127        diff = 0
128        if l:
129            diff += l.height
130        if r:
131            diff -= r.height
132        if diff > 1:
133            assert l
134            l = l.copy()
135            l.right = self._merge_with_root(l.right, root, r)
136            self._update(l)
137            if l.balance() == -2:
138                return self._balance_left(l)
139            return l
140        if diff < -1:
141            assert r
142            r = r.copy()
143            r.left = self._merge_with_root(l, root, r.left)
144            self._update(r)
145            if r.balance() == 2:
146                return self._balance_right(r)
147            return r
148        root = root.copy()
149        root.left = l
150        root.right = r
151        self._update(root)
152        return root
153
154    def _merge_node(self, l: Optional[Node], r: Optional[Node]) -> Optional[Node]:
155        if l is None and r is None:
156            return None
157        if l is None:
158            assert r
159            return r.copy()
160        if r is None:
161            return l.copy()
162        l = l.copy()
163        r = r.copy()
164        l, root = self._pop_right(l)
165        return self._merge_with_root(l, root, r)
166
167    def merge(self, other: "PersistentAVLTreeList") -> "PersistentAVLTreeList":
168        root = self._merge_node(self.root, other.root)
169        return self._new(root)
170
171    def _pop_right(self, node: Node) -> tuple[Node, Node]:
172        path = []
173        node = node.copy()
174        mx = node
175        while node.right is not None:
176            path.append(node)
177            node = node.right.copy()
178            mx = node
179        path.append(node.left.copy() if node.left else None)
180        for _ in range(len(path) - 1):
181            node = path.pop()
182            if node is None:
183                path[-1].right = None
184                self._update(path[-1])
185                continue
186            b = node.balance()
187            if b == 2:
188                path[-1].right = self._balance_right(node)
189            elif b == -2:
190                path[-1].right = self._balance_left(node)
191            else:
192                path[-1].right = node
193            self._update(path[-1])
194        if path[0] is not None:
195            b = path[0].balance()
196            if b == 2:
197                path[0] = self._balance_right(path[0])
198            elif b == -2:
199                path[0] = self._balance_left(path[0])
200        mx.left = None
201        self._update(mx)
202        return path[0], mx
203
204    def _split_node(
205        self, node: Optional[Node], k: int
206    ) -> tuple[Optional[Node], Optional[Node]]:
207        if node is None:
208            return None, None
209        tmp = k if node.left is None else k - node.left.size
210        l, r = None, None
211        if tmp == 0:
212            return node.left, self._merge_with_root(None, node, node.right)
213        elif tmp < 0:
214            l, r = self._split_node(node.left, k)
215            return l, self._merge_with_root(r, node, node.right)
216        else:
217            l, r = self._split_node(node.right, tmp - 1)
218            return self._merge_with_root(node.left, node, l), r
219
220    def split(self, k: int) -> tuple["PersistentAVLTreeList", "PersistentAVLTreeList"]:
221        l, r = self._split_node(self.root, k)
222        return self._new(l), self._new(r)
223
224    def _new(
225        self, root: Optional["PersistentAVLTreeList.Node"]
226    ) -> "PersistentAVLTreeList":
227        return PersistentAVLTreeList([], root)
228
229    def insert(self, k: int, key: T) -> "PersistentAVLTreeList":
230        s, t = self._split_node(self.root, k)
231        root = self._merge_with_root(s, PersistentAVLTreeList.Node(key), t)
232        return self._new(root)
233
234    def pop(self, k: int) -> tuple["PersistentAVLTreeList", T]:
235        s, t = self._split_node(self.root, k + 1)
236        assert s
237        s, tmp = self._pop_right(s)
238        root = self._merge_node(s, t)
239        return self._new(root), tmp.key
240
241    def tolist(self) -> list[T]:
242        node = self.root
243        stack = []
244        a = []
245        while stack or node:
246            if node:
247                stack.append(node)
248                node = node.left
249            else:
250                node = stack.pop()
251                a.append(node.key)
252                node = node.right
253        return a
254
255    def __getitem__(self, k: int) -> T:
256        if k < 0:
257            k += len(self)
258        node = self.root
259        while True:
260            assert node
261            t = 0 if node.left is None else node.left.size
262            if t == k:
263                return node.key
264            elif t < k:
265                k -= t + 1
266                node = node.right
267            else:
268                node = node.left
269
270    def __len__(self):
271        return 0 if self.root is None else self.root.size
272
273    def __str__(self):
274        return "[" + ", ".join(map(str, self.tolist())) + "]"
275
276    def __repr__(self):
277        return f"PersistentAVLTreeList({self})"

仕様

class PersistentAVLTreeList(a: Iterable[T] = [], _root: Node | None = None)[source]

Bases: Generic[T]

挿入削除が対数時間で行える永続AVL木です。

class Node(key: T)[source]

Bases: object

balance() int[source]
copy() Node[source]
insert(k: int, key: T) PersistentAVLTreeList[source]
merge(other: PersistentAVLTreeList) PersistentAVLTreeList[source]
pop(k: int) tuple[PersistentAVLTreeList, T][source]
split(k: int) tuple[PersistentAVLTreeList, PersistentAVLTreeList][source]
tolist() list[T][source]