hld_lazy_segment_tree

ソースコード

from titan_pylib.graph.hld.hld_lazy_segment_tree import HLDLazySegmentTree

view on github

展開済みコード

  1# from titan_pylib.graph.hld.hld_lazy_segment_tree import HLDLazySegmentTree
  2# from titan_pylib.data_structures.segment_tree.lazy_segment_tree import LazySegmentTree
  3from typing import Union, Callable, TypeVar, Generic, Iterable
  4
  5T = TypeVar("T")
  6F = TypeVar("F")
  7
  8
  9class LazySegmentTree(Generic[T, F]):
 10    """遅延セグ木です。"""
 11
 12    def __init__(
 13        self,
 14        n_or_a: Union[int, Iterable[T]],
 15        op: Callable[[T, T], T],
 16        mapping: Callable[[F, T], T],
 17        composition: Callable[[F, F], F],
 18        e: T,
 19        id: F,
 20    ) -> None:
 21        self.op: Callable[[T, T], T] = op
 22        self.mapping: Callable[[F, T], T] = mapping
 23        self.composition: Callable[[F, F], F] = composition
 24        self.e: T = e
 25        self.id: F = id
 26        if isinstance(n_or_a, int):
 27            self.n = n_or_a
 28            self.log = (self.n - 1).bit_length()
 29            self.size = 1 << self.log
 30            self.data = [e] * (self.size << 1)
 31        else:
 32            a = list(n_or_a)
 33            self.n = len(a)
 34            self.log = (self.n - 1).bit_length()
 35            self.size = 1 << self.log
 36            data = [e] * (self.size << 1)
 37            data[self.size : self.size + self.n] = a
 38            for i in range(self.size - 1, 0, -1):
 39                data[i] = op(data[i << 1], data[i << 1 | 1])
 40            self.data = data
 41        self.lazy = [id] * self.size
 42
 43    def _update(self, k: int) -> None:
 44        self.data[k] = self.op(self.data[k << 1], self.data[k << 1 | 1])
 45
 46    def _all_apply(self, k: int, f: F) -> None:
 47        self.data[k] = self.mapping(f, self.data[k])
 48        if k >= self.size:
 49            return
 50        self.lazy[k] = self.composition(f, self.lazy[k])
 51
 52    def _propagate(self, k: int) -> None:
 53        if self.lazy[k] == self.id:
 54            return
 55        self._all_apply(k << 1, self.lazy[k])
 56        self._all_apply(k << 1 | 1, self.lazy[k])
 57        self.lazy[k] = self.id
 58
 59    def apply_point(self, k: int, f: F) -> None:
 60        k += self.size
 61        for i in range(self.log, 0, -1):
 62            self._propagate(k >> i)
 63        self.data[k] = self.mapping(f, self.data[k])
 64        for i in range(1, self.log + 1):
 65            self._update(k >> i)
 66
 67    def _upper_propagate(self, l: int, r: int) -> None:
 68        for i in range(self.log, 0, -1):
 69            if l >> i << i != l:
 70                self._propagate(l >> i)
 71            if (r >> i << i != r) and (l >> i != (r - 1) >> i or l >> i << i == l):
 72                self._propagate((r - 1) >> i)
 73
 74    def apply(self, l: int, r: int, f: F) -> None:
 75        assert (
 76            0 <= l <= r <= self.n
 77        ), f"IndexError: {self.__class__.__name__}.apply({l}, {r}, {f}), n={self.n}"
 78        if l == r:
 79            return
 80        if f == self.id:
 81            return
 82        l += self.size
 83        r += self.size
 84        self._upper_propagate(l, r)
 85        l2, r2 = l, r
 86        while l < r:
 87            if l & 1:
 88                self._all_apply(l, f)
 89                l += 1
 90            if r & 1:
 91                self._all_apply(r ^ 1, f)
 92            l >>= 1
 93            r >>= 1
 94        ll, rr = l2, r2 - 1
 95        for i in range(1, self.log + 1):
 96            ll >>= 1
 97            rr >>= 1
 98            if ll << i != l2:
 99                self._update(ll)
