Source code for titan_pylib.graph.warshall_floyd

 1from typing import Union
 2
 3
 4def warshall_floyd(
 5    G: list[list[tuple[int, int]]], INF: Union[int, float] = float("inf")
 6) -> list[list[Union[int, float]]]:
 7    """重み付き隣接リスト ``G`` に対し、全点対最短経路を返します。
 8    :math:`O(n^3)` です。
 9
10    Args:
11        G (list[list[tuple[int, int]]]): 重み付き隣接リストです。
12        INF (Union[int, float], optional): 無限大です。
13
14    Returns:
15        list[list[Union[int, float]]]: dist[a][b] -> a to b
16    """
17    n = len(G)
18    dist = [[INF] * n for _ in range(n)]
19    for v in range(n):
20        dist_v_ = dist[v]
21        for x, c in G[v]:
22            dist_v_[x] = c
23        dist_v_[v] = 0
24    for k in range(n):
25        dist_k_ = dist[k]
26        for i in range(n):
27            dist_i_ = dist[i]
28            dist_i_k_ = dist_i_[k]
29            if dist_i_k_ == INF:
30                continue
31            for j, dist_k_j_ in enumerate(dist_k_):
32                if dist_i_[j] > dist_i_k_ + dist_k_j_:
33                    dist_i_[j] = dist_i_k_ + dist_k_j_
34    return dist
35
36
37from typing import Union
38
39"""Return min dist s.t. dist[a][b] -> a to b. / O(|n|^3)"""
40
41
[docs] 42def warshall_floyd( 43 D: list[list[int]], INF: Union[int, float] = float("inf") 44) -> list[list[Union[int, float]]]: 45 n = len(D) 46 dist = [d[:] for d in D] 47 for k in range(n): 48 dist_k_ = dist[k] 49 for i in range(n): 50 dist_i_ = dist[i] 51 dist_i_k_ = dist_i_[k] 52 if dist_i_k_ == INF: 53 continue 54 for j, dist_k_j_ in enumerate(dist_k_): 55 if dist_i_[j] > dist_i_k_ + dist_k_j_: 56 dist_i_[j] = dist_i_k_ + dist_k_j_ 57 return dist
58 59 60# from typing import Union 61# '''Return min dist s.t. dist[a][b] -> a to b. / O(|n|^3)''' 62# def warshall_floyd(G: list[list[tuple[int, int]]], INF: Union[int, float]=float('inf')) -> list[list[Union[int, float]]]: 63# n = len(G) 64# # dist = [dijkstra(G, s, INF) for s in range(n)] 65# dist = [[INF]*n for _ in range(n)] 66# for v in range(n): 67# for x, c in G[v]: 68# dist[v][x] = c 69# dist[v][v] = 0 70# for k in range(n): 71# for i in range(n): 72# if dist[i][k] == INF: continue 73# for j in range(n): 74# if dist[i][j] > dist[i][k] + dist[k][j]: 75# dist[i][j] = dist[i][k] + dist[k][j] 76# # elif dist[i][j] == dist[i][k] + dist[k][j]: 77# # dist[i][j] = dist[i][k] + dist[k][j] 78# ''' 79# for i in range(n): 80# if dist[i][i] < 0: 81# return 'NEGATIVE CYCLE' 82# ''' 83# return dist