get_scc_graph

ソースコード

from titan_pylib.graph.get_scc_graph import get_scc_graph

view on github

展開済みコード

  1# from titan_pylib.graph.get_scc_graph import get_scc_graph
  2# from titan_pylib.others.antirec import antirec
  3from types import GeneratorType
  4
  5# ref: https://github.com/cheran-senthil/PyRival/blob/master/pyrival/misc/bootstrap.py
  6# ref: https://twitter.com/onakasuita_py/status/1731535542305907041
  7
  8
  9def antirec(func):
 10    stack = []
 11
 12    def wrappedfunc(*args, **kwargs):
 13        if stack:
 14            return func(*args, **kwargs)
 15        to = func(*args, **kwargs)
 16        while True:
 17            if isinstance(to, GeneratorType):
 18                stack.append(to)
 19                to = next(to)
 20            else:
 21                stack.pop()
 22                if not stack:
 23                    break
 24                to = stack[-1].send(to)
 25        return to
 26
 27    return wrappedfunc
 28
 29
 30def antirec_cache(func):
 31    stack = []
 32    memo = {}
 33    args_list = []
 34
 35    def wrappedfunc(*args):
 36        args_list.append(args)
 37        if stack:
 38            return func(*args)
 39        to = func(*args)
 40        while True:
 41            if args_list[-1] in memo:
 42                res = memo[args_list.pop()]
 43                if not stack:
 44                    return res
 45                to = stack[-1].send(res)
 46                continue
 47            if isinstance(to, GeneratorType):
 48                stack.append(to)
 49                to = next(to)
 50            else:
 51                memo[args_list.pop()] = to
 52                stack.pop()
 53                if not stack:
 54                    break
 55                to = stack[-1].send(to)
 56        return to
 57
 58    return wrappedfunc
 59
 60
 61def get_scc_graph(G: list[list[int]]):
 62    """
 63    scc, 頂点を縮約した隣接リスト, もとの頂点->新たなグラフの頂点
 64    """
 65    n = len(G)
 66    stack = [0] * n
 67    ptr = 0
 68    lowlink = [-1] * n
 69    order = [-1] * n
 70    ids = [0] * n
 71    cur_time = 0
 72    group_cnt = 0
 73
 74    @antirec
 75    def dfs(v: int):
 76        nonlocal cur_time, ptr
 77        order[v] = cur_time
 78        lowlink[v] = cur_time
 79        cur_time += 1
 80        stack[ptr] = v
 81        ptr += 1
 82        for x in G[v]:
 83            if order[x] == -1:
 84                yield dfs(x)
 85                lowlink[v] = min(lowlink[v], lowlink[x])
 86            else:
 87                lowlink[v] = min(lowlink[v], order[x])
 88        if lowlink[v] == order[v]:
 89            nonlocal group_cnt
 90            while True:
 91                u = stack[ptr - 1]
 92                ptr -= 1
 93                order[u] = n
 94                ids[u] = group_cnt
 95                if u == v:
 96                    break
 97            group_cnt += 1
 98        yield
 99
100    for v in range(n):
101        if order[v] == -1:
102            dfs(v)
103    groups = [[] for _ in range(group_cnt)]
104    for v in range(n):
105        groups[group_cnt - 1 - ids[v]].append(v)
106
107    F = [set() for _ in range(max(ids) + 1)]
108    for v in range(n):
109        for x in G[v]:
110            if ids[v] != ids[x]:
111                F[ids[v]].add(ids[x])
112    F = [list(f) for f in F]
113    return groups, F, ids

仕様

get_scc_graph(G: list[list[int]])[source]

scc, 頂点を縮約した隣接リスト, もとの頂点->新たなグラフの頂点