persistent_array

ソースコード

from titan_pylib.data_structures.array.persistent_array import PersistentArray

view on github

展開済みコード

  1# from titan_pylib.data_structures.array.persistent_array import PersistentArray
  2from typing import Iterable, TypeVar, Generic, Optional
  3
  4T = TypeVar("T")
  5
  6
  7class PersistentArray(Generic[T]):
  8
  9    class _Node:
 10
 11        def __init__(self, key: T):
 12            self.key: T = key
 13            self.left: Optional[PersistentArray._Node] = None
 14            self.right: Optional[PersistentArray._Node] = None
 15
 16        def copy(self) -> "PersistentArray._Node":
 17            node = PersistentArray._Node(self.key)
 18            node.left = self.left
 19            node.right = self.right
 20            return node
 21
 22    def __init__(
 23        self, a: Iterable[T] = [], _root: Optional["PersistentArray._Node"] = None
 24    ) -> None:
 25        self.root = self._build(a) if _root is None else _root
 26
 27    def _build(self, a: Iterable[T]) -> Optional["PersistentArray._Node"]:
 28        pool = [PersistentArray._Node(e) for e in a]
 29        self.n = len(pool)
 30        if not pool:
 31            return None
 32        n = len(pool)
 33        for i in range(1, n + 1):
 34            if 2 * i - 1 < n:
 35                pool[i - 1].left = pool[2 * i - 1]
 36            if 2 * i < n:
 37                pool[i - 1].right = pool[2 * i]
 38        return pool[0]
 39
 40    def _new(self, root: Optional["PersistentArray._Node"]) -> "PersistentArray[T]":
 41        res = PersistentArray(_root=root)
 42        res.n = self.n
 43        return res
 44
 45    def set(self, k: int, v: T) -> "PersistentArray[T]":
 46        assert 0 <= k < len(self), f"IndexError: {self.__class__.__name__}.set({k})"
 47        assert self.root
 48        node = self.root
 49        new_node = node.copy()
 50        res = self._new(new_node)
 51        k += 1
 52        b = k.bit_length()
 53        for i in range(b - 2, -1, -1):
 54            if k >> i & 1:
 55                node = node.right
 56                new_node.right = node.copy()
 57                new_node = new_node.right
 58            else:
 59                node = node.left
 60                new_node.left = node.copy()
 61                new_node = new_node.left
 62        new_node.key = v
 63        return res
 64
 65    def get(self, k: int) -> T:
 66        assert 0 <= k < len(self), f"IndexError: {self.__class__.__name__}.get({k})"
 67        node = self.root
 68        k += 1
 69        b = k.bit_length()
 70        for i in range(b - 2, -1, -1):
 71            if k >> i & 1:
 72                node = node.right
 73            else:
 74                node = node.left
 75        return node.key
 76
 77    __getitem__ = get
 78
 79    def copy(self) -> "PersistentArray[T]":
 80        return self._new(None if self.root is None else self.root.copy())
 81
 82    def tolist(self) -> list[T]:
 83        node = self.root
 84        a: list[T] = []
 85        if not node:
 86            return a
 87        q = [node]
 88        for node in q:
 89            a.append(node.key)
 90            if node.left:
 91                q.append(node.left)
 92            if node.right:
 93                q.append(node.right)
 94        return a
 95
 96    def __len__(self):
 97        return self.n
 98
 99    def __str__(self):
100        return str(self.tolist())
101
102    def __repr__(self):
103        return f"{self.__class__.__name__}({self})"

仕様

class PersistentArray(a: Iterable[T] = [], _root: _Node | None = None)[source]

Bases: Generic[T]

copy() PersistentArray[T][source]
get(k: int) T[source]
set(k: int, v: T) PersistentArray[T][source]
tolist() list[T][source]