bipartite_max_matching

ソースコード

from titan_pylib.graph.flow.bipartite_max_matching import BipartiteMaxMatching

view on github

展開済みコード

  1# from titan_pylib.graph.flow.bipartite_max_matching import BipartiteMaxMatching
  2# from titan_pylib.graph.flow.max_flow_dinic import MaxFlowDinic
  3from collections import deque
  4# from titan_pylib.others.antirec import antirec
  5from types import GeneratorType
  6
  7# ref: https://github.com/cheran-senthil/PyRival/blob/master/pyrival/misc/bootstrap.py
  8# ref: https://twitter.com/onakasuita_py/status/1731535542305907041
  9
 10
 11def antirec(func):
 12    stack = []
 13
 14    def wrappedfunc(*args, **kwargs):
 15        if stack:
 16            return func(*args, **kwargs)
 17        to = func(*args, **kwargs)
 18        while True:
 19            if isinstance(to, GeneratorType):
 20                stack.append(to)
 21                to = next(to)
 22            else:
 23                stack.pop()
 24                if not stack:
 25                    break
 26                to = stack[-1].send(to)
 27        return to
 28
 29    return wrappedfunc
 30
 31
 32def antirec_cache(func):
 33    stack = []
 34    memo = {}
 35    args_list = []
 36
 37    def wrappedfunc(*args):
 38        args_list.append(args)
 39        if stack:
 40            return func(*args)
 41        to = func(*args)
 42        while True:
 43            if args_list[-1] in memo:
 44                res = memo[args_list.pop()]
 45                if not stack:
 46                    return res
 47                to = stack[-1].send(res)
 48                continue
 49            if isinstance(to, GeneratorType):
 50                stack.append(to)
 51                to = next(to)
 52            else:
 53                memo[args_list.pop()] = to
 54                stack.pop()
 55                if not stack:
 56                    break
 57                to = stack[-1].send(to)
 58        return to
 59
 60    return wrappedfunc
 61
 62
 63class MaxFlowDinic:
 64    """mf.G[v]:= [x, cap, ind, flow]"""
 65
 66    def __init__(self, n: int):
 67        self.n: int = n
 68        self.G: list[list[list[int]]] = [[] for _ in range(n)]
 69        self.level = [-1] * n
 70
 71    def add_edge(self, u: int, v: int, w: int) -> None:
 72        assert (
 73            0 <= u < self.n
 74        ), f"Indexerror: {self.__class__.__name__}.add_edge({u}, {v})"
 75        assert (
 76            0 <= v < self.n
 77        ), f"Indexerror: {self.__class__.__name__}.add_edge({u}, {v})"
 78        G_u = len(self.G[u])
 79        G_v = len(self.G[v])
 80        self.G[u].append([v, w, G_v, 0])
 81        self.G[v].append([u, 0, G_u, 0])
 82
 83    def _bfs(self, s: int) -> None:
 84        level = self.level
 85        for i in range(len(level)):
 86            level[i] = -1
 87        dq = deque([s])
 88        level[s] = 0
 89        while dq:
 90            v = dq.popleft()
 91            for x, w, _, _ in self.G[v]:
 92                if w > 0 and level[x] == -1:
 93                    level[x] = level[v] + 1
 94                    dq.append(x)
 95        self.level = level
 96
 97    @antirec
 98    def _dfs(self, v: int, g: int, f: int):
 99        if v == g:
100            yield f
101        else:
102            for i in range(self.it[v], len(self.G[v])):
103                self.it[v] += 1
104                x, w, rev, _ = self.G[v][i]
105                if w > 0 and self.level[v] < self.level[x]:
106                    fv = yield self._dfs(x, g, min(f, w))
107                    if fv > 0:
108                        self.G[v][i][3] += f
109                        self.G[x][rev][3] -= f
110                        self.G[v][i][1] -= fv
111                        self.G[x][rev][1] += fv
112                        yield fv
113                        break
114            else:
115                yield 0
116
117    def max_flow(self, s: int, g: int, INF: int = 10**18) -> int:
118        """:math:`O(V^2 E)`"""
119        assert (
120            0 <= s < self.n
121        ), f"Indexerror: {self.__class__.__class__}.max_flow(), {s=}"
122        assert (
123            0 <= g < self.n
124        ), f"Indexerror: {self.__class__.__class__}.max_flow(), {g=}"
125        ans = 0
126        while True:
127            self._bfs(s)
128            if self.level[g] < 0:
129                break
130            self.it = [0] * self.n
131            while True:
132                f = self._dfs(s, g, INF)
133                if f == 0:
134                    break
135                ans += f
136        return ans
137
138
139class BipartiteMaxMatching:
140
141    def __init__(self, n: int, m: int) -> None:
142        """二部グラフの最大マッチングを求めるグラフを初期化します。
143
144        Args:
145            n (int): 左側の頂点数です。
146            m (int): 右側の頂点数です。
147        """
148        self.n = n
149        self.m = m
150        self.s = n + m
151        self.t = n + m + 1
152        self.mf = MaxFlowDinic(n + m + 2)
153        for i in range(n):
154            self.mf.add_edge(self.s, i, 1)
155        for i in range(m):
156            self.mf.add_edge(n + i, self.t, 1)
157
158    def add_edge(self, l: int, r: int) -> None:
159        """左側の頂点 ``l`` と右側の頂点 ``r`` に辺を貼ります。
160
161        Args:
162            l (int):
163            r (int):
164        """
165        assert 0 <= l < self.n
166        assert 0 <= r < self.m
167        self.mf.add_edge(l, self.n + r, 1)
168
169    def max_matching(self) -> tuple[int, list[tuple[int, int]]]:
170        """最大マッチングを求め、マッチングの個数と使用する辺を返します。
171        :math:`O(E \sqrt{V})` です。
172
173        Returns:
174            tuple[int, list[tuple[int, int]]]:
175        """
176        ans = self.mf.max_flow(self.s, self.t)
177        K = []
178        for a in range(self.n):
179            for b, _, _, f in self.mf.G[a]:
180                if self.n <= b < self.n + self.t and f > 0:
181                    K.append((a, b - self.n))
182        return ans, K

仕様

class BipartiteMaxMatching(n: int, m: int)[source]

Bases: object

add_edge(l: int, r: int) None[source]

左側の頂点 l と右側の頂点 r に辺を貼ります。

Parameters:
  • l (int)

  • r (int)

max_matching() tuple[int, list[tuple[int, int]]][source]

最大マッチングを求め、マッチングの個数と使用する辺を返します。 \(O(E \sqrt{V})\) です。

Return type:

tuple[int, list[tuple[int, int]]]