최소 스패닝 트리 (Minimum Spaning Tree)

스패닝 트리란 그래프에서 일부 간선을 제거하고(= 일부 간선만 선택하여) 모든 정점이 연결되도록 만든 트리를 말한다. 즉, 모든 정점이 직접적으로 연결되지 않아도 다른 정점을 거쳐 가면 다른 모든 정점에 도달할 수 있는 구조이다.

최소 스패닝 트리는 어떤 그래프의 스패닝 트리 중 간선의 가중치 합이 가장 작은 스패닝 트리를 말한다.

최소 스패닝 트리를 만드는 알고리즘

최소 스패닝 트리를 만드는 알고리즘은 프림 알고리즘과 크루스칼 알고리즘이 있다. 이 글에서는 해당 알고리즘들을 BOJ 1197번 문제를 이용하여 최소 스패닝 트리를 만드는 알고리즘을 구현한다.

프림 알고리즘 (Prim’s Algorithm)

프림 알고리즘은 아래와 같이 동작한다.

  1. 임의의 시작 정점을 스패닝 트리에 포함시킨다.
  2. 스패닝 트리에 포함된 모든 정점에서 포함되지 않은 정점으로 연결된 간선 중 가장 가중치가 작은 간선을 스패닝 트리에 포함시킨다.
  3. 모든 정점이 스패닝 트리에 포함 될 때까지 2를 반복한다.

프림 알고리즘이 어떻게 동작하는지 아래 그래프를 예제로 들어 알아보자. 먼저 임의의 정점을 스패닝 트리에 포함시킨다. 아무 정점이나 선택해도 상관이 없으므로 이 예제 에서는 1번 정점을 먼저 포함시킨다.

스패닝 트리에 포함된 정점과(현재는 1번 정점밖에 없다) 연결된 간선들 (1, 2)와 (1, 3) 중 가중치가 가장 작은 간선 (1, 3)을 스패닝 트리에 포함시키고 3번 정점을 스패닝 트리에 포함시킨다.

스패닝 트리에 포함된 정점과 연결된 간선들 (1, 2)와 (3, 2), (3, 4) 중 가중치가 가장 작은 간선은 (3, 2)이다. 이 간선을 선택하고 2번 정점을 스패닝 트리에 포함시킨다.

스패닝 트리에 포함된 정점과 연결된 간선들 (1, 2)와 (3, 4), (2, 4) 중 가중치가 가장 작은 간선은 (2, 4)이다. 이 간선을 선택하고 4번 정점을 스패닝 트리에 포함 시킨다.

스패닝 트리에 포함된 정점과 연결된 간선들 (1, 2)와 (3, 4), (4, 5) 중 가중치가 가장 작은 간선은 (3, 4)이지만 3번 정점과 4번 정점 모두 스패닝 트리에 포함되어 있으므로 이 간선은 무시한다. 간선 (1, 2) 역시 같은 이유로 무시한다. 이제 마지막으로 남은 간선 (4, 5)이 남았는데 5번 정점은 스패닝 트리에 포함되어 있지 않았으므로 이 간선을 선택하고 5번 정점을 스패닝 트리에 포함시킨다.

이제 더 이상 확인할 간선이 없다.

최소 스패닝 트리가 완성 되었다.

위 예제를 통하여 프림 알고리즘은 한 정점에서 시작해 계속 트리 형태를 유지해가며 간선을 선택, 확장하는 방식이라는 것을 알수 있다. 스패닝 트리를 구성하는 과정에서 가장 가중치가 작은 간선을 찾는 것이 큰 문제인데, 이 부분은 힙을 사용하면 쉽게 처리할 수 있다.

#include <cstdio>
#include <queue>
#include <utility>
#include <functional>
#define P pair<int, int >
using namespace std;

int prim(vector<P> edges[]) {
	int ans = 0;
	// selected[u] : 정점 u가 스패닝 트리에 포함되어 있는지 여부를 나타낸다
	bool selected[10001] = { false, };
	//간선의 비용이 낮은 순으로 처리해야 하므로 최소 힙을 선언한다
	priority_queue<P, vector<P>, greater<P>> pq;
	//임의의 정점에서 시작한다. 이 함수는 항상 1번 정점에서 시작한다.
	pq.push(P(0, 1));

	while (!pq.empty()) {
		P cur = pq.top(); pq.pop();
		//간선으로 연결된 정점이 이미 스패닝 트리에 포함되어 있다면 스킵
		if (selected[cur.second]) continue;
		//포함되어 있지 않다면 MST에 포함시키고 간선 가중치의 누적합을 계산한다
		selected[cur.second] = true;
		ans += cur.first;
		//새로 포함된 정점과 연결되어 있는 모든 간선을 힙에 삽입한다.
		for (int i = 0; i < edges[cur.second].size(); i++) {
			//해당 간선과 연결된 정점이 이미 스패닝 트리에 포함되어 있다면 무시한다
			if (selected[edges[cur.second][i].second]) continue;
			pq.push(edges[cur.second][i]);
		}
	}
	//모든 간선을 탐색했다면 최소 스패닝 트리 간선 가중치의 합을 반환한다
	return ans;
}

