알고리즘이론

LCA (최소공통조상)과 sparse table (희소배열) python

배우겠습니다 2023. 3. 29. 20:05

LCA 

최소공통조상은 트리에서 서로다른 두 노드 A,B에 대해

두 노드가 계속 루트로 가는 경로로 이동했을때, 처음으로 만날 수 있는 깊이가 가장 깊은 노드이다.

 

가장 단순하게 구하는방법

1. 루트에서 탐색한다. BFS로 각 노드마다 깊이를 저장한다. (깊이란 부모쪽으로 x번 거슬러 올라가야 루트가 나올때 x이다.)

2. 두 노드 A,B를 받는다. 

3. 만약 두 노드의 깊이가 다르다면, 깊이가 깊은쪽에서 거슬러 올라가 얕은쪽과 같은 깊이가 되게 맞춘다.

4. 두 노드가 같아질때 까지 하나씩 거슬러 올라간다.

복잡도는 O(노드수)이다.

하지만 O(N)보다 효율적인 알고리즘이 존재한다.

 

O(log(노드수))로 구하는 방법: sparse table

(루트에선 몇번이라도 거슬러올라가봤자 루트임을 가정한다.)

만약 특정노드가 2^0,2^1,2^2,....2^k번 거슬러올라갔을 때 모슨 노드인가 저장하는 것이 sparse table이다.

왜 2의 거듭제곱꼴을 사용하나면

이진수에서 알듯이, 자연수 = sum(2^(비트가 1인 자리의 자릿수 ))이다.

예를들어서

9(1001) = 2^3 + 2^0 (3번째와 0번째비트가 1이다.

7(111) = 2^2 + 2^1 + 2^0 (0번째,1번째,2번째비트가 1이다.)

.....

따라서

특정노드에서 몇번을 이동하든 sparse table의 정보로 이동한 결과를 빠르게 찾을 수 있단 것이다.

대충 이런그래프가 있다하고 8번 노드에서 2^k번 이동했을때를 생각해보자.

1번에선 몇번이동해도 1이라고 가정하자.

1. 2^0번 이동했을때: 7

2. 2^1번 이동했을때: 6 (7번노드가 2^0번이동)

3. 2^2번 이동했을때: 4 (6번노드가 2^1번 이동)

4. 2^3번 이동했을때: 1 (4번노드가 2^2번 이동)

노드 A에서 2^k번 이동했을때는

(A에서 2^(k-1)번 이동한 결과) 가 2^(k-1)번 이동한 결과라고 생각하면 된다.

이것이 sparse table을 구현하는 것의 전부이다.

# ROOT:트리의 루트 N:노드의 개수
capacity = 22  # 넉넉하게 잡자. 어차피 오버돼도 답은 안변한다.
sparse_table = [[ROOT] * 20 for _ in range(N + 1)]
# 1. sparse_table[node][0]는 자신의 부모로 초기화한다.
for node in range(1, N + 1):
    sparse_table[node][0] = parent[node]
# 2. sparse_table을 채워준다.
for cap in range(1, capacity):
    for node in range(1, N + 1):
        par = sparse_table[node][cap - 1]  # node에서 2^(cap-1) 거슬러 올라갔을 때
        parofpar = sparse_table[par][cap - 1]  # par에서 2^(cap-1) 거슬러 올라갔을 때
        sparse_table[node][cap] = parofpar

 

그래서 이걸로 뭘할 것인가

sparse table로 LCA를 구하는 법은 O(N)으로 구하는 법과 유사하다.

1. 깊이를 서로 맞춰준다. 만약 그 결과로 LCA를 찾을 수 있다면 다음과정은 진행하지 않는다.

2. cap을 capatity - 1부터 0까지 순회하면서 sparse_table[node1][cap]과 sparse_table[node2][cap]이 다르다면

node1 = sparse_table[node1][cap], node2 = sparse_table[node2][cap]으로 갱신시켜준다. 둘의 값이 같다면 투머치하게 거슬러 올라간 것이기 때문이다.

3. sparse_table[node1][0]가 LCA의 결과이다.

def lca(n1, n2):
    if depth[n1] > depth[n2]:
        n1, n2 = n2, n1  # 편의를 위해 n1이 더 얕다고 가정
    diff = depth[n2] - depth[n1]
    # diff를 2의 거듭제곱수들의 합으로 나타낸다.
    for cap in range(capacity):
        # 만약 diff를 합으로 나태낸 것에 2^cap가 포함된다면
        if diff & (1 << i):
            n2 = sparse_table[n2][cap]
    if n1 == n2: # 이미 찾았다면 생략 
        return n1
    for cap in range(capacity - 1, -1, -1):
        if sparse_table[n1][cap] != sparse_table[n2][cap]:
            n1 = sparse_table[n1][cap]
            n2 = sparse_table[n2][cap]
    return sparse_table[n1][0]  # == sparse_table[n2][0]

왜 3번과정이 필요하냐면, 우리는 sparse_table의 값이 달라질때

두 노드를 갱신하는데, 그러다보면 최종적으로 n1과 n2가 LCA의 바로 밑 자식노드로 갱신될 것이기 때문이다.

그래서 3번과정으로 최종적인 답을 리턴해준다.

 

더 나아가서

sparse_table로 트리가중치의 최대 또는 최소 또는 합 역시 쉽게 구할 수 있다.

가중치를 담는 희소배열을 새로 만들어주자.

par = sparse_table[node][cap-1]은 node->par로 거슬러 올라간것이고

parofpar = sparse_table[par][cap-1]은 par->parofpar로 거슬러올라간 것이다.

결국 sparse_table[node][cap]은 node->parofpar로 거슬러올라간 것이 된다.

이 힌트로 직접 구현해보자.