Source code for titan_pylib.data_structures.segment_tree.persistent_segment_tree

  1from typing import Generic, Iterable, Optional, TypeVar, Callable
  2
  3T = TypeVar("T")
  4
  5
[docs] 6class PersistentSegmentTree(Generic[T]): 7
[docs] 8 class Node: 9 10 def __init__(self, key: T) -> None: 11 self.key: T = key 12 self.data: T = key 13 self.left: Optional[PersistentSegmentTree.Node] = None 14 self.right: Optional[PersistentSegmentTree.Node] = None 15 self.size: int = 1 16
[docs] 17 def copy(self) -> "PersistentSegmentTree.Node": 18 node = PersistentSegmentTree.Node(self.key) 19 node.data = self.data 20 node.left = self.left 21 node.right = self.right 22 node.size = self.size 23 return node
24 25 def __init__( 26 self, 27 a: Iterable[T], 28 op: Callable[[T, T], T], 29 e: T, 30 __root: Optional[Node] = None, 31 ) -> None: 32 self.root: Optional[PersistentSegmentTree.Node] = __root 33 self.op: Callable[[T, T], T] = op 34 self.e: T = e 35 if __root is not None: 36 return 37 a = list(a) 38 if a: 39 self._build(a) 40 41 def _build(self, a: list[T]) -> None: 42 Node = PersistentSegmentTree.Node 43 44 def build(l: int, r: int) -> PersistentSegmentTree.Node: 45 mid = (l + r) >> 1 46 node = Node(a[mid]) 47 if l != mid: 48 node.left = build(l, mid) 49 if mid + 1 != r: 50 node.right = build(mid + 1, r) 51 self._update(node) 52 return node 53 54 self.root = build(0, len(a)) 55 56 def _update(self, node: Node) -> None: 57 node.size = 1 58 node.data = node.key 59 if node.left: 60 node.size += node.left.size 61 node.data = self.op(node.left.data, node.data) 62 if node.right: 63 node.size += node.right.size 64 node.data = self.op(node.data, node.right.data) 65
[docs] 66 def prod(self, l: int, r: int) -> T: 67 if l >= r or (not self.root): 68 return self.e 69 70 def dfs(node: PersistentSegmentTree.Node, left: int, right: int) -> T: 71 if right <= l or r <= left: 72 return self.e 73 if l <= left and right < r: 74 return node.data 75 lsize = node.left.size if node.left else 0 76 res = self.e 77 if node.left: 78 res = dfs(node.left, left, left + lsize) 79 if l <= left + lsize < r: 80 res = self.op(res, node.key) 81 if node.right: 82 res = self.op(res, dfs(node.right, left + lsize + 1, right)) 83 return res 84 85 return dfs(self.root, 0, len(self))
86
[docs] 87 def tolist(self) -> list[T]: 88 node = self.root 89 stack = [] 90 a = [] 91 while stack or node: 92 if node: 93 stack.append(node) 94 node = node.left 95 else: 96 node = stack.pop() 97 a.append(node.key) 98 node = node.right 99 return a
100 101 def _new(self, root: Optional["Node"]) -> "PersistentSegmentTree[T]": 102 return PersistentSegmentTree([], self.op, self.e, root) 103
[docs] 104 def copy(self) -> "PersistentSegmentTree[T]": 105 root = self.root.copy() if self.root else None 106 return self._new(root)
107
[docs] 108 def set(self, k: int, v: T) -> "PersistentSegmentTree[T]": 109 if k < 0: 110 k += len(self) 111 node = self.root.copy() 112 root = node 113 pnode = None 114 d = 0 115 path = [node] 116 while True: 117 t = 0 if node.left is None else node.left.size 118 if t == k: 119 node = node.copy() 120 node.key = v 121 path.append(node) 122 if pnode: 123 if d: 124 pnode.left = node 125 else: 126 pnode.right = node 127 else: 128 root = node 129 while path: 130 self._update(path.pop()) 131 return self._new(root) 132 pnode = node 133 if t < k: 134 k -= t + 1 135 d = 0 136 node = node.right.copy() 137 pnode.right = node 138 else: 139 d = 1 140 node = node.left.copy() 141 pnode.left = node 142 path.append(node)
143 144 def __getitem__(self, k: int) -> T: 145 if k < 0: 146 k += len(self) 147 node = self.root 148 while True: 149 t = 0 if node.left is None else node.left.size 150 if t == k: 151 return node.key 152 if t < k: 153 k -= t + 1 154 node = node.right 155 else: 156 node = node.left 157 158 def __len__(self): 159 return 0 if self.root is None else self.root.size 160 161 def __str__(self): 162 return str(self.tolist()) 163 164 def __repr__(self): 165 return f"{self.__class__.__name__}({self})"