Source code for titan_pylib.data_structures.segment_tree.persistent_segment_tree
1from typing import Generic, Iterable, Optional, TypeVar, Callable
2
3T = TypeVar("T")
4
5
[docs]
6class PersistentSegmentTree(Generic[T]):
7
[docs]
8 class Node:
9
10 def __init__(self, key: T) -> None:
11 self.key: T = key
12 self.data: T = key
13 self.left: Optional[PersistentSegmentTree.Node] = None
14 self.right: Optional[PersistentSegmentTree.Node] = None
15 self.size: int = 1
16
[docs]
17 def copy(self) -> "PersistentSegmentTree.Node":
18 node = PersistentSegmentTree.Node(self.key)
19 node.data = self.data
20 node.left = self.left
21 node.right = self.right
22 node.size = self.size
23 return node
24
25 def __init__(
26 self,
27 a: Iterable[T],
28 op: Callable[[T, T], T],
29 e: T,
30 __root: Optional[Node] = None,
31 ) -> None:
32 self.root: Optional[PersistentSegmentTree.Node] = __root
33 self.op: Callable[[T, T], T] = op
34 self.e: T = e
35 if __root is not None:
36 return
37 a = list(a)
38 if a:
39 self._build(a)
40
41 def _build(self, a: list[T]) -> None:
42 Node = PersistentSegmentTree.Node
43
44 def build(l: int, r: int) -> PersistentSegmentTree.Node:
45 mid = (l + r) >> 1
46 node = Node(a[mid])
47 if l != mid:
48 node.left = build(l, mid)
49 if mid + 1 != r:
50 node.right = build(mid + 1, r)
51 self._update(node)
52 return node
53
54 self.root = build(0, len(a))
55
56 def _update(self, node: Node) -> None:
57 node.size = 1
58 node.data = node.key
59 if node.left:
60 node.size += node.left.size
61 node.data = self.op(node.left.data, node.data)
62 if node.right:
63 node.size += node.right.size
64 node.data = self.op(node.data, node.right.data)
65
[docs]
66 def prod(self, l: int, r: int) -> T:
67 if l >= r or (not self.root):
68 return self.e
69
70 def dfs(node: PersistentSegmentTree.Node, left: int, right: int) -> T:
71 if right <= l or r <= left:
72 return self.e
73 if l <= left and right < r:
74 return node.data
75 lsize = node.left.size if node.left else 0
76 res = self.e
77 if node.left:
78 res = dfs(node.left, left, left + lsize)
79 if l <= left + lsize < r:
80 res = self.op(res, node.key)
81 if node.right:
82 res = self.op(res, dfs(node.right, left + lsize + 1, right))
83 return res
84
85 return dfs(self.root, 0, len(self))
86
[docs]
87 def tolist(self) -> list[T]:
88 node = self.root
89 stack = []
90 a = []
91 while stack or node:
92 if node:
93 stack.append(node)
94 node = node.left
95 else:
96 node = stack.pop()
97 a.append(node.key)
98 node = node.right
99 return a
100
101 def _new(self, root: Optional["Node"]) -> "PersistentSegmentTree[T]":
102 return PersistentSegmentTree([], self.op, self.e, root)
103
[docs]
104 def copy(self) -> "PersistentSegmentTree[T]":
105 root = self.root.copy() if self.root else None
106 return self._new(root)
107
[docs]
108 def set(self, k: int, v: T) -> "PersistentSegmentTree[T]":
109 if k < 0:
110 k += len(self)
111 node = self.root.copy()
112 root = node
113 pnode = None
114 d = 0
115 path = [node]
116 while True:
117 t = 0 if node.left is None else node.left.size
118 if t == k:
119 node = node.copy()
120 node.key = v
121 path.append(node)
122 if pnode:
123 if d:
124 pnode.left = node
125 else:
126 pnode.right = node
127 else:
128 root = node
129 while path:
130 self._update(path.pop())
131 return self._new(root)
132 pnode = node
133 if t < k:
134 k -= t + 1
135 d = 0
136 node = node.right.copy()
137 pnode.right = node
138 else:
139 d = 1
140 node = node.left.copy()
141 pnode.left = node
142 path.append(node)
143
144 def __getitem__(self, k: int) -> T:
145 if k < 0:
146 k += len(self)
147 node = self.root
148 while True:
149 t = 0 if node.left is None else node.left.size
150 if t == k:
151 return node.key
152 if t < k:
153 k -= t + 1
154 node = node.right
155 else:
156 node = node.left
157
158 def __len__(self):
159 return 0 if self.root is None else self.root.size
160
161 def __str__(self):
162 return str(self.tolist())
163
164 def __repr__(self):
165 return f"{self.__class__.__name__}({self})"