splay_tree_list_array

ソースコード

from titan_pylib.data_structures.splay_tree.splay_tree_list_array import SplayTreeListArrayData
from titan_pylib.data_structures.splay_tree.splay_tree_list_array import SplayTreeListArray

view on github

展開済みコード

  1# from titan_pylib.data_structures.splay_tree.splay_tree_list_array import SplayTreeListArray
  2from array import array
  3from typing import Generic, TypeVar, Iterable, Union
  4
  5T = TypeVar("T")
  6
  7
  8class SplayTreeListArrayData(Generic[T]):
  9
 10    def __init__(self, e: T = 0):
 11        self.keys: list[T] = [e]
 12        self.e: T = e
 13        self.arr: array[int] = array("I", bytes(16))
 14        # left:  arr[node<<2]
 15        # right: arr[node<<2|1]
 16        # size:  arr[node<<2|2]
 17        # rev:   arr[node<<2|3]
 18        self.end: int = 1
 19
 20    def reserve(self, n: int) -> None:
 21        if n <= 0:
 22            return
 23        self.keys += [self.e] * (2 * n)
 24        self.arr += array("I", bytes(16 * n))
 25
 26    def _update_triple(self, x: int, y: int, z: int) -> None:
 27        arr = self.arr
 28        arr[z << 2 | 2] = arr[x << 2 | 2]
 29        arr[x << 2 | 2] = 1 + arr[arr[x << 2] << 2 | 2] + arr[arr[x << 2 | 1] << 2 | 2]
 30        arr[y << 2 | 2] = 1 + arr[arr[y << 2] << 2 | 2] + arr[arr[y << 2 | 1] << 2 | 2]
 31
 32    def _update_double(self, x: int, y: int) -> None:
 33        arr = self.arr
 34        arr[y << 2 | 2] = arr[x << 2 | 2]
 35        arr[x << 2 | 2] = 1 + arr[arr[x << 2] << 2 | 2] + arr[arr[x << 2 | 1] << 2 | 2]
 36
 37    def _update(self, node: int) -> None:
 38        arr = self.arr
 39        arr[node << 2 | 2] = (
 40            1 + arr[arr[node << 2] << 2 | 2] + arr[arr[node << 2 | 1] << 2 | 2]
 41        )
 42
 43    def _make_node(self, key: T) -> int:
 44        if self.end >= len(self.arr) // 4:
 45            self.keys.append(key)
 46            self.arr.append(0)
 47            self.arr.append(0)
 48            self.arr.append(1)
 49            self.arr.append(0)
 50        else:
 51            self.keys[self.end] = key
 52        self.end += 1
 53        return self.end - 1
 54
 55    def _splay(self, path: list[int], d: int) -> None:
 56        arr = self.arr
 57        g = d & 1
 58        while len(path) > 1:
 59            pnode = path.pop()
 60            gnode = path.pop()
 61            f = d >> 1 & 1
 62            node = arr[pnode << 2 | g ^ 1]
 63            nnode = (pnode if g == f else node) << 2 | f
 64            arr[pnode << 2 | g ^ 1] = arr[node << 2 | g]
 65            arr[node << 2 | g] = pnode
 66            arr[gnode << 2 | f ^ 1] = arr[nnode]
 67            arr[nnode] = gnode
 68            self._update_triple(gnode, pnode, node)
 69            if not path:
 70                return
 71            d >>= 2
 72            g = d & 1
 73            arr[path[-1] << 2 | g ^ 1] = node
 74        pnode = path.pop()
 75        node = arr[pnode << 2 | g ^ 1]
 76        arr[pnode << 2 | g ^ 1] = arr[node << 2 | g]
 77        arr[node << 2 | g] = pnode
 78        self._update_double(pnode, node)
 79
 80
 81class SplayTreeListArray(Generic[T]):
 82
 83    def __init__(
 84        self,
 85        data: SplayTreeListArrayData,
 86        n_or_a: Union[int, Iterable[T]] = 0,
 87        _root: int = 0,
 88    ) -> None:
 89        self.data: SplayTreeListArrayData = data
 90        self.root: int = _root
 91        if not n_or_a:
 92            return
 93        if isinstance(n_or_a, int):
 94            a = [data.e for _ in range(n_or_a)]
 95        else:
 96            a = list(n_or_a)
 97        if a:
 98            self._build(a)
 99
