get_scc_graph¶
ソースコード¶
from titan_pylib.graph.get_scc_graph import get_scc_graph
展開済みコード¶
1# from titan_pylib.graph.get_scc_graph import get_scc_graph
2# from titan_pylib.others.antirec import antirec
3from types import GeneratorType
4
5# ref: https://github.com/cheran-senthil/PyRival/blob/master/pyrival/misc/bootstrap.py
6# ref: https://twitter.com/onakasuita_py/status/1731535542305907041
7
8
9def antirec(func):
10 stack = []
11
12 def wrappedfunc(*args, **kwargs):
13 if stack:
14 return func(*args, **kwargs)
15 to = func(*args, **kwargs)
16 while True:
17 if isinstance(to, GeneratorType):
18 stack.append(to)
19 to = next(to)
20 else:
21 stack.pop()
22 if not stack:
23 break
24 to = stack[-1].send(to)
25 return to
26
27 return wrappedfunc
28
29
30def antirec_cache(func):
31 stack = []
32 memo = {}
33 args_list = []
34
35 def wrappedfunc(*args):
36 args_list.append(args)
37 if stack:
38 return func(*args)
39 to = func(*args)
40 while True:
41 if args_list[-1] in memo:
42 res = memo[args_list.pop()]
43 if not stack:
44 return res
45 to = stack[-1].send(res)
46 continue
47 if isinstance(to, GeneratorType):
48 stack.append(to)
49 to = next(to)
50 else:
51 memo[args_list.pop()] = to
52 stack.pop()
53 if not stack:
54 break
55 to = stack[-1].send(to)
56 return to
57
58 return wrappedfunc
59
60
61def get_scc_graph(G: list[list[int]]):
62 """
63 scc, 頂点を縮約した隣接リスト, もとの頂点->新たなグラフの頂点
64 """
65 n = len(G)
66 stack = [0] * n
67 ptr = 0
68 lowlink = [-1] * n
69 order = [-1] * n
70 ids = [0] * n
71 cur_time = 0
72 group_cnt = 0
73
74 @antirec
75 def dfs(v: int):
76 nonlocal cur_time, ptr
77 order[v] = cur_time
78 lowlink[v] = cur_time
79 cur_time += 1
80 stack[ptr] = v
81 ptr += 1
82 for x in G[v]:
83 if order[x] == -1:
84 yield dfs(x)
85 lowlink[v] = min(lowlink[v], lowlink[x])
86 else:
87 lowlink[v] = min(lowlink[v], order[x])
88 if lowlink[v] == order[v]:
89 nonlocal group_cnt
90 while True:
91 u = stack[ptr - 1]
92 ptr -= 1
93 order[u] = n
94 ids[u] = group_cnt
95 if u == v:
96 break
97 group_cnt += 1
98 yield
99
100 for v in range(n):
101 if order[v] == -1:
102 dfs(v)
103 groups = [[] for _ in range(group_cnt)]
104 for v in range(n):
105 groups[group_cnt - 1 - ids[v]].append(v)
106
107 F = [set() for _ in range(max(ids) + 1)]
108 for v in range(n):
109 for x in G[v]:
110 if ids[v] != ids[x]:
111 F[ids[v]].add(ids[x])
112 F = [list(f) for f in F]
113 return groups, F, ids