persistent_segment_tree

ソースコード

from titan_pylib.data_structures.segment_tree.persistent_segment_tree import PersistentSegmentTree

view on github

展開済みコード

  1# from titan_pylib.data_structures.segment_tree.persistent_segment_tree import PersistentSegmentTree
  2from typing import Generic, Iterable, Optional, TypeVar, Callable
  3
  4T = TypeVar("T")
  5
  6
  7class PersistentSegmentTree(Generic[T]):
  8
  9    class Node:
 10
 11        def __init__(self, key: T) -> None:
 12            self.key: T = key
 13            self.data: T = key
 14            self.left: Optional[PersistentSegmentTree.Node] = None
 15            self.right: Optional[PersistentSegmentTree.Node] = None
 16            self.size: int = 1
 17
 18        def copy(self) -> "PersistentSegmentTree.Node":
 19            node = PersistentSegmentTree.Node(self.key)
 20            node.data = self.data
 21            node.left = self.left
 22            node.right = self.right
 23            node.size = self.size
 24            return node
 25
 26    def __init__(
 27        self,
 28        a: Iterable[T],
 29        op: Callable[[T, T], T],
 30        e: T,
 31        __root: Optional[Node] = None,
 32    ) -> None:
 33        self.root: Optional[PersistentSegmentTree.Node] = __root
 34        self.op: Callable[[T, T], T] = op
 35        self.e: T = e
 36        if __root is not None:
 37            return
 38        a = list(a)
 39        if a:
 40            self._build(a)
 41
 42    def _build(self, a: list[T]) -> None:
 43        Node = PersistentSegmentTree.Node
 44
 45        def build(l: int, r: int) -> PersistentSegmentTree.Node:
 46            mid = (l + r) >> 1
 47            node = Node(a[mid])
 48            if l != mid:
 49                node.left = build(l, mid)
 50            if mid + 1 != r:
 51                node.right = build(mid + 1, r)
 52            self._update(node)
 53            return node
 54
 55        self.root = build(0, len(a))
 56
 57    def _update(self, node: Node) -> None:
 58        node.size = 1
 59        node.data = node.key
 60        if node.left:
 61            node.size += node.left.size
 62            node.data = self.op(node.left.data, node.data)
 63        if node.right:
 64            node.size += node.right.size
 65            node.data = self.op(node.data, node.right.data)
 66
 67    def prod(self, l: int, r: int) -> T:
 68        if l >= r or (not self.root):
 69            return self.e
 70
 71        def dfs(node: PersistentSegmentTree.Node, left: int, right: int) -> T:
 72            if right <= l or r <= left:
 73                return self.e
 74            if l <= left and right < r:
 75                return node.data
 76            lsize = node.left.size if node.left else 0
 77            res = self.e
 78            if node.left:
 79                res = dfs(node.left, left, left + lsize)
 80            if l <= left + lsize < r:
 81                res = self.op(res, node.key)
 82            if node.right:
 83                res = self.op(res, dfs(node.right, left + lsize + 1, right))
 84            return res
 85
 86        return dfs(self.root, 0, len(self))
 87
 88    def tolist(self) -> list[T]:
 89        node = self.root
 90        stack = []
 91        a = []
 92        while stack or node:
 93            if node:
 94                stack.append(node)
 95                node = node.left
 96            else:
 97                node = stack.pop()
 98                a.append(node.key)
 99                node = node.right
100        return a
101
102    def _new(self, root: Optional["Node"]) -> "PersistentSegmentTree[T]":
103        return PersistentSegmentTree([], self.op, self.e, root)
104
105    def copy(self) -> "PersistentSegmentTree[T]":
106        root = self.root.copy() if self.root else None
107        return self._new(root)
108
109    def set(self, k: int, v: T) -> "PersistentSegmentTree[T]":
110        if k < 0:
111            k += len(self)
112        node = self.root.copy()
113        root = node
114        pnode = None
115        d = 0
116        path = [node]
117        while True:
118            t = 0 if node.left is None else node.left.size
119            if t == k:
120                node = node.copy()
121                node.key = v
122                path.append(node)
123                if pnode:
124                    if d:
125                        pnode.left = node
126                    else:
127                        pnode.right = node
128                else:
129                    root = node
130                while path:
131                    self._update(path.pop())
132                return self._new(root)
133            pnode = node
134            if t < k:
135                k -= t + 1
136                d = 0
137                node = node.right.copy()
138                pnode.right = node
139            else:
140                d = 1
141                node = node.left.copy()
142                pnode.left = node
143            path.append(node)
144
145    def __getitem__(self, k: int) -> T:
146        if k < 0:
147            k += len(self)
148        node = self.root
149        while True:
150            t = 0 if node.left is None else node.left.size
151            if t == k:
152                return node.key
153            if t < k:
154                k -= t + 1
155                node = node.right
156            else:
157                node = node.left
158
159    def __len__(self):
160        return 0 if self.root is None else self.root.size
161
162    def __str__(self):
163        return str(self.tolist())
164
165    def __repr__(self):
166        return f"{self.__class__.__name__}({self})"

仕様

class PersistentSegmentTree(a: Iterable[T], op: Callable[[T, T], T], e: T, _PersistentSegmentTree__root: Node | None = None)[source]

Bases: Generic[T]

class Node(key: T)[source]

Bases: object

copy() Node[source]
copy() PersistentSegmentTree[T][source]
prod(l: int, r: int) T[source]
set(k: int, v: T) PersistentSegmentTree[T][source]
tolist() list[T][source]