100    def _build(self, a: list[T]) -> None:
101        def rec(l: int, r: int) -> int:
102            mid = (l + r) >> 1
103            if l != mid:
104                arr[mid << 2] = rec(l, mid)
105            if mid + 1 != r:
106                arr[mid << 2 | 1] = rec(mid + 1, r)
107            self.data._update(mid)
108            return mid
109
110        n = len(a)
111        keys, arr = self.data.keys, self.data.arr
112        end = self.data.end
113        self.data.reserve(n + end - len(keys) // 2 + 1)
114        self.data.end += n
115        keys[end : end + n] = a
116        self.root = rec(end, n + end)
117
118    def _kth_elm_splay(self, node: int, k: int) -> int:
119        arr = self.data.arr
120        if k < 0:
121            k += arr[node << 2 | 2]
122        d = 0
123        path = []
124        while True:
125            t = arr[arr[node << 2] << 2 | 2]
126            if t == k:
127                if path:
128                    self.data._splay(path, d)
129                return node
130            d = d << 1 | (t > k)
131            path.append(node)
132            node = arr[node << 2 | (t < k)]
133            if t < k:
134                k -= t + 1
135
136    def _left_splay(self, node: int) -> int:
137        if not node:
138            return 0
139        arr = self.data.arr
140        if not arr[node << 2]:
141            return node
142        path = []
143        while arr[node << 2]:
144            path.append(node)
145            node = arr[node << 2]
146        self.data._splay(path, (1 << len(path)) - 1)
147        return node
148
149    def _right_splay(self, node: int) -> int:
150        if not node:
151            return 0
152        arr = self.data.arr
153        if not arr[node << 2 | 1]:
154            return node
155        path = []
156        while arr[node << 2 | 1]:
157            path.append(node)
158            node = arr[node << 2 | 1]
159        self.data._splay(path, 0)
160        return node
161
162    def reserve(self, n: int) -> None:
163        self.data.reserve(n)
164
165    def merge(self, other: "SplayTreeListArray") -> None:
166        assert self.data is other.data
167        if not other.root:
168            return
169        if not self.root:
170            self.root = other.root
171            return
172        self.root = self._right_splay(self.root)
173        self.data.arr[self.root << 2 | 1] = other.root
174        self.data._update(self.root)
175
176    def split(self, k: int) -> tuple["SplayTreeListArray", "SplayTreeListArray"]:
177        assert (
178            -len(self) < k <= len(self)
179        ), f"IndexError: SplayTreeListArray.split({k}), len={len(self)}"
180        if k < 0:
181            k += len(self)
182        if k >= self.data.arr[self.root << 2 | 2]:
183            return self, SplayTreeListArray(self.data, _root=0)
184        self.root = self._kth_elm_splay(self.root, k)
185        left = SplayTreeListArray(self.data, _root=self.data.arr[self.root << 2])
186        self.data.arr[self.root << 2] = 0
187        self.data._update(self.root)
188        return left, self
189
190    def _internal_split(self, k: int) -> tuple[int, int]:
191        if k >= self.data.arr[self.root << 2 | 2]:
192            return self.root, 0
193        self.root = self._kth_elm_splay(self.root, k)
194        left = self.data.arr[self.root << 2]
195        self.data.arr[self.root << 2] = 0
196        self.data._update(self.root)
197        return left, self.root
198
199    def insert(self, k: int, key: T) -> None:
200        assert (
201            -len(self) <= k <= len(self)
202        ), f"IndexError: SplayTreeListArray.insert({k}, {key}), len={len(self)}"
203        if k < 0:
204            k += len(self)
205        data = self.data
206        node = self.data._make_node(key)
207        if not self.root:
208            self.data._update(node)
209            self.root = node
210            return
211        arr = data.arr
212        if k == data.arr[self.root << 2 | 2]:
213            arr[node << 2] = self._right_splay(self.root)
214        else:
215            node_ = self._kth_elm_splay(self.root, k)
216            if arr[node_ << 2]:
217                arr[node << 2] = arr[node_ << 2]
218                arr[node_ << 2] = 0
219                self.data._update(node_)
220            arr[node << 2 | 1] = node_
221        self.data._update(node)
222        self.root = node
223
224    def append(self, key: T) -> None:
225        data = self.data
226        node = self._right_splay(self.root)
227        self.root = self.data._make_node(key)
228        data.arr[self.root << 2] = node
229        self.data._update(self.root)
230
231    def appendleft(self, key: T) -> None:
232        node = self._left_splay(self.root)
233        self.root = self.data._make_node(key)
234        self.data.arr[self.root << 2 | 1] = node
235        self.data._update(self.root)
236
237    def pop(self, k: int = -1) -> T:
238        assert -len(self) <= k < len(self), f"IndexError: SplayTreeListArray.pop({k})"
239        data = self.data
240        if k == -1:
241            node = self._right_splay(self.root)
242            self.root = data.arr[node << 2]
243            return data.keys[node]
244        self.root = self._kth_elm_splay(self.root, k)
245        res = data.keys[self.root]
246        if not data.arr[self.root << 2]:
247            self.root = data.arr[self.root << 2 | 1]
248        elif not data.arr[self.root << 2 | 1]:
249            self.root = data.arr[self.root << 2]
250        else:
251            node = self._right_splay(data.arr[self.root << 2])
252            data.arr[node << 2 | 1] = data.arr[self.root << 2 | 1]
253            self.root = node
254            self.data._update(self.root)
255        return res
256
257    def popleft(self) -> T:
258        assert self, "IndexError: SplayTreeListArray.popleft()"
259        node = self._left_splay(self.root)
260        self.root = self.data.arr[node << 2 | 1]
261        return self.data.keys[node]
262
263    def rotate(self, x: int) -> None:
264        # 「末尾をを削除し先頭に挿入」をx回
265        n = self.data.arr[self.root << 2 | 2]
266        l, self = self.split(n - (x % n))
267        self.merge(l)
268
269    def tolist(self) -> list[T]:
270        node = self.root
271        arr, keys = self.data.arr, self.data.keys
272        stack = []
273        res = []
274        while stack or node:
275            if node:
276                stack.append(node)
277                node = arr[node << 2]
278            else:
279                node = stack.pop()
280                res.append(keys[node])
281                node = arr[node << 2 | 1]
282        return res
283
284    def clear(self) -> None:
285        self.root = 0
286
287    def __setitem__(self, k: int, key: T):
288        assert (
289            -len(self) <= k < len(self)
290        ), f"IndexError: SplayTreeListArray.__setitem__({k})"
291        self.root = self._kth_elm_splay(self.root, k)
292        self.data.keys[self.root] = key
293        self.data._update(self.root)
294
295    def __getitem__(self, k: int) -> T:
296        assert (
297            -len(self) <= k < len(self)
298        ), f"IndexError: SplayTreeListArray.__getitem__({k})"
299        self.root = self._kth_elm_splay(self.root, k)
300        return self.data.keys[self.root]
301
302    def __iter__(self):
303        self.__iter = 0
304        return self
305
306    def __next__(self):
307        if self.__iter == self.data.arr[self.root << 2 | 2]:
308            raise StopIteration
309        res = self.__getitem__(self.__iter)
310        self.__iter += 1
311        return res
312
313    def __reversed__(self):
314        for i in range(len(self)):
315            yield self.__getitem__(-i - 1)
316
317    def __len__(self):
318        return self.data.arr[self.root << 2 | 2]
319
320    def __str__(self):
321        return str(self.tolist())
322
323    def __bool__(self):
324        return self.root != 0
325
326    def __repr__(self):
327        return f"SplayTreeListArray({self})"

仕様

class SplayTreeListArray(data: SplayTreeListArrayData, n_or_a: int | Iterable[T] = 0, _root: int = 0)[source]

Bases: Generic[T]

append(key: T) None[source]
appendleft(key: T) None[source]
clear() None[source]
insert(k: int, key: T) None[source]
merge(other: SplayTreeListArray) None[source]
pop(k: int = -1) T[source]
popleft() T[source]
reserve(n: int) None[source]
rotate(x: int) None[source]
split(k: int) tuple[SplayTreeListArray, SplayTreeListArray][source]
tolist() list[T][source]
class SplayTreeListArrayData(e: T = 0)[source]

Bases: Generic[T]

reserve(n: int) None[source]