100            if (ll << i == l2 or ll != rr) and (r2 >> i << i != r2):
101                self._update(rr)
102
103    def all_apply(self, f: F) -> None:
104        self.lazy[1] = self.composition(f, self.lazy[1])
105
106    def prod(self, l: int, r: int) -> T:
107        assert (
108            0 <= l <= r <= self.n
109        ), f"IndexError: {self.__class__.__name__}.prod({l}, {r}), n={self.n}"
110        if l == r:
111            return self.e
112        l += self.size
113        r += self.size
114        self._upper_propagate(l, r)
115        lres = self.e
116        rres = self.e
117        while l < r:
118            if l & 1:
119                lres = self.op(lres, self.data[l])
120                l += 1
121            if r & 1:
122                rres = self.op(self.data[r ^ 1], rres)
123            l >>= 1
124            r >>= 1
125        return self.op(lres, rres)
126
127    def all_prod(self) -> T:
128        return self.data[1]
129
130    def all_propagate(self) -> None:
131        for i in range(self.size):
132            self._propagate(i)
133
134    def tolist(self) -> list[T]:
135        self.all_propagate()
136        return self.data[self.size : self.size + self.n]
137
138    def max_right(self, l, f) -> int:
139        assert 0 <= l <= self.n
140        # assert f(self.e)
141        if l == self.size:
142            return self.n
143        l += self.size
144        for i in range(self.log, 0, -1):
145            self._propagate(l >> i)
146        s = self.e
147        while True:
148            while l & 1 == 0:
149                l >>= 1
150            if not f(self.op(s, self.data[l])):
151                while l < self.size:
152                    self._propagate(l)
153                    l <<= 1
154                    if f(self.op(s, self.data[l])):
155                        s = self.op(s, self.data[l])
156                        l |= 1
157                return l - self.size
158            s = self.op(s, self.data[l])
159            l += 1
160            if l & -l == l:
161                break
162        return self.n
163
164    def min_left(self, r: int, f) -> int:
165        assert 0 <= r <= self.n
166        # assert f(self.e)
167        if r == 0:
168            return 0
169        r += self.size
170        for i in range(self.log, 0, -1):
171            self._propagate((r - 1) >> i)
172        s = self.e
173        while True:
174            r -= 1
175            while r > 1 and r & 1:
176                r >>= 1
177            if not f(self.op(self.data[r], s)):
178                while r < self.size:
179                    self._propagate(r)
180                    r = r << 1 | 1
181                    if f(self.op(self.data[r], s)):
182                        s = self.op(self.data[r], s)
183                        r ^= 1
184                return r + 1 - self.size
185            s = self.op(self.data[r], s)
186            if r & -r == r:
187                break
188        return 0
189
190    def __getitem__(self, k: int) -> T:
191        assert (
192            -self.n <= k < self.n
193        ), f"IndexError: {self.__class__.__name__}[{k}], n={self.n}"
194        if k < 0:
195            k += self.n
196        k += self.size
197        for i in range(self.log, 0, -1):
198            self._propagate(k >> i)
199        return self.data[k]
200
201    def __setitem__(self, k: int, v: T):
202        assert (
203            -self.n <= k < self.n
204        ), f"IndexError: {self.__class__.__name__}[{k}] = {v}, n={self.n}"
205        if k < 0:
206            k += self.n
207        k += self.size
208        for i in range(self.log, 0, -1):
209            self._propagate(k >> i)
210        self.data[k] = v
211        for i in range(1, self.log + 1):
212            self._update(k >> i)
213
214    def __str__(self) -> str:
215        return str(self.tolist())
216
217    def __repr__(self):
218        return f"{self.__class__.__name__}({self})"
219# from titan_pylib.graph.hld.hld import HLD
220from typing import Any, Iterator
221
222
223class HLD:
224
225    def __init__(self, G: list[list[int]], root: int):
226        """``root`` を根とする木 ``G`` を HLD します。
227        :math:`O(n)` です。
228
229        Args:
230          G (list[list[int]]): 木を表す隣接リストです。
231          root (int): 根です。
232        """
233        n = len(G)
234        self.n: int = n
235        self.G: list[list[int]] = G
236        self.size: list[int] = [1] * n
237        self.par: list[int] = [-1] * n
238        self.dep: list[int] = [-1] * n
239        self.nodein: list[int] = [0] * n
240        self.nodeout: list[int] = [0] * n
241        self.head: list[int] = [0] * n
242        self.hld: list[int] = [0] * n
243        self._dfs(root)
244
245    def _dfs(self, root: int) -> None:
246        dep, par, size, G = self.dep, self.par, self.size, self.G
247        dep[root] = 0
248        stack = [~root, root]
249        while stack:
250            v = stack.pop()
251            if v >= 0:
252                dep_nxt = dep[v] + 1
253                for x in G[v]:
254                    if dep[x] != -1:
255                        continue
256                    dep[x] = dep_nxt
257                    stack.append(~x)
258                    stack.append(x)
259            else:
260                v = ~v
261                G_v, dep_v = G[v], dep[v]
262                for i, x in enumerate(G_v):
263                    if dep[x] < dep_v:
264                        par[v] = x
265                        continue
266                    size[v] += size[x]
267                    if size[x] > size[G_v[0]]:
268                        G_v[0], G_v[i] = G_v[i], G_v[0]
269
270        head, nodein, nodeout, hld = self.head, self.nodein, self.nodeout, self.hld
271        curtime = 0
272        stack = [~root, root]
273        while stack:
274            v = stack.pop()
275            if v >= 0:
276                if par[v] == -1:
277                    head[v] = v
278                nodein[v] = curtime
279                hld[curtime] = v
280                curtime += 1
281                if not G[v]:
282                    continue
283                G_v0 = G[v][0]
284                for x in reversed(G[v]):
285                    if x == par[v]:
286                        continue
287                    head[x] = head[v] if x == G_v0 else x
288                    stack.append(~x)
289                    stack.append(x)
290            else:
291                nodeout[~v] = curtime
292
293    def build_list(self, a: list[Any]) -> list[Any]:
294        """``hld配列`` を基にインデックスを振りなおします。非破壊的です。
295        :math:`O(n)` です。
296
297        Args:
298            a (list[Any]): 元の配列です。
299
300        Returns:
301            list[Any]: 振りなおし後の配列です。
302        """
303        return [a[e] for e in self.hld]
304
305    def for_each_vertex_path(self, u: int, v: int) -> Iterator[tuple[int, int]]:
306        """``u-v`` パスに対応する区間のインデックスを返します。
307        :math:`O(\\log{n})` です。
308        """
309        head, nodein, dep, par = self.head, self.nodein, self.dep, self.par
310        while head[u] != head[v]:
311            if dep[head[u]] < dep[head[v]]:
312                u, v = v, u
313            yield nodein[head[u]], nodein[u] + 1
314            u = par[head[u]]
315        if dep[u] < dep[v]:
316            u, v = v, u
317        yield nodein[v], nodein[u] + 1
318
319    def for_each_vertex_subtree(self, v: int) -> Iterator[tuple[int, int]]:
320        """頂点 ``v`` の部分木に対応する区間のインデックスを返します。
321        :math:`O(1)` です。
322        """
323        yield self.nodein[v], self.nodeout[v]
324
325    def path_kth_elm(self, s: int, t: int, k: int) -> int:
326        """``s`` から ``t`` に向かって ``k`` 個進んだ頂点のインデックスを返します。
327        存在しないときは ``-1`` を返します。
328        :math:`O(\\log{n})` です。
329        """
330        head, dep, par = self.head, self.dep, self.par
331        lca = self.lca(s, t)
332        d = dep[s] + dep[t] - 2 * dep[lca]
333        if d < k:
334            return -1
335        if dep[s] - dep[lca] < k:
336            s = t
337            k = d - k
338        hs = head[s]
339        while dep[s] - dep[hs] < k:
340            k -= dep[s] - dep[hs] + 1
341            s = par[hs]
342            hs = head[s]
343        return self.hld[self.nodein[s] - k]
344
345    def lca(self, u: int, v: int) -> int:
346        """``u``, ``v`` の LCA を返します。
347        :math:`O(\\log{n})` です。
348        """
349        nodein, head, par = self.nodein, self.head, self.par
350        while True:
351            if nodein[u] > nodein[v]:
352                u, v = v, u
353            if head[u] == head[v]:
354                return u
355            v = par[head[v]]
356
357    def dist(self, u: int, v: int) -> int:
358        return self.dep[u] + self.dep[v] - 2 * self.dep[self.lca(u, v)]
359
360    def is_on_path(self, u: int, v: int, a: int) -> bool:
361        """Return True if (a is on path(u - v)) else False. / O(logN)"""
362        return self.dist(u, a) + self.dist(a, v) == self.dist(u, v)
363from typing import Union, Iterable, Callable, TypeVar, Generic
364
365T = TypeVar("T")
366F = TypeVar("F")
367
368
369class HLDLazySegmentTree(Generic[T, F]):
370    """遅延セグ木搭載HLDです。
371
372    Note:
373        **非可換に対応してます。**
374    """
375
376    def __init__(
377        self,
378        hld: HLD,
379        n_or_a: Union[int, Iterable[T]],
380        op: Callable[[T, T], T],
381        mapping: Callable[[F, T], T],
382        composition: Callable[[F, F], F],
383        e: T,
384        id: F,
385    ) -> None:
386        self.hld: HLD = hld
387        a = (
388            [e] * n_or_a
389            if isinstance(n_or_a, int)
390            else self.hld.build_list(list(n_or_a))
391        )
392        self.seg: LazySegmentTree[T, F] = LazySegmentTree(
393            a, op, mapping, composition, e, id
394        )
395        self.rseg: LazySegmentTree[T, F] = LazySegmentTree(
396            a[::-1], op, mapping, composition, e, id
397        )
398        self.op: Callable[[T, T], T] = op
399        self.e: T = e
400
401    def path_prod(self, u: int, v: int) -> T:
402        """頂点 ``u`` から頂点 ``v`` へのパスの集約値を返します。
403        :math:`O(\\log^2{n})` です。
404
405        Args:
406            u (int): パスの **始点** です。
407            v (int): パスの **終点** です。
408
409        Returns:
410            T: 求める集約値です。
411        """
412        head, nodein, dep, par, n = (
413            self.hld.head,
414            self.hld.nodein,
415            self.hld.dep,
416            self.hld.par,
417            self.hld.n,
418        )
419        lres, rres = self.e, self.e
420        seg, rseg = self.seg, self.rseg
421        while head[u] != head[v]:
422            if dep[head[u]] > dep[head[v]]:
423                lres = self.op(lres, rseg.prod(n - nodein[u] - 1, n - nodein[head[u]]))
424                u = par[head[u]]
425            else:
426                rres = self.op(seg.prod(nodein[head[v]], nodein[v] + 1), rres)
427                v = par[head[v]]
428        if dep[u] > dep[v]:
429            lres = self.op(lres, rseg.prod(n - nodein[u] - 1, n - nodein[v]))
430        else:
431            lres = self.op(lres, seg.prod(nodein[u], nodein[v] + 1))
432        return self.op(lres, rres)
433
434    def path_apply(self, u: int, v: int, f: F) -> None:
435        """頂点 ``u`` から頂点 ``v`` へのパスに作用させます。
436        :math:`O(\\log^2{n})` です。
437
438        Args:
439            u (int): パスの **始点** です。
440            v (int): パスの **終点** です。
441            f (F): 作用素です。
442        """
443        head, nodein, dep, par = (
444            self.hld.head,
445            self.hld.nodein,
446            self.hld.dep,
447            self.hld.par,
448        )
449        while head[u] != head[v]:
450            if dep[head[u]] < dep[head[v]]:
451                u, v = v, u
452            self.seg.apply(nodein[head[u]], nodein[u] + 1, f)
453            self.rseg.apply(
454                self.hld.n - (nodein[u] + 1 - 1) - 1,
455                self.hld.n - nodein[head[u]] - 1 + 1,
456                f,
457            )
458            u = par[head[u]]
459        if dep[u] < dep[v]:
460            u, v = v, u
461        self.seg.apply(nodein[v], nodein[u] + 1, f)
462        self.rseg.apply(
463            self.hld.n - (nodein[u] + 1 - 1) - 1, self.hld.n - nodein[v] - 1 + 1, f
464        )
465
466    def get(self, k: int) -> T:
467        """頂点の値を返します。
468        :math:`O(\\log{n})` です。
469
470        Args:
471            k (int): 頂点のインデックスです。
472
473        Returns:
474            T: 頂点の値です。
475        """
476        return self.seg[self.hld.nodein[k]]
477
478    def set(self, k: int, v: T) -> None:
479        """頂点の値を更新します。
480        :math:`O(\\log{n})` です。
481
482        Args:
483            k (int): 頂点のインデックスです。
484            v (T): 更新する値です。
485        """
486        self.seg[self.hld.nodein[k]] = v
487        self.rseg[self.hld.n - self.hld.nodein[k] - 1] = v
488
489    __getitem__ = get
490    __setitem__ = set
491
492    def subtree_prod(self, v: int) -> T:
493        """部分木の集約値を返します。
494        :math:`O(\\log{n})` です。
495
496        Args:
497            v (int): 根とする頂点です。
498
499        Returns:
500            T: 求める集約値です。
501        """
502        return self.seg.prod(self.hld.nodein[v], self.hld.nodeout[v])
503
504    def subtree_apply(self, v: int, f: F) -> None:
505        """部分木に作用させます。
506        :math:`O(\\log{n})` です。
507
508        Args:
509            v (int): 根とする頂点です。
510            f (F): 作用素です。
511        """
512        self.seg.apply(self.hld.nodein[v], self.hld.nodeout[v], f)
513        self.rseg.apply(
514            self.hld.n - self.hld.nodeout[v] - 1 - 1,
515            self.hld.n - self.hld.nodein[v] - 1 + 1,
516            f,
517        )

