跳转至

最小生成树#
发布于2020-12-13
上次编辑2021-04-19

Kruskal 算法#

每次选择一条 最短的边,且这条边的两个顶点不在已选择的边的顶点的同一个集合(不连通)中(避免出现环路)。

时间复杂度为 \(\mathcal{O}(E\log E)=\mathcal{O}(E\log V)\)

Python
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class UnionFind:
    def __init__(self):
        self.fa = None
        self.size = None

    def init_set(self, n: int) -> None:
        self.fa = list(range(n))
        self.size = [1] * n

    def find_set(self, x: int) -> int:
        if x != self.fa[x]:
            self.fa[x] = self.find_set(self.fa[x])
        return self.fa[x]

    def union_set(self, x: int, y: int) -> bool:
        x, y = self.find_set(x), self.find_set(y)
        if x == y: return False
        if self.size[x] < self.size[y]: x, y = y, x
        self.fa[y] = x
        self.size[x] += self.size[y]
        return True

def kruskal(n: int, edges: List[Tuple[int, int, int]]) -> int:
    # edges: [u, v, w]
    uf = UnionFind()
    uf.init_set(n)
    cost = 0
    for u, v, w in edges:
        if uf.union_set(u, v):
            cost += w
    return cost
C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class UnionFind {
  public:
    vector<int> fa, size;

    void init_set(int n) {
        fa.resize(n), size.resize(n);
        for (int i = 0; i < n; ++i) { fa[i] = i, size[i] = 1; }
    }

    int find_set(int x) { return x == fa[x] ? x : (fa[x] = find_set(fa[x])); }

    bool union_set(int x, int y) {
        x = find_set(x), y = find_set(y);
        if (x == y) return false;
        if (size[x] < size[y]) swap(x, y);
        fa[y] = x, size[x] += size[y];
        return true;
    }
};

int kruskal(int n, vector<tuple<int, int, int>> &edges) {
    // edges: [u, v, w]
    auto uf = UnionFind(); uf.init_set(n);
    int cost = 0, cnt = 0;

    sort(edges.begin(), edges.end(), [] (const auto &a, const auto &b) {
        return get<2>(a) < get<2>(b);
    });

    for (auto &[u, v, w] : edges) if (uf.union_set(u, v)) cost += w, ++cnt;

    if (cnt != n - 1) return -1;
    return cost;
}

Prim 算法#

每次在 未选择顶点集合 中选择离 已选择顶点集合 距离最短顶点 加入到 已选择顶点集合

Python
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
def prim(adj: List[List[float]]) -> float:
    n = len(adj)
    vis = [False] * n
    d = [float("inf")] * n

    d[0] = cost = 0
    for _ in range(n):
        v = -1
        for i in range(n):
            if not vis[i] and (v == -1 or d[i] < d[v]):
                v = i

        vis[v] = True
        cost += d[v]

        for i in range(n):
            if adj[i][v] < d[v]:
                d[v] = adj[i][v]
    return cost
C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
int prim(int n, vector<vector<int>> &adj) {
    vector<int> d(n, INT_MAX); d[0] = 0;
    vector<bool> vis(n);

    int cost = 0;
    for (int i = 0; i < n; ++i) {
        int v = -1;
        for (int j = 0; j < n; ++j) {
            if (!vis[j] && (v == -1 || d[j] < d[v])) v = j;
        }

        vis[v] = true; cost += d[v];

        for (int j = 0; j < n; ++j) {
            if (adj[j][v] < d[j]) d[j] = adj[j][v];
        }
    }
    return cost;
}

例题#

返回顶部

在手机上阅读