백준

백준 13547번: 수열과 쿼리 5 (Mo's)

츄츄츄츄츄츄츄 2023. 2. 22. 00:30

문제링크: 13547번: 수열과 쿼리 5 (acmicpc.net)

 

13547번: 수열과 쿼리 5

길이가 N인 수열 A1, A2, ..., AN이 주어진다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오. i j: Ai, Ai+1, ..., Aj에 존재하는 서로 다른 수의 개수를 출력한다.

www.acmicpc.net

이 문제는 길이가 N인 수열에서 (i, j)가 입력인 쿼리가 주어졌을때 i번째 수부터 j번째 수까지 다른 수의 개수를 출력하는 쿼리를 구현하는 문제이다.

 

이 문제를 풀기 위해서는 mo's 알고리즘에 대해 알고 있어야 한다. mo's 알고리즘은 주어진 쿼리들의 순서를 오프라인 쿼리처럼 특정한 방법으로 배열 한 뒤, 이전의 쿼리의 답을 이용해 가며 다음 쿼리의 답을 구하는 방법이다.

 

여기서 특정한 방법은 쿼리의 i, j가 있을때 쿼리의 순서를 (i//sqrt(N), j)의 우선순위로 더 작은 것을 먼저 처리한다. (i//sqrt(N)) 에서 sqrt(N) 은 sqrt decomposition 에서 나온 것인데, 전체 N개의 수열을 루트N 개의 그룹으로 나누면 각 그룹은 대표값과 루트 N개의 값들을 갖게 된다. 세그먼트 트리처럼 그룹의 대표값을 구하고 싶은 범위 안에 그룹이 있다면 그룹 내의 모든 값들을 탐색하지 않고 그룹의 대표값만 탐색하여 시간을 줄이는 방식이다.

 

쿼리를 정렬하는 방식에서 sqrt decomposition을 사용하는 방법은 구간이 여러개 있을 때, 시작과 끝 범위가 계속해서 달라지므로 시작과 끝을 +=1, -=1 시키며 계속 변경시키며 탐색할텐데 여기서 모든 쿼리를 탐색할 때 변경되는 정도를 최소화하기 위해서이다. sqrt decomposition을 사용해 쿼리를 정렬하고 시작과 끝을 변경시키면서 탐색하면, 시작과 끝을 변경하는 횟수를 N * sqrt(N) 으로 최소화 할 수 있다고 한다.

 

과연 N*sqrt(N) 에 해결할 수 있을까? 생각해 보면

 

start가 한 그룹에서 최대로 움직이는 경우는 sqrt(N) (한 그룹의 개수) 이다. 여기서 end는 최대 N번 움직일 수 있다.

다음 그룹은 start = sqrt(N), end 는 최대 N-sqrt(N),

다음 그룹은 start = sqrt(N), end는 N - 2*sqrt(N), ...

형식으로 움직일 수 있다.

그러면 start는 최대 sqrt(N) * sqrt(N)으로 N번 움직일 수 있고

end는 최대 N * sqrt(N) - (0.5) * N * sqrt(N) ~= 0.5 * N * sqrt(N)  번 움직일 수 있다. 총 시간 복잡도는 N + 0.5 N(sqrt(N)) ~= N * sqrt(N) 이 된다.

import sys

N = int(sys.stdin.readline().rstrip())

nums = list(map(int,sys.stdin.readline().rstrip().split()))

M = int(sys.stdin.readline().rstrip())
sqrt = N**0.5

queries = []

for i in range(M):
    s,e = map(int,sys.stdin.readline().rstrip().split())
    queries.append((s-1,e-1,i))

queries = sorted(queries, key = lambda x : (x[0]//sqrt, x[1]))

ans = [0 for _ in range(M)]

def query(start, end, cache):
    if cache: #캐시를 통해 이전 쿼리의 정보얻기
        p_start, p_end, p_set = cache

        while p_end < end:
            p_end += 1
            if nums[p_end] in p_set.keys():
                p_set[nums[p_end]] += 1
            else:
                p_set[nums[p_end]] = 1
                
        while p_end > end:
            p_set[nums[p_end]] -= 1
            if p_set[nums[p_end]] == 0:
                del p_set[nums[p_end]]
            p_end -= 1

        while p_start > start:
            p_start -= 1
            if nums[p_start] in p_set.keys():
                p_set[nums[p_start]] += 1
            else:
                p_set[nums[p_start]] = 1
                
        while p_start < start:
            p_set[nums[p_start]] -= 1
            if p_set[nums[p_start]] == 0:
                del p_set[nums[p_start]]
            p_start += 1

        cache = (p_start, p_end, p_set)

        return len(p_set), cache
            
    else: #처음 쿼리
        p_set = {} #dict로 저장
        for i in range(start,end+1):
            if nums[i] in p_set.keys():
                p_set[nums[i]] +=1
            else:
                p_set[nums[i]] = 1
                
        cache = (start, end, p_set)
        
        return len(p_set), cache

cache = 0
for x in queries:
    s, e, idx = x
    cnt, cache = query(s, e, cache)
    ans[idx] = cnt

for x in ans:
    print(x)

나는 i부터 j까지의 서로 다른 값들을 p_set{} 의 딕셔너리 형태로 저장하는 방법을 사용했다. dict에 서로 다른 수를 key로 저장하고, 중복되는 수가 있다면 그 수의 갯수를 값으로 저장한다. 추가할때에는 해당 키가 있으면 +=1, 없으면 값을 1로 할당하여 딕셔너리에 추가해준다. 만약 해당 범위에서 수가 사라져 삭제하면 -=1 해주고, 값이 0 이면 그 수가 없다는것이므로 해당 키를 삭제해 주었다. 이렇게 하면 p_set의 길이를 구하면 서로 다른 수의 개수를 얻을 수 있다.

 

먼저 쿼리를 처음 수행할때에는 이전의 쿼리가 없으므로 naive하게 s 부터 e 까지 탐색해 가며 p_set을 구한다. 이전의 (start, end, p_set)를 캐시로 저장하여 다음 쿼리에서 사용할 수 있도록 넘겨준다. 다음 함수는 cache를 통해 이전 쿼리의 시작, 끝, 딕셔너리 정보를 p_start, p_end, p_set으로 얻을 수 있다. 이를 바탕으로 p_start 는 start와 일치할 때까지, p_end는 end와 일치할 때까지 변경시켜 가며 p_set을 수정한다.

 

p_end를 먼저 변경해 주는 이유는 우리가 쿼리를 정렬할 때 오름차순으로 정리하였기 때문에 (1,1) (3,3) 등의 쿼리가 있을 때 만약 p_start를 먼저 변경하면 (1,1) -> (2,1) -> (3,1) 이렇게 start보다 end가 작은 상황이 발생한다. 이런 현상을 막기 위해 p_end부터 변경해준다.