Source code for titan_pylib.graph.rerooting_dp

  1from typing import Callable, TypeVar
  2
  3T = TypeVar("T")
  4E = TypeVar("E")
  5
  6
[docs] 7def rerooting_dp( 8 G: list[list[tuple[int, int]]], 9 merge: Callable[[T, T], T], 10 apply_vertex: Callable[[T, int], T], 11 apply_edge: Callable[[T, E, int, int], T], 12 e: T, 13) -> None: 14 """全方位木DP 15 16 Args: 17 G (list[list[tuple[int, int]]]): 18 merge (Callable[[T, T], T]): 19 apply_vertex (Callable[[T, int], T]): 20 apply_edge (Callable[[T, E, int, int], T]): 21 e (T): 22 23 Returns: 24 _type_: 25 """ 26 n = len(G) 27 28 dp: list[list[T]] = [[e] * len(g) for g in G] 29 ans = [e] * len(G) 30 31 root = 0 32 par = [-2] * n 33 par[root] = -1 34 toposo = [root] 35 36 todo = [root] 37 while todo: 38 v = todo.pop() 39 for x, _ in G[v]: 40 if par[x] != -2: 41 continue 42 par[x] = v 43 toposo.append(x) 44 todo.append(x) 45 46 arr = [e] * n 47 for v in toposo[::-1]: 48 dp_v = dp[v] 49 acc = e 50 for i, (x, c) in enumerate(G[v]): 51 if x == par[v]: 52 continue 53 dp_v[i] = apply_edge(arr[x], c, x, v) 54 acc = merge(acc, dp_v[i]) 55 arr[v] = apply_vertex(acc, v) 56 57 dp_par = [e] * n 58 acc_l = [e] * (n + 1) 59 acc_r = [e] * (n + 1) 60 for v in toposo: 61 dp_v = dp[v] 62 for i, (x, _) in enumerate(G[v]): 63 if x == par[v]: 64 dp_v[i] = dp_par[v] 65 break 66 d = len(dp_v) 67 for i in range(d): 68 acc_l[i + 1] = merge(acc_l[i], dp_v[i]) 69 acc_r[i + 1] = merge(acc_r[i], dp_v[d - i - 1]) 70 ans[v] = apply_vertex(acc_l[d], v) 71 for i, (x, c) in enumerate(G[v]): 72 if x == par[v]: 73 continue 74 dp_par[x] = apply_edge( 75 apply_vertex( 76 merge(acc_l[i], acc_r[d - i - 1]), 77 v, 78 ), 79 c, 80 v, 81 x, 82 ) 83 return ans
84 85 86# apply_vertex(dp_x: T, v: int) -> T: 87# v } return 88# -------------- } 89# | / | \ | } 90# | o o o | } dp_x (mergeしたもの) 91# | △ △ △ | 92# -------------- 93 94# apply_edge(dp_x: T, e: E, x: int, v: int) -> T: 95# v } return 96# | } e 97# x | } dp_x 98# △| 99 100# def merge(s: T, t: T) -> T: 101# """``s`` , ``t`` をマージする""" 102# ... 103 104# def apply_vertex(dp_x: T, v: int) -> T: 105# ... 106 107# def apply_edge(dp_x: T, e: E, x: int, v: int) -> T: 108# ... 109 110# e: T = ...