1from typing import Union, Callable, TypeVar, Generic, Iterable
2
3T = TypeVar("T")
4F = TypeVar("F")
5
6
[docs]
7class DualSegmentTreeCommutative(Generic[T, F]):
8 """双対セグ木です。"""
9
10 def __init__(
11 self,
12 n_or_a: Union[int, Iterable[T]],
13 mapping: Callable[[F, T], T],
14 composition: Callable[[F, F], F],
15 e: T,
16 id: F,
17 ) -> None:
18 self.mapping: Callable[[F, T], T] = mapping
19 self.composition: Callable[[F, F], F] = composition
20 self.e: T = e
21 self.id: F = id
22 self.data: list[T] = [e] * n_or_a if isinstance(n_or_a, int) else list(n_or_a)
23 self.n: int = len(self.data)
24 self.log: int = (self.n - 1).bit_length()
25 self.size: int = 1 << self.log
26 self.lazy: list[F] = [id] * self.size
27
28 def _all_apply(self, k: int, f: F) -> None:
29 if k < self.size:
30 self.lazy[k] = self.composition(f, self.lazy[k])
31 return
32 k -= self.size
33 if k < self.n:
34 self.data[k] = self.mapping(f, self.data[k])
35
[docs]
36 def apply_point(self, k: int, f: F) -> None:
37 assert (
38 0 <= k < self.n
39 ), f"IndexError: {self.__class__.__name__}.apply_point({k}, {f}), n={self.n}"
40 k += self.size
41 self.data[k - self.size] = self.mapping(f, self.data[k - self.size])
42
43 def _propagate(self, k: int) -> None:
44 self._all_apply(k << 1, self.lazy[k])
45 self._all_apply(k << 1 | 1, self.lazy[k])
46 self.lazy[k] = self.id
47
[docs]
48 def apply(self, l: int, r: int, f: F) -> None:
49 assert (
50 0 <= l <= r <= self.n
51 ), f"IndexError: {self.__class__.__name__}.apply({l}, {r}, {f}), n={self.n}"
52 if l == r:
53 return
54 if f == self.id:
55 return
56 l += self.size
57 r += self.size
58 lazy = self.lazy
59 l >>= 1
60 r >>= 1
61 while l < r:
62 if l & 1:
63 lazy[l] = self.composition(f, lazy[l])
64 l += 1
65 if r & 1:
66 r ^= 1
67 lazy[r] = self.composition(f, lazy[r])
68 l >>= 1
69 r >>= 1
70
[docs]
71 def all_apply(self, f: F) -> None:
72 self.lazy[1] = self.composition(f, self.lazy[1])
73
[docs]
74 def all_propagate(self) -> None:
75 for i in range(self.size):
76 self._propagate(i)
77
[docs]
78 def tolist(self) -> list[T]:
79 self.all_propagate()
80 return self.data[:]
81
82 def __getitem__(self, k: int) -> T:
83 assert (
84 -self.n <= k < self.n
85 ), f"IndexError: {self.__class__.__name__}[{k}], n={self.n}"
86 if k < 0:
87 k += self.n
88 fs = self.id
89 k += self.size
90 for i in range(self.log, 0, -1):
91 fs = self.composition(fs, self.lazy[k >> i])
92 return self.mapping(fs, self.data[k - self.size])
93
94 def __setitem__(self, k: int, v: T) -> None:
95 assert (
96 -self.n <= k < self.n
97 ), f"IndexError: {self.__class__.__name__}[{k}] = {v}, n={self.n}"
98 if k < 0:
99 k += self.n
100 k += self.size
101 for i in range(self.log, 0, -1):
102 self._propagate(k >> i)
103 self.data[k - self.size] = v
104
105 def __str__(self):
106 return str([self[i] for i in range(self.n)])
107
108 def __repr__(self):
109 return f"{self.__class__.__name__}({self})"