알고리즘이론

MST를 응용해보자

배우겠습니다 2023. 3. 29. 15:18

서론

크루스칼알고리즘이나 프림알고리즘은 학교알고리즘수업이나 다른 개인의 자료에서도 필수적으로 소개되는 알고리즘이다. (나름 관련 문제가 많은 union find를 그것을 구현하기 위해 처음 배울 정도니까!)

그것을 응용한 문제를 풀어보자.

 

알아야 할 것

트리의 성질, 크루스칼 알고리즘, LCA (sparse table)가 무엇인지 알아야하고 관련 문제를 풀어봤어야한다.

 

MST (Minimum Spanning Tree)

Tree란 특수한 모양의 그래프이다. 그래프가 한 컴포넌트로 이어져있고 간선의 수가 N-1개라면 사이클이 없는 트리란 그래프가 완성된다.

https://csacademy.com/app/graph_editor/로 수제제작한 트리이다.

1. 트리에서 자신위에 있는 노드를 부모노드라 하고 밑에 있는 노드들을 자식노드라 한다.

2번의 부모는 1번이고 2번의 자식은 4,5번이다.

2. 부모노드가 없는 노드는 루트라고 한다. 여기서는 1번이다. 루트는 자신만의 기준이다.(탐색을 시작할 노드라고 생각하면 된다.) 트리를 잘 돌리고 잡아 당기면 아무 노드가 루트가 될 수 있다.

3. 루트를 제외하고 자신의 부모는 유일하다. (즉 트리는 Tree[Node] = Parent꼴의 벡터로 표현 가능하다.)

Tree = [1,1,1,1,2,2,3,6,6,] (0,1번인덱스의 값은 그냥 루트 노드를 줬다.)

4. 같은 노드를 중복해서 방문하지않는 경로를 생각했을 때 트리안의 두 노드  A, B사이의 경로는 유일하다. 

 

Spanning Tree란 그래프 탐색의 결과이다.

for node in adjlist[start]:
    if not visit[node]:
        visit[node] = True 
        nextsearch(node)

우리들이 탐색할때 다음과 같은 코드가 들어갈 텐데(DFS든 BFS든)

여기서 start는 부모,  방문하지않은 node는 자석이라 생각한다면

그래프탐색의 결과로 트리가 만들어진다.

 

MST

MST란 트리인데, 무방향 그래프에서 만들어질 수 있는 Spanning Tree 중 가장 가중치합이 적은 트리이다.

이는 프림과 크루스칼알고리즘을 이용해 구한다. 난 둘 중 하나만 알면된다고 생각하고 크루스칼알고리즘로만 풀었다.

 크루스칼 알고리즘은 다음과 같다.

1. 간선정보(노드1,노드2,가중치)를 받는 배열을 만들고 그것을 가중치를 기준으로 정렬한다.

2. 그것을 첫번째부터 마지막까지 순회하며 다음과 같은 과정을 거친다.

노드1과 노드2가 같은 집합에 속했다면{
넘어간다;
}
다른 집합이라면{
이 간선은 MST에 포함된 간선이다;
}

3. MST완성

여기서 봐야할건 같은 집합으로 만들어주고 같은 집합인지 판별할 때 쓰는 unionfind 알고리즘이다.

unionfind의 find함수는 경로를 압축시켜서 시간을 절약하는데 이는 특정노드 x에 대해 parent[x]의 값은 x가 포함된 (sub)tree의 루트라고 생각하면 된다.

즉 parent의 값이 같다면 같은 트리에 속한다는 것이다.

노드1과 노드2를 이어주는 간선E에 대해 노드1과 노드2가 같은 트리에 속한다면 MST에 포함안시키는 이유는

노드1,노드2가 속한 트리에서 노드1과 노드2를 잇는 경로는 유일해야 하고(트리의 성질) 해당 경로에 가중치는 모두 E의 가중치이하라는 것이 보장되기 때문이다. (가중치가 작은 순으로 정렬시키고 앞쪽부터 트리를 만들어줬으니까)

 

