1from typing import Callable, TypeVar
2
3T = TypeVar("T")
4E = TypeVar("E")
5
6
[docs]
7def rerooting_dp(
8 G: list[list[tuple[int, int]]],
9 merge: Callable[[T, T], T],
10 apply_vertex: Callable[[T, int], T],
11 apply_edge: Callable[[T, E, int, int], T],
12 e: T,
13) -> None:
14 """全方位木DP
15
16 Args:
17 G (list[list[tuple[int, int]]]):
18 merge (Callable[[T, T], T]):
19 apply_vertex (Callable[[T, int], T]):
20 apply_edge (Callable[[T, E, int, int], T]):
21 e (T):
22
23 Returns:
24 _type_:
25 """
26 n = len(G)
27
28 dp: list[list[T]] = [[e] * len(g) for g in G]
29 ans = [e] * len(G)
30
31 root = 0
32 par = [-2] * n
33 par[root] = -1
34 toposo = [root]
35
36 todo = [root]
37 while todo:
38 v = todo.pop()
39 for x, _ in G[v]:
40 if par[x] != -2:
41 continue
42 par[x] = v
43 toposo.append(x)
44 todo.append(x)
45
46 arr = [e] * n
47 for v in toposo[::-1]:
48 dp_v = dp[v]
49 acc = e
50 for i, (x, c) in enumerate(G[v]):
51 if x == par[v]:
52 continue
53 dp_v[i] = apply_edge(arr[x], c, x, v)
54 acc = merge(acc, dp_v[i])
55 arr[v] = apply_vertex(acc, v)
56
57 dp_par = [e] * n
58 acc_l = [e] * (n + 1)
59 acc_r = [e] * (n + 1)
60 for v in toposo:
61 dp_v = dp[v]
62 for i, (x, _) in enumerate(G[v]):
63 if x == par[v]:
64 dp_v[i] = dp_par[v]
65 break
66 d = len(dp_v)
67 for i in range(d):
68 acc_l[i + 1] = merge(acc_l[i], dp_v[i])
69 acc_r[i + 1] = merge(acc_r[i], dp_v[d - i - 1])
70 ans[v] = apply_vertex(acc_l[d], v)
71 for i, (x, c) in enumerate(G[v]):
72 if x == par[v]:
73 continue
74 dp_par[x] = apply_edge(
75 apply_vertex(
76 merge(acc_l[i], acc_r[d - i - 1]),
77 v,
78 ),
79 c,
80 v,
81 x,
82 )
83 return ans
84
85
86# apply_vertex(dp_x: T, v: int) -> T:
87# v } return
88# -------------- }
89# | / | \ | }
90# | o o o | } dp_x (mergeしたもの)
91# | △ △ △ |
92# --------------
93
94# apply_edge(dp_x: T, e: E, x: int, v: int) -> T:
95# v } return
96# | } e
97# x | } dp_x
98# △|
99
100# def merge(s: T, t: T) -> T:
101# """``s`` , ``t`` をマージする"""
102# ...
103
104# def apply_vertex(dp_x: T, v: int) -> T:
105# ...
106
107# def apply_edge(dp_x: T, e: E, x: int, v: int) -> T:
108# ...
109
110# e: T = ...