splay_tree_list

ソースコード

from titan_pylib.data_structures.splay_tree.splay_tree_list import SplayTreeList

view on github

展開済みコード

  1# from titan_pylib.data_structures.splay_tree.splay_tree_list import SplayTreeList
  2from typing import Generic, TypeVar, Iterable, Optional
  3
  4T = TypeVar("T")
  5F = TypeVar("F")
  6
  7
  8class SplayTreeList(Generic[T, F]):
  9
 10    class Node:
 11
 12        def __init__(self, key: T):
 13            self.key: T = key
 14            self.left: Optional["SplayTreeList.Node"] = None
 15            self.right: Optional["SplayTreeList.Node"] = None
 16            self.par: Optional["SplayTreeList.Node"] = None
 17            self.size: int = 1
 18            self.rev: int = 0
 19
 20    def __init__(self, a: Iterable[T], _root: Node = None):
 21        self.root = _root
 22        if _root:
 23            return
 24        if not isinstance(a, list):
 25            a = list(a)
 26        if a:
 27            self._build(a)
 28
 29    def _build(self, a: list[T]) -> None:
 30        Node = self.Node
 31
 32        def build(l: int, r: int) -> self.Node:
 33            mid = (l + r) >> 1
 34            node = Node(a[mid])
 35            if l != mid:
 36                node.left = build(l, mid)
 37                node.left.par = node
 38            if mid + 1 != r:
 39                node.right = build(mid + 1, r)
 40                node.right.par = node
 41            self._update(node)
 42            return node
 43
 44        self.root = build(0, len(a))
 45
 46    def _rotate(self, node: Node) -> None:
 47        pnode = node.par
 48        gnode = pnode.par
 49        if gnode:
 50            if gnode.left is pnode:
 51                gnode.left = node
 52            else:
 53                gnode.right = node
 54        node.par = gnode
 55        if pnode.left is node:
 56            pnode.left = node.right
 57            if node.right:
 58                node.right.par = pnode
 59            node.right = pnode
 60        else:
 61            pnode.right = node.left
 62            if node.left:
 63                node.left.par = pnode
 64            node.left = pnode
 65        pnode.par = node
 66        self._update(pnode)
 67        self._update(node)
 68
 69    def _propagate_rev(self, node: Optional[Node]) -> None:
 70        if not node:
 71            return
 72        node.rev ^= 1
 73
 74    def _propagate(self, node: Optional[Node]) -> None:
 75        if not node:
 76            return
 77        if node.rev:
 78            node.data, node.rdata = node.rdata, node.data
 79            node.left, node.right = node.right, node.left
 80            self._propagate_rev(node.left)
 81            self._propagate_rev(node.right)
 82            node.rev = 0
 83
 84    def _update(self, node: Node) -> None:
 85        node.size = 1
 86        if node.left:
 87            node.size += node.left.size
 88        if node.right:
 89            node.size += node.right.size
 90
 91    def _splay(self, node: Node) -> None:
 92        while node.par and node.par.par:
 93            pnode = node.par
 94            self._rotate(
 95                pnode if (pnode.par.left is pnode) == (pnode.left is node) else node
 96            )
 97            self._rotate(node)
 98        if node.par:
 99            self._rotate(node)