문제추천

이중 MST와 LCA유형을 혼합한 두 문제를 추천해보고자 한다.

 

백준 15481

https://www.acmicpc.net/problem/15481

 

15481번: 그래프와 MST

첫째 줄에 정점의 개수 N과 간선의 개수 M (2 ≤ N ≤ 200,000, N-1 ≤ M ≤ 200,000) 둘째 줄부터 M개의 줄에 간선의 정보 u, v, w가 주어진다. u와 v를 연결하는 간선의 가중치가 w라는 뜻이다. (1 ≤ u, v ≤

www.acmicpc.net

 우선 크루스칼로 MST를 구한다. MST에 포함되는 간선은 원래 MST의 가중치가 답임을 알 수 있다.

포함되지 않은 간선에 대해선 어떨까? 

암튼 억지로 포함시켜서 spanning tree를 만든다면 가중치는 MST이상인 것은 알겠으나, 어떤 간선을 삭제해야 그중에서도 최소일까를 고민해봐야 한다.

위에서 find(노드1) == find(노드2)라면 왜 MST에 포함 안 시키는지를 고민해 봤다.

경로가 유일해지지 않기 때문인데, 억지로 포함시키려면 경로에 속한 간선을 하나 삭제해야 한다.

경로에 속한 간선 중에서도 가중치가 최대인 간선을 삭제시켜야 최소가 된다.

더보기

왜냐하면 새로운 MST의 가중치를 new, 원래 MST의 가중치를 origin이라 하면

new = origin + w(포함시킬 간선 가중치) - e(삭제할 간선 가중치) 이기 때문이다.

즉 이 문제는

트리가 주어졌을 때, 쿼리는 다음과 같이 들어온다.

쿼리의 정보는  노드 1, 노드 2이다.

쿼리를 질의받으면 우리는 노드 1과 노드 2 경로에 있는 간선 중 가중치가 최대인 간선을 구해야 한다.

라는 익숙한 문제의 테크닉을 가져다 쓰면 된다.

우리들은 이 문제를 LCA를 배울 때 한번 쯤 풀어봤다.(solve ac class에도 동일한 문제가 있는 것으로 안다.)

하지만, 15841의 input범위를 봤을 때 가장 기초적인 LCA풀이로는 시간초과가 뜰 것을 알 수 있다.

따라서 sparse table을 이용해 풀어야 한다.

 

백준 1626

https://www.acmicpc.net/problem/1626

 

1626번: 두 번째로 작은 스패닝 트리

첫째 줄에 그래프의 정점의 수 V(1 ≤ V ≤ 50,000)와 간선의 수 E(1 ≤ E ≤ 200,000)가 들어온다. 둘째 줄부터 E+1번째 줄까지 한 간선으로 연결된 두 정점과 그 간선의 가중치가 주어진다. 가중치는 100,

www.acmicpc.net

문제는 G의 최소 스패닝 트리보다는 크면서 가장 작은 스패닝 트리인 'The second minimum spanning tree'를 구하는 것이다.

참고로 이 문제는 입력받는 그래프가 하나의 컴포넌트가 아닌 (MST를 못 만드는) 테스트케이스가 들어온다. 예외처리하자.

언뜻 보면 위의 문제와 풀이가 완벽히 일치하는 문제 같으나,

사용하는 간선이 달라도 MST의 가중치가 달라야한다.

경로에 속한 간선 중 최대가중치를 X라하자.

X는 포함시키는 간선의 가중치 이하이다. 따라서 두 가중치가 같을 수 있다 아뿔소!

이는 단순하게 풀 수 없다.

우선 우리는 2번째로 큰 최대값 X는 포함시키는 간선의 가중치보다 항상 작음을 알 수 있다.

즉 우리가 해야할 것은

더보기

1. 2번째로 큰 최대값을 담은 sparse table도 생성

2. LCA를 구할때 간선의 최대값은 포함시켜야하는 간선의 최대값보다 작은 경우에만 갱신

1번이 아이디어가 필요해보인다.

원래 간선최대값을 담은 sparse table을 table1이라하자.

2번째 최댓값을 담은 sparse table을 table2라하자.

par은 node의 2^i번째 위에있는 노드이다.

table2[node][i] < table1[node][i] 여야한다.

table1[node][i] = max(table1[node][i-1],table1[par][i-1])를 구하는 것과 비슷하게,

table2[node][i] = max(table1[node][i-1],table1[par][i-1],table2[node][i-1],table2[par][i-1] 단 table1[node][i]이상이라면 0으로 대체한다.) 이다.

왜냐하면

(table2[node][i-1]가 제대로 된 값(2번째 최대값)을 담고있다고 가정)

1. table1[node][i-1]과 table1[par][i-1]이 모두 table1[node][i]와 같다.

그렇다면 table2[node][i-1],table2[par][i-1]중 큰 것을 취해야한다.

2. table1[node][i-1]만 table1[node][i]와 같다.

그럼 table2[node][i-1]과 table1[par][i-1]중 큰 것을 취해야 한다.

3. table1[par][i-1]만 table1[node][i]와 같다.

그럼 table2[par][i-1]과 table1[node][i-1]중 큰 것을 취해야 한다.

 

코드 

 

백준 15841 코드

더보기
import sys
from collections import deque

input = sys.stdin.readline

N, M = map(int, input().split())
edges = []
for i in range(M):
    a, b, w = map(int, input().split())
    edges.append((w, a, b, i))
edges.sort()
MST = [False] * M  # 초기 크루스칼결과 mst에 포함
parent = [i for i in range(N + 1)]  # for union find
cost = 0  # mst 가중치
ans = [0] * M  # 정답 답는 배열
tree = [[] for _ in range(N + 1)]  # tree


def find(x):
    if x != parent[x]:
        parent[x] = find(parent[x])
    return parent[x]


def union(x, y):
    x = find(x)
    y = find(y)
    if x < y:
        parent[y] = x
    else:
        parent[x] = y


# 크루스칼
for w, a, b, idx in edges:
    if find(a) != find(b):
        union(a, b)
        # 크루스칼
        cost += w
        MST[idx] = True
        # 트리에 간선 추가
        tree[a].append((b, w))
        tree[b].append((a, w))
# mst에 포함된 것들 처리
for w, a, b, idx in edges:
    if MST[idx]:
        ans[idx] = cost
# 희소배열 초기화
CAP = 20
ptable = [[1] * (N + 1) for _ in range(CAP)]  # 부모희소배열
wtable = [[0] * (N + 1) for _ in range(CAP)]  # 간선 희소배열
q = deque([1])
depth = [-1] * (N + 1)  # 루트로부터 거리
depth[1] = 0
while q:
    node = q.popleft()
    for n, w in tree[node]:
        if depth[n] == -1:
            q.append(n)
            depth[n] = depth[node] + 1
            ptable[0][n] = node
            wtable[0][n] = w
for i in range(1, CAP):
    for n in range(1, N + 1):
        p = ptable[i - 1][n]
        ptable[i][n] = ptable[i - 1][p]
        wtable[i][n] = max(wtable[i - 1][n], wtable[i - 1][p])


def lca(a, b):
    if depth[a] > depth[b]:
        a, b = b, a
    d = depth[b] - depth[a]
    tmp = 0
    for j in range(CAP):
        if d & (1 << j):
            tmp = max(tmp, wtable[j][b])
            b = ptable[j][b]
    if a == b:
        return tmp
    for j in range(CAP - 1, -1, -1):
        if ptable[j][a] != ptable[j][b]:
            tmp = max((tmp, wtable[j][a], wtable[j][b]))
            a = ptable[j][a]
            b = ptable[j][b]
    tmp = max((tmp, wtable[0][a], wtable[0][b]))
    return tmp


# 쿼리 처리
for w, a, b, idx in edges:
    if MST[idx]:  # 이미 처리한건 넘김
        continue
    val = lca(a, b)
    ans[idx] = cost - val + w  # 기존mst - 경로에서제일큰 간선 + 새로운간선
for a in ans:
    print(a)

 

백준 1626 코드

더보기
import sys
from collections import deque

input = sys.stdin.readline

V, E = map(int, input().split())
edges = []
for i in range(E):
    a, b, w = map(int, input().split())
    edges.append((w, a, b))
edges.sort()
MST = [False] * E  # 초기 크루스칼결과 mst에 포함
parent = [i for i in range(V + 1)]  # for union find
cost = 0  # mst 가중치
ans = float('inf')
tree = [[] for _ in range(V + 1)]  # tree


def find(x):
    if x != parent[x]:
        parent[x] = find(parent[x])
    return parent[x]


def union(x, y):
    x = find(x)
    y = find(y)
    if x < y:
        parent[y] = x
    else:
        parent[x] = y


# 크루스칼
for i in range(E):
    w, a, b = edges[i]
    if find(a) != find(b):
        union(a, b)
        cost += w
        MST[i] = True
        # 트리에 간선 추가
        tree[a].append((b, w))
        tree[b].append((a, w))
# MST를 못이뤄ㅛ을때
if MST.count(True) != V - 1:
    print(-1)
    sys.exit()
# 트리 구조화 및 희소배열 초기화
CAP = 20
ptable = [[1] * (V + 1) for _ in range(CAP)]  # 부모희소배열
wtable = [[0] * (V + 1) for _ in range(CAP)]  # 간선 희소배열
q = deque([1])
depth = [-1] * (V + 1)  # 루트로부터 거리
depth[1] = 0
while q:
    node = q.popleft()
    for n, w in tree[node]:
        if depth[n] == -1:
            q.append(n)
            depth[n] = depth[node] + 1
            ptable[0][n] = node
            wtable[0][n] = w
# 희소배열 채우기
for i in range(1, CAP):
    for n in range(1, V + 1):
        p = ptable[i - 1][n]
        ptable[i][n] = ptable[i - 1][p]
        wtable[i][n] = max(wtable[i - 1][n], wtable[i - 1][p])
stable = [[-float('inf')] * (V + 1) for _ in range(CAP)]  # 2번째로 작은 간선
for i in range(1, CAP):
    for n in range(1, V + 1):
        value = wtable[i][n]  # 현재 최대값
        tmp = -float('inf')
        p = ptable[i - 1][n]
        if wtable[i - 1][n] < value:
            tmp = wtable[i - 1][n]
        if wtable[i - 1][p] < value:
            tmp = max(tmp, wtable[i - 1][p])
        if stable[i - 1][n] < value:
            tmp = max(tmp, stable[i - 1][p])
        if stable[i - 1][p] < value:
            tmp = max(tmp, stable[i - 1][n])
        stable[i][n] = tmp


def lca(a, b, table, bias):
    if depth[a] > depth[b]:
        a, b = b, a
    d = depth[b] - depth[a]
    tmp = -float('inf')
    for j in range(CAP):
        if d & (1 << j):
            if table[j][b] < bias:
                tmp = max(tmp, table[j][b])
            b = ptable[j][b]
    if a == b:
        return tmp
    for j in range(CAP - 1, -1, -1):
        if ptable[j][a] != ptable[j][b]:
            if table[j][a] < bias:
                tmp = max(tmp, table[j][a])
            if table[j][b] < bias:
                tmp = max(tmp, table[j][b])
            a = ptable[j][a]
            b = ptable[j][b]
    if table[0][a] < bias:
        tmp = max(tmp, table[0][a])
    if table[0][b] < bias:
        tmp = max(tmp, table[0][b])
    return tmp


ans = float('inf')
for i in range(E):
    if MST[i]:
        continue
    w, a, b = edges[i]
    ans = min((ans, cost - lca(a, b, wtable, w) + w, cost - lca(a, b, stable, w) + w))
print(ans if ans < float('inf') else -1)