백준

백준 1761번: 정점들의 거리 (Lowest Common Ancestor)

츄츄츄츄츄츄츄 2023. 1. 9. 23:33

문제링크:1761번: 정점들의 거리 (acmicpc.net)

 

1761번: 정점들의 거리

첫째 줄에 노드의 개수 N이 입력되고 다음 N-1개의 줄에 트리 상에 연결된 두 점과 거리를 입력받는다. 그 다음 줄에 M이 주어지고, 다음 M개의 줄에 거리를 알고 싶은 노드 쌍이 한 줄에 한 쌍씩

www.acmicpc.net

이 문제는 정점 N개와 N개를 모두 이어줄 수 있는 가중치가 있는 N-1개의 간선들이 주어 질 때, 두 노드 사이의 거리를 얻는 문제이다. 문제는 정점의 개수인 N이 40,000개까지 주어질 수 있고, 거리를 얻어야 하는 두 노드 케이스가 10,000개까지 주어진다. 최소 공통 조상, LCA 알고리즘과 이분탐색을 연계하여 사용하면 두 노드 사이의 거리를 O(logN)의 시간 복잡도로 구할 수 있다.

 

최소 공통 조상이란 a와 b 노드가 주어졌을 때, a의 조상노드(부모,부모의부모,..) 와 b의 조상노드 중 겹치는 조상 중 가장 적게 대를 올라가는 노드이다. 만약 조상노드 중 같은 조상이 있다면 두 노드는 이어 질 수 있다는 뜻이고, 가장 적게 대를 올라가는 최소 공통 조상은 두 노드를 이어주는 최단 거리에 포함되는 노드 일 것이다.

 

import sys
sys.setrecursionlimit(999999)
N=int(sys.stdin.readline().rstrip())

graph=[[] for _ in range(N+1)]

maxdepth=1
while 2**(maxdepth)<N:
    maxdepth+=1

for _ in range(N-1):
    a,b,c=map(int,sys.stdin.readline().rstrip().split())
    graph[a].append([b,c])
    graph[b].append([a,c])

parent=[[0 for _ in range(maxdepth)] for _ in range(N+1)]
depthlst=[0 for _ in range(N+1)]
visited=[0 for _ in range(N+1)]
dislst=[0 for _ in range(N+1)]
visited[1]=1

def dfs(depth,root,totaldis):
    for i in graph[root]:
        child=i[0]
        dis=i[1]
        if visited[child]==0:
            visited[child]=1
            totaldis+=dis
            depthlst[child]=depth+1
            dislst[child]=totaldis
            parent[child][0]=root
            for k in range(1,maxdepth):
                parent[child][k]=parent[parent[child][k-1]][k-1]
            dfs(depth+1,child,totaldis)
            totaldis-=dis

dfs(0,1,0)
        
def LCA(x,y):
    xdepth=depthlst[x]
    ydepth=depthlst[y]

    if xdepth>ydepth:
        for k in range(maxdepth-1,-1,-1):
            if xdepth-ydepth>=(1<<k):
                xdepth-=(1<<k)
                x=parent[x][k]
                
    if ydepth>xdepth:
        for k in range(maxdepth-1,-1,-1):
            if ydepth-xdepth>=(1<<k):
                ydepth-=(1<<k)
                y=parent[y][k]
                
    if x==y:
        return x

    for k in range(maxdepth-1,-1,-1):
        if parent[x][k]!=0:
            if parent[x][k]!=parent[y][k]:
                x=parent[x][k]
                y=parent[y][k]
           
    return parent[x][0]

M=int(sys.stdin.readline().rstrip())

for _ in range(M):
    a,b=map(int,sys.stdin.readline().rstrip().split())
    print(dislst[a]+dislst[b]-2*dislst[LCA(a,b)])

먼저 우리는 이분 탐색을 이용할 것이기에 최고 깊이를 2**k 로 표현할 수 있는 최소 정수 k를 찾는다. 만약 N개의 노드가 있을때, 모두 일직선으로 이어진다면 최대 depth는 N이므로 2**k가 N보다 크거나 같아질때의 k를 찾으면 된다.

 

모든 입력은 모든 노드들이 연결될 수 있게 주어지므로, 그리고 가중치에는 방향성이 없으므로 임의의 노드 중 아무거나를 root노드로 정한다. 여기서는 1로 정했다. 1부터 DFS를 통해 트리를 만들고, DFS로 탐색해가며 1부터 각 노드까지의 총 거리 totaldis를 dislst에 저장 해 주어 나중에 두 노드 사이의 거리를 구할 수 있도록 한다. 그리고 depthlst에 각 노드의 depth도 저장해 준다. 만약 a, b 사이 거리를 구해야 한다고 칠때, c가 둘의 LCA라면 1부터 a까지 거리와 1부터 b까지 거리를 더해준다. 그리고 1부터 c까지의 거리를 두번 곱해서 빼주면 a와 b사이의 거리가 될 것이다.

 

parent정보도 DP를 활용하여 저장해 준다. 여기서 우리는 이분 탐색을 이용하기로 했기 떄문에 parent[child][k] 는, child 노드의 2**k번쨰 위 부모라는 뜻이다. 바로 위 부모 (root) 는 2**0번째 부모이고, 2**1번째(2up) 부모는 2**0번째(1up) 부모의 2**0번(1up)쨰 부모이고, 2**2(4up)번째 부모는 2**1번째(2up) 부모의 2**1번째(2up)부모로 이렇게 DP를 통해 maxdepth까지 찾아주면 해당 child노드에 대한 이분탐색을 위한 부모 정보가 저장된다. DFS를 통해 조사하기 때문에 child노드의 조상 노드 들의 정보들은 모두 이미 저장된 후이므로 조상 노드 정보가 누락될 일이 없다.

 

다음으로 LCA()함수를 통해 최소 공통 조상을 찾아준다. x와 y의 LCA를 찾는다고 할 때, 이분 탐색을 통해 먼저 x와 y의 depth를 맞춰준다. 깊은 노드를 다른 노드에 맞춰준다. 만약 x가 y보다 depth가 깊다면 비트를 이용해 큰값부터 2**k를 xdepth에 빼었을때  ydepth보다 작아지면 실행하지 않는다. k를 줄여가며 만약 2**(k-1)을 뻈을때는 ydepth보다 크거나 같다면 빼는것을 실행 해주고 x와 xdepth를 그에 맞게 업데이트 해 준다. 그렇게 k를 maxdepth-1부터 0까지 해주면 된다. 이분 탐색보다 제곱의 분할 정복과 비슷한 개념일 수도 있겠다.

 

만약 2**k를 이용하지 않고 1씩 빼주어서 부모노드로 업데이트 하는 방식을 사용한다면, k번 연산할 것을 최악의 경우 2**(k+1)연산해야 한다. 두 노드의 depth를 맞춰 주고 다시 2**k를 뺴가는 방식으로 최소 공통 조상을 찾아낸다. 만약 x와 y의 2**k 번째 위 부모가 같다면 최소가 아닐수도 있는 조상 노드라는 뜻이다. 아까처럼 공통 조상이 아닐 때에만 maxdepth-1부터 0까지 빼어주면, x와 y 둘 다 최소 공통 조상 바로 1 아래에 위치하게 된다. 업데이트 된 x나 y의 바로 위 부모 노드가 최소 공통 조상이 될 것이다.

 

만약 노드의 개수가 이 문제처럼 엄청 많고, 간선내에 사이클이 있다거나 연결되지 않는 노드들이 있다거나 하지 않다면 최소 공통 조상을 사용해 노드 사이의 거리를 구하는 것이 효율적일 수도 있겠다.