int main() {
	int V, E;
	//정점의 수 V와 간선의 수 E를 입력받는다
	scanf("%d%d", &V, &E);
	//간선 정보를 저장할 배열을 선언한다.
	//edges[u] = (c, v) : 정점에 u에서 v로 가는데 c 비용이 드는 간선
	vector<P> edges[10001];
	for (int i = 0; i < E; i++) {
		int u, v, c;
		scanf("%d%d%d", &u, &v, &c);
		edges[u].push_back(P(c, v));
		edges[v].push_back(P(c, u));
	}
	//최소 스패닝 트리 간선들의 가중치 합을 출력한다
	printf("%d", prim(edges));
}

크루스칼 알고리즘(Kruskal’s Algorthim)

크루스칼 알고리즘은 아래와 같이 동작한다.

  1. 모든 간선을 가중치를 기준으로 오름차순 정렬한다.
  2. 정렬된 간선들을 순회하며 간선 e(u, v)의 정점 u, v가 서로 같은 집합에 속해있는지 확인한다.
  3. 만약 서로 다른 집합에 속해있다면 두 집합을 합친다.

크루스칼 알고리즘이 어떻게 동작하는지 아래 그래프를 예로 들어 알아보자. 아래 그래프에서 간선의 가중치가 가장 작은 간선 (2, 4)에서 2번 정점과 4번 정점은 서로 같은 집합에 있지 않으므로 두 정점을 같은 집합에 포함시킨다.

그 다음으로 가중치가 가장 작은 간선은 (2, 3)이다. 두 정점 역시 서로 다른 집합에 있으므로 두 집합을 합쳐 같은 집합으로 만든다.

남은 간선 중 가중치가 가장 작은 간선은 (3, 4)이지만 현재 3번 정점과 4번 정점은 같은 집합에 속해있으므로(=연결하면 사이클이 생기게 되므로) 이 간선은 무시한다. 그 다음으로 가중치가 작은 간선은 (3, 1)이다. 1번 정점과 3번 정점은 서로 다른 집합이므로 두 집합을 합친다.

남은 간선 중 가중치가 가장 작은 간선은 (1, 2)이다. 하지만 1번 정점과 2번 정점은 같은 집합에 속해있으므로 이 간선은 무시한다. 그 다음으로 가중치가 작은 간선은 (4, 5)이다. 4번 정점과 5번 정점은 서로 다른 집합에 속해있으므로 두 집합을 합친다.

더 이상 확인할 간선이 없다. 최소 스패닝 트리가 완성 되었다.

위 과정을 보면 크루스칼 알고리즘은 트리에 포함되지 않은 전체 간선 중 가중치가 가장 낮은 간선부터 선택해 나가므로 탐욕적 알고리즘이다. 프림 알고리즘과는 다르게 스패닝 트리를 구성하는 과정에서 트리가 산발적으로 생성된다는 것이 특징이다.

이 알고리즘을 구현하는데 두 정점이 같은 집합에 속해있는지를 검사하는 부분. 즉, 두 간선을 연결 했을 때 사이클이 생기는지에 대한 검사를 해야하는데 이 부분은 Union-Find(Disjoint Set[상호 배타적 집합])를 사용해서 빠른 시간에 처리할 수 있다.

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

typedef struct edge{
	int u, v, c;
	edge(int u, int v, int c) : u(u), v(v), c(c) {};
} edge;

int group[10001];

//Find(x) : 정점 x의 집합 번호를 반환한다
int Find(int x) {
	if (group[x] == x) return x;
	return group[x] = Find(group[x]);
}

//Union(x, y) : 정점 y의 집합을 정점 x의 집합에 포함시킨다.
void Union(int x, int y) {
	x = Find(x);
	y = Find(y);
	group[y] = x;
}

int compare_edge(const edge &a, const edge &b) {
	return a.c < b.c;
}

int kruskal(int V, vector<edge> edges) {
	int ans = 0;
	//집합을 초기화 한다.
	for (int i = 1; i <= V; i++)
		group[i] = i;

	//간선을 정렬한다
	sort(edges.begin(), edges.end(), compare_edge);

	for (int i = 0; i < edges.size(); i++) {
		//간선의 출발 정점과 도착 정점을 연결했을때
		//사이클이 생긴다면(= 같은 집합이라면) 이 간선은 무시한다.
		if (Find(edges[i].u) == Find(edges[i].v)) continue;
		//사이클이 생기지 않는다면(= 다른 집합이라면) 두 정점의 집합을 하나로 합친다.
		Union(edges[i].u, edges[i].v);

		//스패닝 트리를 구성하는 간선들의 가중치 누적합을 계산한다.
		ans += edges[i].c;
	}
	return ans;
}

int main() {
	int V, E;
	//간선 정보를 저장할 배열을 선언한다.
	vector<edge> edges;
	//정점의 수 V와 간선의 수 E를 입력받는다
	scanf("%d%d", &V, &E);

	for (int i = 0; i < E; i++) {
		//간선의 정보를 입력받고 배열에 삽입한다.
		int u, v, c;
		scanf("%d%d%d", &u, &v, &c);
		edges.push_back(edge(u, v, c));
		edges.push_back(edge(v, u, c));
	}
	//최소 스패닝 트리 간선들의 가중치 합을 출력한다
	printf("%d",  kruskal(V, edges));
}