hld_segment_tree

ソースコード

from titan_pylib.graph.hld.hld_segment_tree import HLDSegmentTree

view on github

展開済みコード

  1# from titan_pylib.graph.hld.hld_segment_tree import HLDSegmentTree
  2# from titan_pylib.data_structures.segment_tree.segment_tree import SegmentTree
  3# from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
  4#     SegmentTreeInterface,
  5# )
  6from abc import ABC, abstractmethod
  7from typing import TypeVar, Generic, Union, Iterable, Callable
  8
  9T = TypeVar("T")
 10
 11
 12class SegmentTreeInterface(ABC, Generic[T]):
 13
 14    @abstractmethod
 15    def __init__(self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
 16        raise NotImplementedError
 17
 18    @abstractmethod
 19    def set(self, k: int, v: T) -> None:
 20        raise NotImplementedError
 21
 22    @abstractmethod
 23    def get(self, k: int) -> T:
 24        raise NotImplementedError
 25
 26    @abstractmethod
 27    def prod(self, l: int, r: int) -> T:
 28        raise NotImplementedError
 29
 30    @abstractmethod
 31    def all_prod(self) -> T:
 32        raise NotImplementedError
 33
 34    @abstractmethod
 35    def max_right(self, l: int, f: Callable[[T], bool]) -> int:
 36        raise NotImplementedError
 37
 38    @abstractmethod
 39    def min_left(self, r: int, f: Callable[[T], bool]) -> int:
 40        raise NotImplementedError
 41
 42    @abstractmethod
 43    def tolist(self) -> list[T]:
 44        raise NotImplementedError
 45
 46    @abstractmethod
 47    def __getitem__(self, k: int) -> T:
 48        raise NotImplementedError
 49
 50    @abstractmethod
 51    def __setitem__(self, k: int, v: T) -> None:
 52        raise NotImplementedError
 53
 54    @abstractmethod
 55    def __str__(self):
 56        raise NotImplementedError
 57
 58    @abstractmethod
 59    def __repr__(self):
 60        raise NotImplementedError
 61from typing import Generic, Iterable, TypeVar, Callable, Union
 62
 63T = TypeVar("T")
 64
 65
 66class SegmentTree(SegmentTreeInterface, Generic[T]):
 67    """セグ木です。非再帰です。"""
 68
 69    def __init__(
 70        self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T
 71    ) -> None:
 72        """``SegmentTree`` を構築します。
 73        :math:`O(n)` です。
 74
 75        Args:
 76            n_or_a (Union[int, Iterable[T]]): ``n: int`` のとき、 ``e`` を初期値として長さ ``n`` の ``SegmentTree`` を構築します。
 77                                              ``a: Iterable[T]`` のとき、 ``a`` から ``SegmentTree`` を構築します。
 78            op (Callable[[T, T], T]): 2項演算の関数です。
 79            e (T): 単位元です。
 80        """
 81        self._op = op
 82        self._e = e
 83        if isinstance(n_or_a, int):
 84            self._n = n_or_a
 85            self._log = (self._n - 1).bit_length()
 86            self._size = 1 << self._log
 87            self._data = [e] * (self._size << 1)
 88        else:
 89            n_or_a = list(n_or_a)
 90            self._n = len(n_or_a)
 91            self._log = (self._n - 1).bit_length()
 92            self._size = 1 << self._log
 93            _data = [e] * (self._size << 1)
 94            _data[self._size : self._size + self._n] = n_or_a
 95            for i in range(self._size - 1, 0, -1):
 96                _data[i] = op(_data[i << 1], _data[i << 1 | 1])
 97            self._data = _data
 98
 99    def set(self, k: int, v: T) -> None:
100        """一点更新です。
101        :math:`O(\\log{n})` です。
102
103        Args:
104            k (int): 更新するインデックスです。
105            v (T): 更新する値です。
106
107        制約:
108            :math:`-n \\leq n \\leq k < n`
109        """
110        assert (
111            -self._n <= k < self._n
112        ), f"IndexError: {self.__class__.__name__}.set({k}, {v}), n={self._n}"
113        if k < 0:
114            k += self._n
115        k += self._size
116        self._data[k] = v
117        for _ in range(self._log):
118            k >>= 1
119            self._data[k] = self._op(self._data[k << 1], self._data[k << 1 | 1])
120
121    def get(self, k: int) -> T:
122        """一点取得です。
123        :math:`O(1)` です。
124
125        Args:
126            k (int): インデックスです。
127
128        制約:
129            :math:`-n \\leq n \\leq k < n`
130        """
131        assert (
132            -self._n <= k < self._n
133        ), f"IndexError: {self.__class__.__name__}.get({k}), n={self._n}"
134        if k < 0:
135            k += self._n
136        return self._data[k + self._size]
137
138    def prod(self, l: int, r: int) -> T:
139        """区間 ``[l, r)`` の総積を返します。
140        :math:`O(\\log{n})` です。
141
142        Args:
143            l (int): インデックスです。
144            r (int): インデックスです。
145
146        制約:
147            :math:`0 \\leq l \\leq r \\leq n`
148        """
149        assert (
150            0 <= l <= r <= self._n
151        ), f"IndexError: {self.__class__.__name__}.prod({l}, {r})"
152        l += self._size
153        r += self._size
154        lres = self._e
155        rres = self._e
156        while l < r:
157            if l & 1:
158                lres = self._op(lres, self._data[l])
159                l += 1
160            if r & 1:
161                rres = self._op(self._data[r ^ 1], rres)
162            l >>= 1
163            r >>= 1
164        return self._op(lres, rres)
165
166    def all_prod(self) -> T:
167        """区間 ``[0, n)`` の総積を返します。
168        :math:`O(1)` です。
169        """
170        return self._data[1]
171
172    def max_right(self, l: int, f: Callable[[T], bool]) -> int:
173        """Find the largest index R s.t. f([l, R)) == True. / O(\\log{n})"""
174        assert (
175            0 <= l <= self._n
176        ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
177        # assert f(self._e), \
178        #     f'{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true.'
179        if l == self._n:
180            return self._n
181        l += self._size
182        s = self._e
183        while True:
184            while l & 1 == 0:
185                l >>= 1
186            if not f(self._op(s, self._data[l])):
187                while l < self._size:
188                    l <<= 1
189                    if f(self._op(s, self._data[l])):
190                        s = self._op(s, self._data[l])
191                        l |= 1
192                return l - self._size
193            s = self._op(s, self._data[l])
194            l += 1
195            if l & -l == l:
196                break
197        return self._n
198
199    def min_left(self, r: int, f: Callable[[T], bool]) -> int:
200        """Find the smallest index L s.t. f([L, r)) == True. / O(\\log{n})"""
201        assert (
202            0 <= r <= self._n
203        ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
204        # assert f(self._e), \
205        #     f'{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true.'
206        if r == 0:
207            return 0
208        r += self._size
209        s = self._e
210        while True:
211            r -= 1
212            while r > 1 and r & 1:
213                r >>= 1
214            if not f(self._op(self._data[r], s)):
215                while r < self._size:
216                    r = r << 1 | 1
217                    if f(self._op(self._data[r], s)):
218                        s = self._op(self._data[r], s)
219                        r ^= 1
220                return r + 1 - self._size
221            s = self._op(self._data[r], s)
222            if r & -r == r:
223                break
224        return 0
225
226    def tolist(self) -> list[T]:
227        """リストにして返します。
228        :math:`O(n)` です。
229        """
230        return [self.get(i) for i in range(self._n)]
231
232    def show(self) -> None:
233        """デバッグ用のメソッドです。"""
234        print(
235            f"<{self.__class__.__name__}> [\n"
236            + "\n".join(
237                [
238                    "  "
239                    + " ".join(
240                        map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
241                    )
242                    for i in range(self._log + 1)
243                ]
244            )
245            + "\n]"
246        )
247
248    def __getitem__(self, k: int) -> T:
249        assert (
250            -self._n <= k < self._n
251        ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), n={self._n}"
252        return self.get(k)
253
254    def __setitem__(self, k: int, v: T):
255        assert (
256            -self._n <= k < self._n
257        ), f"IndexError: {self.__class__.__name__}.__setitem__{k}, {v}), n={self._n}"
258        self.set(k, v)
259
260    def __len__(self) -> int:
261        return self._n
262
263    def __str__(self) -> str:
264        return str(self.tolist())
265
266    def __repr__(self) -> str:
267        return f"{self.__class__.__name__}({self})"
268# from titan_pylib.graph.hld.hld import HLD
269from typing import Any, Iterator
270
271
272class HLD:
273
274    def __init__(self, G: list[list[int]], root: int):
275        """``root`` を根とする木 ``G`` を HLD します。
276        :math:`O(n)` です。
277
278        Args:
279          G (list[list[int]]): 木を表す隣接リストです。
280          root (int): 根です。
281        """
282        n = len(G)
283        self.n: int = n
284        self.G: list[list[int]] = G
285        self.size: list[int] = [1] * n
286        self.par: list[int] = [-1] * n
287        self.dep: list[int] = [-1] * n
288        self.nodein: list[int] = [0] * n
289        self.nodeout: list[int] = [0] * n
290        self.head: list[int] = [0] * n
291        self.hld: list[int] = [0] * n
292        self._dfs(root)
293
294    def _dfs(self, root: int) -> None:
295        dep, par, size, G = self.dep, self.par, self.size, self.G
296        dep[root] = 0
297        stack = [~root, root]
298        while stack:
299            v = stack.pop()
300            if v >= 0:
301                dep_nxt = dep[v] + 1
302                for x in G[v]:
303                    if dep[x] != -1:
304                        continue
305                    dep[x] = dep_nxt
306                    stack.append(~x)
307                    stack.append(x)
308            else:
309                v = ~v
310                G_v, dep_v = G[v], dep[v]
311                for i, x in enumerate(G_v):
312                    if dep[x] < dep_v:
313                        par[v] = x
314                        continue
315                    size[v] += size[x]
316                    if size[x] > size[G_v[0]]:
317                        G_v[0], G_v[i] = G_v[i], G_v[0]
318
319        head, nodein, nodeout, hld = self.head, self.nodein, self.nodeout, self.hld
320        curtime = 0
321        stack = [~root, root]
322        while stack:
323            v = stack.pop()
324            if v >= 0:
325                if par[v] == -1:
326                    head[v] = v
327                nodein[v] = curtime
328                hld[curtime] = v
329                curtime += 1
330                if not G[v]:
331                    continue
332                G_v0 = G[v][0]
333                for x in reversed(G[v]):
334                    if x == par[v]:
335                        continue
336                    head[x] = head[v] if x == G_v0 else x
337                    stack.append(~x)
338                    stack.append(x)
339            else:
340                nodeout[~v] = curtime
341
342    def build_list(self, a: list[Any]) -> list[Any]:
343        """``hld配列`` を基にインデックスを振りなおします。非破壊的です。
344        :math:`O(n)` です。
345
346        Args:
347            a (list[Any]): 元の配列です。
348
349        Returns:
350            list[Any]: 振りなおし後の配列です。
351        """
352        return [a[e] for e in self.hld]
353
354    def for_each_vertex_path(self, u: int, v: int) -> Iterator[tuple[int, int]]:
355        """``u-v`` パスに対応する区間のインデックスを返します。
356        :math:`O(\\log{n})` です。
357        """
358        head, nodein, dep, par = self.head, self.nodein, self.dep, self.par
359        while head[u] != head[v]:
360            if dep[head[u]] < dep[head[v]]:
361                u, v = v, u
362            yield nodein[head[u]], nodein[u] + 1
363            u = par[head[u]]
364        if dep[u] < dep[v]:
365            u, v = v, u
366        yield nodein[v], nodein[u] + 1
367
368    def for_each_vertex_subtree(self, v: int) -> Iterator[tuple[int, int]]:
369        """頂点 ``v`` の部分木に対応する区間のインデックスを返します。
370        :math:`O(1)` です。
371        """
372        yield self.nodein[v], self.nodeout[v]
373
374    def path_kth_elm(self, s: int, t: int, k: int) -> int:
375        """``s`` から ``t`` に向かって ``k`` 個進んだ頂点のインデックスを返します。
376        存在しないときは ``-1`` を返します。
377        :math:`O(\\log{n})` です。
378        """
379        head, dep, par = self.head, self.dep, self.par
380        lca = self.lca(s, t)
381        d = dep[s] + dep[t] - 2 * dep[lca]
382        if d < k:
383            return -1
384        if dep[s] - dep[lca] < k:
385            s = t
386            k = d - k
387        hs = head[s]
388        while dep[s] - dep[hs] < k:
389            k -= dep[s] - dep[hs] + 1
390            s = par[hs]
391            hs = head[s]
392        return self.hld[self.nodein[s] - k]
393
394    def lca(self, u: int, v: int) -> int:
395        """``u``, ``v`` の LCA を返します。
396        :math:`O(\\log{n})` です。
397        """
398        nodein, head, par = self.nodein, self.head, self.par
399        while True:
400            if nodein[u] > nodein[v]:
401                u, v = v, u
402            if head[u] == head[v]:
403                return u
404            v = par[head[v]]
405
406    def dist(self, u: int, v: int) -> int:
407        return self.dep[u] + self.dep[v] - 2 * self.dep[self.lca(u, v)]
408
409    def is_on_path(self, u: int, v: int, a: int) -> bool:
410        """Return True if (a is on path(u - v)) else False. / O(logN)"""
411        return self.dist(u, a) + self.dist(a, v) == self.dist(u, v)
412from typing import Union, Iterable, Callable, TypeVar, Generic
413
414T = TypeVar("T")
415
416
417class HLDSegmentTree(Generic[T]):
418    """セグ木搭載HLDです。
419
420    Note:
421        **非可換に対応していません。**
422    """
423
424    def __init__(
425        self, hld: HLD, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T
426    ) -> None:
427        self.hld: HLD = hld
428        n_or_a = (
429            n_or_a if isinstance(n_or_a, int) else self.hld.build_list(list(n_or_a))
430        )
431        self.seg: SegmentTree[T] = SegmentTree(n_or_a, op, e)
432        self.op: Callable[[T, T], T] = op
433        self.e: T = e
434
435    def path_prod(self, u: int, v: int) -> T:
436        """頂点 ``u`` から頂点 ``v`` へのパスの集約値を返します。
437        :math:`O(\\log^2{n})` です。
438
439        Note:
440            **非可換に対応していません。**
441
442        Args:
443            u (int): パスの端点です。
444            v (int): パスの端点です。
445
446        Returns:
447            T: 求める集約値です。
448        """
449        head, nodein, dep, par = (
450            self.hld.head,
451            self.hld.nodein,
452            self.hld.dep,
453            self.hld.par,
454        )
455        res = self.e
456        while head[u] != head[v]:
457            if dep[head[u]] < dep[head[v]]:
458                u, v = v, u
459            res = self.op(res, self.seg.prod(nodein[head[u]], nodein[u] + 1))
460            u = par[head[u]]
461        if dep[u] < dep[v]:
462            u, v = v, u
463        return self.op(res, self.seg.prod(nodein[v], nodein[u] + 1))
464
465    def get(self, k: int) -> T:
466        """頂点の値を返します。
467        :math:`O(\\log{n})` です。
468
469        Args:
470            k (int): 頂点のインデックスです。
471
472        Returns:
473            T: 頂点の値です。
474        """
475        return self.seg[self.hld.nodein[k]]
476
477    def set(self, k: int, v: T) -> None:
478        """頂点の値を更新します。
479        :math:`O(\\log{n})` です。
480
481        Args:
482            k (int): 頂点のインデックスです。
483            v (T): 更新する値です。
484        """
485        self.seg[self.hld.nodein[k]] = v
486
487    __getitem__ = get
488    __setitem__ = set
489
490    def subtree_prod(self, v: int) -> T:
491        """部分木の集約値を返します。
492        :math:`O(\\log{n})` です。
493
494        Args:
495            v (int): 根とする頂点です。
496
497        Returns:
498            T: 求める集約値です。
499        """
500        return self.seg.prod(self.hld.nodein[v], self.hld.nodeout[v])

仕様

class HLDSegmentTree(hld: HLD, n_or_a: int | Iterable[T], op: Callable[[T, T], T], e: T)[source]

Bases: Generic[T]

セグ木搭載HLDです。

Note

非可換に対応していません。

__getitem__(k: int) T

頂点の値を返します。 \(O(\log{n})\) です。

Parameters:

k (int) – 頂点のインデックスです。

Returns:

頂点の値です。

Return type:

T

__setitem__(k: int, v: T) None

頂点の値を更新します。 \(O(\log{n})\) です。

Parameters:
  • k (int) – 頂点のインデックスです。

  • v (T) – 更新する値です。

get(k: int) T[source]

頂点の値を返します。 \(O(\log{n})\) です。

Parameters:

k (int) – 頂点のインデックスです。

Returns:

頂点の値です。

Return type:

T

path_prod(u: int, v: int) T[source]

頂点 u から頂点 v へのパスの集約値を返します。 \(O(\log^2{n})\) です。

Note

非可換に対応していません。

Parameters:
  • u (int) – パスの端点です。

  • v (int) – パスの端点です。

Returns:

求める集約値です。

Return type:

T

set(k: int, v: T) None[source]

頂点の値を更新します。 \(O(\log{n})\) です。

Parameters:
  • k (int) – 頂点のインデックスです。

  • v (T) – 更新する値です。

subtree_prod(v: int) T[source]

部分木の集約値を返します。 \(O(\log{n})\) です。

Parameters:

v (int) – 根とする頂点です。

Returns:

求める集約値です。

Return type:

T