본문 바로가기

알고리즘

백준 - 1717 집합의 표현(Java)

https://www.acmicpc.net/problem/1717

 

1717번: 집합의 표현

첫째 줄에 n(1 ≤ n ≤ 1,000,000), m(1 ≤ m ≤ 100,000)이 주어진다. m은 입력으로 주어지는 연산의 개수이다. 다음 m개의 줄에는 각각의 연산이 주어진다. 합집합은 0 a b의 형태로 입력이 주어진다. 이는

www.acmicpc.net

 

유니온-파인드(Uion-Find) 문제입니다.

0~n까지의 집합이 주어지고, m번 만큼 특정 두 개의 원소가 같은 집합에 속하는지 알아내는 문제입니다.

 

문제를 해결하기 위해서는 길이가 n인 배열을 사용해야 합니다.

각 배열의 값은 자신의 부모 노드를 가리키게 해야 합니다.

물론 아무런 연결이 없는 초기에는 자기 자신을 가리키게 하면 됩니다.

 

다음 그림과 같은 연결 그래프가 있다고 하겠습니다.

각 인덱스의 값을 자기 자신의 값으로 바꿔 자기자신이 곧 자기의 부모 노드가 되도록 합니다.

이제 각 인덱스의 값을 자기 자신의 부모의 값으로 바꿉니다.

2의 부모는 1이기 때문에, arr[2] = 1입니다.

3의 부모는 2이기 때문에, arr[3] = 2입니다.

4의 부모는 0이기 때문에, arr[4] = 0입니다.

이제 재귀를 통해 부모를 찾아가면 됩니다.

재귀를 시작할 때의 값이 자기 자신이 되도록 짜면 됩니다.

private static int find(int a) {
    if(a == arr[a]){
        return a;
    }
    return find(arr[a]);
}

값을 넣어보면 결국 find(3)의 값과 find(2)의 값은 1로, 2와 3은 루트가 1인 집합에 속한다고 할 수 있습니다.

즉, 두 개를 find를 했을 때 같은 값이 나오면 같은 집합에 속하는 것입니다.

이것이 바로 파인드(Find)연산 입니다.

 

그러면 어떻게 그림과 같이 같은 집합을 만들까요?

이제 유니온(Union)연산이 필요합니다.

그림과 같이 서로 다른 두 개의 집합을 하나로 만드는 방법은 간단합니다.

두 개의 루트 노드를 찾아 한 쪽에 연결하면 됩니다.

즉, 다음과 같은 식이 만들어집니다.

private static void union(int a, int b) {
    int x1 = find(a);
    int x2 = find(b);

    arr[x1] = x2;
}

루트 노드를 찾기위해 union 연산에도 find 연산이 필요합니다.(이 이유가 파인드 연산을 먼저 설명한 이유입니다.)

예를 들어 3, 4를 같은 집합으로 만들기 위해서는 각각의 루트인 1과 0을 찾아 한쪽에 연결하며 됩니다.

0쪽에 1을 연결한다고 했을 때, 1의 부모는 0이되기 때문에 나중에 모든 노드의 find연산의 값은 0이 됩니다.

 

하지만, 1717번 문제는 위 두 개의 유니온-파인드 함수로는 시간초과가 발생합니다.

그 이유 역시 그림으로 설명하도록 하겠습니다.

위 그림과 같이 find(4)의 네 번의 재귀가 필요합니다.

즉, 그림과 같은 조건일 때 시간복잡도는 O(N)입니다.

 

1717번은 탐색 시간 N과 연산의 개수 M을 곱하면 O(NM)입니다.

이는, 문제 조건에 의해 1,000,000*100,000 = 100,000,000,000으로 시간초과가 발생합니다.

 

이를 어떻게 줄일 수 있을 까요?

파인드 연산에 약간의 변화면 주면 됩니다.

private static int find(int a) {
    if(a == arr[a]){
        return a;
    }
    return arr[a] = find(arr[a]);
}

코드를 위과 수정하게 된다면 한 집합의 모든 노드는 루트 노드만을 가리키게 되어 연산의 속도가 O(1)이 되어 

기존의 시간복잡도 O(NM)이 O(M)으로 줄어 시간초과가 발생하지 않게 됩니다.

 

[구현 코드]

import java.io.*;
import java.util.StringTokenizer;

public class Main {
    private static int n, m;
    private static int[] arr;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        StringTokenizer stz = new StringTokenizer(br.readLine(), " ");

        n = Integer.parseInt(stz.nextToken());
        m = Integer.parseInt(stz.nextToken());

        arr = new int[n + 1];

        for(int i=0; i<n+1; i++){
            arr[i] = i;
        }

        for (int i = 0; i < m; i++) {
            stz = new StringTokenizer(br.readLine(), " ");

            int op = Integer.parseInt(stz.nextToken());
            int a = Integer.parseInt(stz.nextToken());
            int b = Integer.parseInt(stz.nextToken());

            if (op == 0) {
                union(a, b);
            } else {
                int x = find(a);
                int y = find(b);
                String ans = (x == y) ? "YES" : "NO";
                bw.write(ans + "\n");
            }
        }

        bw.flush();
        bw.close();
        br.close();
    }

    private static void union(int a, int b) {
        int x1 = find(a);
        int x2 = find(b);

        arr[x1] = x2;
    }

    private static int find(int a) {
        if(a == arr[a]){
            return a;
        }
        return arr[a] = find(arr[a]);
    }
}