warshall_floyd¶
ソースコード¶
from titan_pylib.graph.warshall_floyd import warshall_floyd
from titan_pylib.graph.warshall_floyd import warshall_floyd
展開済みコード¶
1# from titan_pylib.graph.warshall_floyd import warshall_floyd
2from typing import Union
3
4
5def warshall_floyd(
6 G: list[list[tuple[int, int]]], INF: Union[int, float] = float("inf")
7) -> list[list[Union[int, float]]]:
8 """重み付き隣接リスト ``G`` に対し、全点対最短経路を返します。
9 :math:`O(n^3)` です。
10
11 Args:
12 G (list[list[tuple[int, int]]]): 重み付き隣接リストです。
13 INF (Union[int, float], optional): 無限大です。
14
15 Returns:
16 list[list[Union[int, float]]]: dist[a][b] -> a to b
17 """
18 n = len(G)
19 dist = [[INF] * n for _ in range(n)]
20 for v in range(n):
21 dist_v_ = dist[v]
22 for x, c in G[v]:
23 dist_v_[x] = c
24 dist_v_[v] = 0
25 for k in range(n):
26 dist_k_ = dist[k]
27 for i in range(n):
28 dist_i_ = dist[i]
29 dist_i_k_ = dist_i_[k]
30 if dist_i_k_ == INF:
31 continue
32 for j, dist_k_j_ in enumerate(dist_k_):
33 if dist_i_[j] > dist_i_k_ + dist_k_j_:
34 dist_i_[j] = dist_i_k_ + dist_k_j_
35 return dist
36
37
38from typing import Union
39
40"""Return min dist s.t. dist[a][b] -> a to b. / O(|n|^3)"""
41
42
43def warshall_floyd(
44 D: list[list[int]], INF: Union[int, float] = float("inf")
45) -> list[list[Union[int, float]]]:
46 n = len(D)
47 dist = [d[:] for d in D]
48 for k in range(n):
49 dist_k_ = dist[k]
50 for i in range(n):
51 dist_i_ = dist[i]
52 dist_i_k_ = dist_i_[k]
53 if dist_i_k_ == INF:
54 continue
55 for j, dist_k_j_ in enumerate(dist_k_):
56 if dist_i_[j] > dist_i_k_ + dist_k_j_:
57 dist_i_[j] = dist_i_k_ + dist_k_j_
58 return dist
59
60
61# from typing import Union
62# '''Return min dist s.t. dist[a][b] -> a to b. / O(|n|^3)'''
63# def warshall_floyd(G: list[list[tuple[int, int]]], INF: Union[int, float]=float('inf')) -> list[list[Union[int, float]]]:
64# n = len(G)
65# # dist = [dijkstra(G, s, INF) for s in range(n)]
66# dist = [[INF]*n for _ in range(n)]
67# for v in range(n):
68# for x, c in G[v]:
69# dist[v][x] = c
70# dist[v][v] = 0
71# for k in range(n):
72# for i in range(n):
73# if dist[i][k] == INF: continue
74# for j in range(n):
75# if dist[i][j] > dist[i][k] + dist[k][j]:
76# dist[i][j] = dist[i][k] + dist[k][j]
77# # elif dist[i][j] == dist[i][k] + dist[k][j]:
78# # dist[i][j] = dist[i][k] + dist[k][j]
79# '''
80# for i in range(n):
81# if dist[i][i] < 0:
82# return 'NEGATIVE CYCLE'
83# '''
84# return dist