100
101    def _get_kth_elm_splay(self, node: Optional[Node], k: int) -> None:
102        if k < 0:
103            k += len(self)
104        while True:
105            self._propagate(node)
106            t = node.left.size if node.left else 0
107            if t == k:
108                break
109            if t > k:
110                node = node.left
111            else:
112                node = node.right
113                k -= t + 1
114        self._splay(node)
115        return node
116
117    def _get_left_splay(self, node: Optional[Node]) -> Optional[Node]:
118        self._propagate(node)
119        if not node or not node.left:
120            return node
121        while node.left:
122            node = node.left
123            self._propagate(node)
124        self._splay(node)
125        return node
126
127    def _get_right_splay(self, node: Optional[Node]) -> Optional[Node]:
128        self._propagate(node)
129        if not node or not node.right:
130            return node
131        while node.right:
132            node = node.right
133            self._propagate(node)
134        self._splay(node)
135        return node
136
137    def merge(self, other: "SplayTreeList") -> None:
138        if not self.root:
139            self.root = other.root
140            return
141        if not other.root:
142            return
143        self.root = self._get_right_splay(self.root)
144        self.root.right = other.root
145        other.root.par = self.root
146        self._update(self.root)
147
148    def split(self, k: int) -> tuple["SplayTreeList", "SplayTreeList"]:
149        left, right = self._internal_split(self.root, k)
150        left_splay = SplayTreeList([], left)
151        right_splay = SplayTreeList([], right)
152        return left_splay, right_splay
153
154    def _internal_split(self, k: int) -> tuple[Node, Node]:
155        # self.root will be broken
156        if k >= len(self):
157            return self.root, None
158        right = self._get_kth_elm_splay(self.root, k)
159        left = right.left
160        if left:
161            left.par = None
162        right.left = None
163        self._update(right)
164        return left, right
165
166    def _internal_merge(
167        self, left: Optional[Node], right: Optional[Node]
168    ) -> Optional[Node]:
169        # need (not right) or (not right.left)
170        if not right:
171            return left
172        assert right.left is None
173        right.left = left
174        if left:
175            left.par = right
176        self._update(right)
177        return right
178
179    def reverse(self, l: int, r: int) -> None:
180        assert (
181            0 <= l <= r <= len(self)
182        ), f"IndexError: {self.__class__.__name__}.reverse({l}, {r}), len={len(self)}"
183        left, right = self._internal_split(r)
184        if l == 0:
185            self._propagate_rev(left)
186        else:
187            left = self._get_kth_elm_splay(left, l - 1)
188            self._propagate_rev(left.right)
189        self.root = self._internal_merge(left, right)
190
191    def all_reverse(self) -> None:
192        self._propagate_rev(self.root)
193
194    def insert(self, k: int, key: T) -> None:
195        node = self.Node(key, self.id)
196        if not self.root:
197            self.root = node
198            return
199        if k >= len(self):
200            root = self._get_kth_elm_splay(self.root, len(self) - 1)
201            node.left = root
202        else:
203            root = self._get_kth_elm_splay(self.root, k)
204            if root.left:
205                node.left = root.left
206                root.left.par = node
207                root.left = None
208                self._update(root)
209            node.right = root
210        root.par = node
211        self.root = node
212        self._update(self.root)
213
214    def append(self, key: T) -> None:
215        node = self._get_right_splay(self.root)
216        self.root = self.Node(key, self.id)
217        self.root.left = node
218        if node:
219            node.par = self.root
220        self._update(self.root)
221
222    def appendleft(self, key: T) -> None:
223        node = self._get_left_splay(self.root)
224        self.root = self.Node(key, self.id)
225        self.root.right = node
226        if node:
227            node.par = self.root
228        self._update(self.root)
229
230    def pop(self, k: int = -1) -> T:
231        if k == -1:
232            node = self._get_right_splay(self.root)
233            if node.left:
234                node.left.par = None
235            self.root = node.left
236            return node.key
237        root = self._get_kth_elm_splay(self.root, k)
238        res = root.key
239        if root.left and root.right:
240            node = self._get_right_splay(root.left)
241            node.par = None
242            node.right = root.right
243            if node.right:
244                node.right.par = node
245            self._update(node)
246            self.root = node
247        else:
248            self.root = root.right if root.right else root.left
249            if self.root:
250                self.root.par = None
251        return res
252
253    def popleft(self) -> T:
254        node = self._get_left_splay(self.root)
255        self.root = node.right
256        if node.right:
257            node.right.par = None
258        return node.key
259
260    def copy(self) -> "SplayTreeList":
261        return SplayTreeList(
262            self.tolist(), self.op, self.mapping, self.composition, self.e
263        )
264
265    def clear(self) -> None:
266        self.root = None
267
268    def tolist(self) -> list[T]:
269        node = self.root
270        stack = []
271        a = []
272        while stack or node:
273            if node:
274                self._propagate(node)
275                stack.append(node)
276                node = node.left
277            else:
278                node = stack.pop()
279                a.append(node.key)
280                node = node.right
281        return a
282
283    def __setitem__(self, k: int, key: T):
284        self.root = self._get_kth_elm_splay(self.root, k)
285        self.root.key = key
286
287    def __getitem__(self, k: int) -> T:
288        self.root = self._get_kth_elm_splay(self.root, k)
289        return self.root.key
290
291    def __iter__(self):
292        self.__iter = 0
293        return self
294
295    def __next__(self):
296        if self.__iter == len(self):
297            raise StopIteration
298        res = self[self.__iter]
299        self.__iter += 1
300        return res
301
302    def __reversed__(self):
303        for i in range(len(self)):
304            yield self[-i - 1]
305
306    def __len__(self):
307        return self.root.size if self.root else 0
308
309    def __str__(self):
310        return str(self.tolist())
311
312    def __bool__(self):
313        return self.root is not None
314
315    def __repr__(self):
316        return f"{self.__class__.__name__}({self})"

仕様

class SplayTreeList(a: Iterable[T], _root: Node = None)[source]

Bases: Generic[T, F]

class Node(key: T)[source]

Bases: object

all_reverse() None[source]
append(key: T) None[source]
appendleft(key: T) None[source]
clear() None[source]
copy() SplayTreeList[source]
insert(k: int, key: T) None[source]
merge(other: SplayTreeList) None[source]
pop(k: int = -1) T[source]
popleft() T[source]
reverse(l: int, r: int) None[source]
split(k: int) tuple[SplayTreeList, SplayTreeList][source]
tolist() list[T][source]