get_scc_lowlink¶
ソースコード¶
from titan_pylib.graph.get_scc_lowlink import get_scc_lowlink
展開済みコード¶
1# from titan_pylib.graph.get_scc_lowlink import get_scc_lowlink
2# from titan_pylib.others.antirec import antirec
3from types import GeneratorType
4
5# ref: https://github.com/cheran-senthil/PyRival/blob/master/pyrival/misc/bootstrap.py
6# ref: https://twitter.com/onakasuita_py/status/1731535542305907041
7
8
9def antirec(func):
10 stack = []
11
12 def wrappedfunc(*args, **kwargs):
13 if stack:
14 return func(*args, **kwargs)
15 to = func(*args, **kwargs)
16 while True:
17 if isinstance(to, GeneratorType):
18 stack.append(to)
19 to = next(to)
20 else:
21 stack.pop()
22 if not stack:
23 break
24 to = stack[-1].send(to)
25 return to
26
27 return wrappedfunc
28
29
30def antirec_cache(func):
31 stack = []
32 memo = {}
33 args_list = []
34
35 def wrappedfunc(*args):
36 args_list.append(args)
37 if stack:
38 return func(*args)
39 to = func(*args)
40 while True:
41 if args_list[-1] in memo:
42 res = memo[args_list.pop()]
43 if not stack:
44 return res
45 to = stack[-1].send(res)
46 continue
47 if isinstance(to, GeneratorType):
48 stack.append(to)
49 to = next(to)
50 else:
51 memo[args_list.pop()] = to
52 stack.pop()
53 if not stack:
54 break
55 to = stack[-1].send(to)
56 return to
57
58 return wrappedfunc
59
60
61def get_scc_lowlink(G: list[list[int]]) -> list[list[int]]:
62 n = len(G)
63 stack = [0] * n
64 ptr = 0
65 lowlink = [-1] * n
66 order = [-1] * n
67 ids = [0] * n
68 cur_time = 0
69 group_cnt = 0
70
71 @antirec
72 def dfs(v: int):
73 nonlocal cur_time, ptr
74 order[v] = cur_time
75 lowlink[v] = cur_time
76 cur_time += 1
77 stack[ptr] = v
78 ptr += 1
79 for x in G[v]:
80 if order[x] == -1:
81 yield dfs(x)
82 lowlink[v] = min(lowlink[v], lowlink[x])
83 else:
84 lowlink[v] = min(lowlink[v], order[x])
85 if lowlink[v] == order[v]:
86 nonlocal group_cnt
87 while True:
88 u = stack[ptr - 1]
89 ptr -= 1
90 order[u] = n
91 ids[u] = group_cnt
92 if u == v:
93 break
94 group_cnt += 1
95 yield
96
97 for v in range(n):
98 if order[v] == -1:
99 dfs(v)
100 groups = [[] for _ in range(group_cnt)]
101 for v in range(n):
102 groups[group_cnt - 1 - ids[v]].append(v)
103 return groups