仕様

class HLDLazySegmentTree(hld: HLD, n_or_a: int | Iterable[T], op: Callable[[T, T], T], mapping: Callable[[F, T], T], composition: Callable[[F, F], F], e: T, id: F)[source]

Bases: Generic[T, F]

遅延セグ木搭載HLDです。

Note

非可換に対応してます。

__getitem__(k: int) T

頂点の値を返します。 \(O(\log{n})\) です。

Parameters:

k (int) – 頂点のインデックスです。

Returns:

頂点の値です。

Return type:

T

__setitem__(k: int, v: T) None

頂点の値を更新します。 \(O(\log{n})\) です。

Parameters:
  • k (int) – 頂点のインデックスです。

  • v (T) – 更新する値です。

get(k: int) T[source]

頂点の値を返します。 \(O(\log{n})\) です。

Parameters:

k (int) – 頂点のインデックスです。

Returns:

頂点の値です。

Return type:

T

path_apply(u: int, v: int, f: F) None[source]

頂点 u から頂点 v へのパスに作用させます。 \(O(\log^2{n})\) です。

Parameters:
  • u (int) – パスの 始点 です。

  • v (int) – パスの 終点 です。

  • f (F) – 作用素です。

path_prod(u: int, v: int) T[source]

頂点 u から頂点 v へのパスの集約値を返します。 \(O(\log^2{n})\) です。

Parameters:
  • u (int) – パスの 始点 です。

  • v (int) – パスの 終点 です。

Returns:

求める集約値です。

Return type:

T

set(k: int, v: T) None[source]

頂点の値を更新します。 \(O(\log{n})\) です。

Parameters:
  • k (int) – 頂点のインデックスです。

  • v (T) – 更新する値です。

subtree_apply(v: int, f: F) None[source]

部分木に作用させます。 \(O(\log{n})\) です。

Parameters:
  • v (int) – 根とする頂点です。

  • f (F) – 作用素です。

subtree_prod(v: int) T[source]

部分木の集約値を返します。 \(O(\log{n})\) です。

Parameters:

v (int) – 根とする頂点です。

Returns:

求める集約値です。

Return type:

T