본문 바로가기
Algorithm/문제코드

BOJ_트리의 독립집합 [DFS, DP]

by Thinking 2024. 11. 9.

트리의 독립집합

 

2 초 128 MB 7738 3809 2834 48.511%

그래프 G(V, E)에서 정점의 부분 집합 S에 속한 모든 정점쌍이 서로 인접하지 않으면 (정점쌍을 잇는 간선이 없으면) S를 독립 집합(independent set)이라고 한다. 독립 집합의 크기는 정점에 가중치가 주어져 있지 않을 경우는 독립 집합에 속한 정점의 수를 말하고, 정점에 가중치가 주어져 있으면 독립 집합에 속한 정점의 가중치의 합으로 정의한다. 독립 집합이 공집합일 때 그 크기는 0이라고 하자. 크기가 최대인 독립 집합을 최대 독립 집합이라고 한다.

 

문제는 일반적인 그래프가 아니라 트리(연결되어 있고 사이클이 없는 그래프)와 각 정점의 가중치가 양의 정수로 주어져 있을 때, 최대 독립 집합을 구하는 것이다.

 

입력

첫째 줄에 트리의 정점의 수 n이 주어진다. n은 10,000이하인 양의 정수이다. 1부터 n사이의 정수가 트리의 정점이라고 가정한다. 둘째 줄에는 n개의 정수 w1, w2, ..., wn이 주어지는데, wi는 정점 i의 가중치이다(1 ≤ i ≤ n). 셋째 줄부터 마지막 줄까지는 간선의 리스트가 주어지는데, 한 줄에 하나의 간선을 나타낸다. 간선은 정점의 쌍으로 주어진다. 입력되는 정수들 사이에는 빈 칸이 하나 있다. 가중치들의 값은 10,000을 넘지 않는 자연수이다.

출력

첫째 줄에 최대 독립집합의 크기를 출력한다. 둘째 줄에는 최대 독립집합에 속하는 정점을 오름차순으로 출력한다. 최대 독립 집합이 하나 이상일 경우에는 하나만 출력하면 된다.

 

 

1. 첫번째로 풀이한 나의 코드이다. (정답X) 해당 문제를 해결하기 위해 고안했던 방법은 2가지였다. 이 코드는, 전체 노드의 가중치와 인덱스를 부여해서, PriorityQueue를 사용하여, 가중치가 높으면서, 인접하지 않은 노드를 뽑아 결과를 도출하는 방법이었다. 하지만 이 방식의 문제는 전체 트리 노드의 최댓값을 구하기에 적합하지 않다. 왜냐하면, 가중치가 높은 순대로 인접하지 않은 노드를 뽑는 것 보다 해당 노드를 선택하지 않았을 때의 노드 합이 더 클 수 있반례가 존재하기 때문이다.

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;
import java.util.StringTokenizer;

public class Main {
    static class Node implements Comparable<Node>{
        int index;
        int weight;

        Node(int index, int weight){
            this.index = index;
            this.weight = weight;
        }

        @Override
        public int compareTo(Node o){
            return o.weight - this.weight;
        }
    }
    static boolean[] visited;
    static List<Integer>[] list;
    static PriorityQueue<Integer> nodePq = new PriorityQueue<>();
    static int answer = 0;
    public static void main(String[] args)throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int N = Integer.parseInt(br.readLine());
        StringTokenizer st = new StringTokenizer(br.readLine());
        PriorityQueue<Node> pq = new PriorityQueue<>();
        for(int i=1; i<=N; i++){
            pq.add(new Node(i, Integer.parseInt(st.nextToken())));
        }

        visited = new boolean[N+1];
        list = new ArrayList[N+1];

        for(int i=1; i<=N; i++){
            list[i] = new ArrayList<>();
        }

        while(true){
            String line = br.readLine();
            if(line == null || line.isEmpty()) break;
            String[] tmp = line.split(" ");

            int from = Integer.parseInt(tmp[0]);
            int to = Integer.parseInt(tmp[1]);

            list[from].add(to);
            list[to].add(from);
        }
        // 7(2,6) 3(2,4)
        while(!pq.isEmpty()){
            Node cur = pq.poll();

            if(!visited[cur.index]){
                visited[cur.index] = true;
                nodePq.add(cur.index);
                answer += cur.weight;

                for(int next : list[cur.index]){
                    visited[next] = true;
                }
            }
        }

        System.out.println(answer);
        while(!nodePq.isEmpty()){
            System.out.print(nodePq.poll() +" ");
        }
    }
}

 

 

 

2. 해당 문제의 허용 시간은 2초내이면서, N 범위는 10000 이하이다. 자바 기준 N^2으로 풀이가 가능하다고 생각했고, 그렇다면 모든 노드를 기준으로 최댓값을 구해서, 답을 도출하는 방식으로 작성했다.

 

(정답 X) 만약 1번 노드를 기준으로 탐색을 시작했을 때, 인접하지 않으면서, 가중치가 높은것을 기준으로 뽑아줬다. 프림 알고리즘과 유사하다고 생각할 수 있을 것 같다. 그리고 처음 노드를 기준으로 시작할 때. visited 방문 배열과 answer(모든 가중치의 최대합 변수), answerNodePq(합이 최대일 때 방문한 노드의 인덱스)를 각각의 조건에 맞을 때 수행한다.

 

구현이나 방법으로는 문제가 원하는 것을 도출할 수 있다고 생각하지만, Pq에 해당 노드를 집어넣고, HashSet, 인접한 노드를 계속 체크해준다는 점에서 메모리초과가 일어난다.

 

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;

public class Main {
    static class Node implements Comparable<Node>{
        int index;
        int weight;

        Node(int index, int weight){
            this.index = index;
            this.weight = weight;
        }

        @Override
        public int compareTo(Node o){
            return o.weight - this.weight;
        }
    }
    static List<Integer>[] list;
    static PriorityQueue<Integer> nodePq;
    static PriorityQueue<Integer> answerNodePq = new PriorityQueue<>();
    static int answer = Integer.MIN_VALUE;
    static int[] data;
    static int N;
    public static void main(String[] args)throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        N = Integer.parseInt(br.readLine());
        data = new int[N+1];
        String[] point = br.readLine().split(" ");
        for(int i=1; i<=N; i++){
            data[i] = Integer.parseInt(point[i-1]);
        }
        list = new ArrayList[N+1];
        for(int i=1; i<=N; i++){
            list[i] = new ArrayList<>();
        }

        while(true){
            String line = br.readLine();
            if(line == null || line.isEmpty()) break;
            String[] tmp = line.split(" ");

            int from = Integer.parseInt(tmp[0]);
            int to = Integer.parseInt(tmp[1]);

            list[from].add(to);
            list[to].add(from);
        }

        for(int i=1; i<=N; i++){
            nodePq = new PriorityQueue<>();
            int cnt = check(i);

            if(cnt > answer){
                answerNodePq.clear();
                while(!nodePq.isEmpty()){
                    answerNodePq.add(nodePq.poll());
                }
                answer = cnt;
            }
        }

        StringBuilder sb = new StringBuilder();
        sb.append(answer).append("\n");

        while(!answerNodePq.isEmpty()){
            sb.append(answerNodePq.poll()).append(" ");
        }
        System.out.print(sb.toString());
    }

    private static int check(int val) {
        HashSet<Integer> set = new HashSet<>();
        PriorityQueue<Node> pq = new PriorityQueue<>();
        pq.add(new Node(val, data[val]));
        int total = 0;

        while(!pq.isEmpty()){
            Node cur = pq.poll();

            if(!set.contains(cur.index)){
                set.add(cur.index);
                nodePq.add(cur.index);
                total += cur.weight;

                set.addAll(list[cur.index]);

                for(int i=1; i<=N; i++){
                    if(!set.contains(i)) pq.add(new Node(i, data[i]));
                }
            }
        }

        return total;
    }
}

 

 

 

그래서 해결할 수 있는 방안으로 아래 2가지를 추가 적용하면 된다.

  • 중복된 PriorityQueue 사용: 현재 코드에서 PriorityQueue를 nodePq, answerNodePq 및 pq로 여러 번 생성하고 사용하고 있어 메모리 사용이 높습니다. 각 노드를 기준으로 트리의 독립 집합을 구하는 방식으로 전체 노드에 대해 모든 조합을 다 시도하게 되어 성능이 저하됩니다.
  • 재귀적 DFS 사용 및 메모이제이션 적용: 트리 구조의 독립 집합 문제는 동적 계획법(DP)을 사용하여 메모이제이션을 통해 중복 계산을 줄일 수 있습니다. dp[i][0]과 dp[i][1]을 사용하여 각각의 노드가 독립 집합에 포함되지 않는 경우와 포함되는 경우의 최대 가중치를 저장하는 방식으로 문제를 해결할 수 있습니다.

 

아래는 정답 코드이고, 중심적으로 봐야할 부분은 dp[N+1][2] 선언으로, 해당 노드가 선택되었는지 안되었는지 각각의 상황마다 가중치를 구하는 것(dfs 메서드), 그리고 해당 노드가 선택되었다면, 인접한 노드의 선택은 불가하며 해당 노드가 선택되지 않았다면, 인접한 노드의 선택은 될 수 있고 안될 수 있다.(trace 메서드)

 

3. 정답 코드

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;

public class boj_2213 {
    static List<Integer>[] tree;
    static int[] weights;
    static int[][] dp;
    static boolean[] visited;
    static PriorityQueue<Integer> selectedNodes = new PriorityQueue<>();

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int N = Integer.parseInt(br.readLine());

        weights = new int[N + 1];
        String[] weightStr = br.readLine().split(" ");
        for (int i = 1; i <= N; i++) {
            weights[i] = Integer.parseInt(weightStr[i - 1]);
        }

        tree = new ArrayList[N + 1];
        for (int i = 1; i <= N; i++) {
            tree[i] = new ArrayList<>();
        }

        String line;
        while ((line = br.readLine()) != null && !line.isEmpty()) {
            String[] edge = line.split(" ");
            int u = Integer.parseInt(edge[0]);
            int v = Integer.parseInt(edge[1]);
            tree[u].add(v);
            tree[v].add(u);
        }

        dp = new int[N + 1][2]; // 각각의 노드가 독립 집합에 포함되지 않는 경우와 포함되는 경우
        visited = new boolean[N + 1];

        dfs(1);
        StringBuilder sb = new StringBuilder();
        int maxWeight = Math.max(dp[1][0], dp[1][1]);
        sb.append(maxWeight).append("\n");

        visited = new boolean[N + 1];
        if (dp[1][1] >= dp[1][0]) {
            trace(1, 1);
        } else {
            trace(1, 0);
        }

        while(!selectedNodes.isEmpty()){
            sb.append(selectedNodes.poll()).append(" ");
        }

        System.out.println(sb.toString());
    }

    private static void dfs(int node) {
        visited[node] = true;
        dp[node][0] = 0;
        dp[node][1] = weights[node];

        for (int child : tree[node]) {
            if (!visited[child]) {
                dfs(child);
                dp[node][0] += Math.max(dp[child][0], dp[child][1]);
                dp[node][1] += dp[child][0];
            }
        }
    }

    private static void trace(int node, int isSelected) {
        visited[node] = true;
        if (isSelected == 1) {
            selectedNodes.add(node);
            for (int child : tree[node]) {
                if (!visited[child]) {
                    trace(child, 0);
                }
            }
        } else {
            for (int child : tree[node]) {
                if (!visited[child]) {
                    if (dp[child][1] >= dp[child][0]) {
                        trace(child, 1);
                    } else {
                        trace(child, 0);
                    }
                }
            }
        }
    }
}

 

 

1. dfs 함수

dfs 함수는 트리의 특정 노드를 기준으로 해당 노드를 포함하거나 포함하지 않는 경우의 최대 가중치 합을 계산하여 dp[node][0]과 dp[node][1]에 저장합니다.

private static void dfs(int node) {
    visited[node] = true;  // 현재 노드를 방문 처리
    dp[node][0] = 0;       // 현재 노드를 포함하지 않는 경우의 초기 가중치 합
    dp[node][1] = weights[node]; // 현재 노드를 포함하는 경우의 초기 가중치 합 (자기 자신의 가중치)

    // 현재 노드의 자식 노드를 탐색
    for (int child : tree[node]) {
        if (!visited[child]) {
            dfs(child); // 자식 노드에 대해 dfs 호출

            // 현재 노드를 포함하지 않는 경우: 자식 노드는 포함될 수도 있고, 포함되지 않을 수도 있다.
            dp[node][0] += Math.max(dp[child][0], dp[child][1]);

            // 현재 노드를 포함하는 경우: 자식 노드는 포함되지 않아야 함.
            dp[node][1] += dp[child][0];
        }
    }
}



1. 방문 처리:
   - visited[node] = true 는 현재 노드를 방문 처리하여, 이미 방문한 노드를 다시 방문하지 않도록 합니다. 이를 통해 무한 루프를 방지합니다.

2. 초기 가중치 설정:
   - dp[node][0] = 0 -> dp[node][0]은 현재 노드를 포함하지 않는 경우의 최대 가중치 합입니다. 초기값은 0으로 설정합니다.
   - dp[node][1] = weights[node] -> dp[node][1]은 현재 노드를 포함하는 경우의 최대 가중치 합 입니다. 초기값은 해당 노드의 가중치(weights[node])로 설정합니다.

3. 자식 노드 탐색:
   - for (int child : tree[node])-> 현재 노드의 자식 노드들을 탐색합니다.
   - if (!visited[child]) -> 방문하지 않은 자식 노드에 대해서만 dfs(child)를 호출하여 재귀적으로 탐색을 진행합니다.

4. DP 점화식 적용:
   - 현재 노드를 포함하지 않는 경우 (dp[node][0]):
     - 이 경우 자식 노드는 포함될 수도, 포함되지 않을 수도 있습니다.
     - 따라서 자식 노드의 dp[child][0] (자식 노드가 포함되지 않은 경우의 최대 가중치 합)과 dp[child][1] (자식 노드가 포함된 경우의 최대 가중치 합) 중 더 큰 값을 선택하여 누적합니다.
     - dp[node][0] += Math.max(dp[child][0], dp[child][1]);

   - 현재 노드를 포함하는 경우 (dp[node][1]):
     - 이 경우 자식 노드는 반드시 포함될 수 없습니다(인접한 노드가 독립 집합에 포함될 수 없기 때문).
     - 따라서 자식 노드가 포함되지 않은 경우의 가중치 합인 dp[child][0]을 누적합니다.
     - dp[node][1] += dp[child][0];


2. trace 함수

trace 함수는 dp 배열을 활용하여 최적의 독립 집합을 구성하는 노드를 선택합니다. 즉, dp에 저장된 정보를 바탕으로 어떤 노드들이 독립 집합에 포함될지 결정합니다.

private static void trace(int node, int isSelected) {
    visited[node] = true;

    if (isSelected == 1) {
        selectedNodes.add(node);  // 독립 집합에 포함되는 노드 저장
        for (int child : tree[node]) {
            if (!visited[child]) {
                trace(child, 0);  // 포함된 노드의 자식은 포함되지 않아야 함
            }
        }
    } else {
        for (int child : tree[node]) {
            if (!visited[child]) {
                if (dp[child][1] >= dp[child][0]) {
                    trace(child, 1);  // 자식 노드를 포함하는 경우
                } else {
                    trace(child, 0);  // 자식 노드를 포함하지 않는 경우
                }
            }
        }
    }
}



1. 현재 노드를 포함하는 경우 (isSelected == 1):
   - selectedNodes.add(node) -> 현재 노드를 독립 집합에 추가합니다.
   - 자식 노드들은 반드시 독립 집합에 포함되지 않아야 합니다.
     - 따라서 각 자식 노드에 대해 trace(child, 0) 을 호출하여, 자식 노드가 독립 집합에 포함되지 않도록 처리합니다.

2. 현재 노드를 포함하지 않는 경우 (isSelected == 0):
   - 이 경우 자식 노드는 독립 집합에 포함될 수도, 포함되지 않을 수도 있습니다.
   - 각 자식 노드에 대해 dp[child][1] (자식 노드가 포함된 경우의 최대 가중치 합)과 dp[child][0] (자식 노드가 포함되지 않은 경우의 최대 가중치 합)을 비교하여 더 큰 값을 선택합니다.
     - dp[child][1] >= dp[child][0]인 경우, 자식 노드를 포함하도록 trace(child, 1)를 호출합니다.
     - 그렇지 않은 경우, 자식 노드를 포함하지 않도록 trace(child, 0)를 호출합니다.

 

 

요약
- dfs 함수는 DP 배열을 채우면서 각 노드의 최적 독립 집합 가중치를 계산합니다.
- trace 함수는 dp 배열을 참조하여 독립 집합에 포함될 노드들을 추적하며, 최적의 독립 집합을 구성합니다.

아래는 차례로 다른 방식을 활용하여 동작한 시간을 정리한 것이다.

 

1) StringBuilder X, PriorityQueue O

2) StringBuilder O, PriorityQueue O

3) StringBuilder X, PriorityQueue X -> ArrayList O

 

 

'Algorithm > 문제코드' 카테고리의 다른 글

BOJ - 점프2253 [DP] Java  (1) 2024.11.04
알고리즘 기록 Github Link  (0) 2024.10.01
codeTree - Grid Compression [Hard]  (0) 2024.08.16
codeTree - PriortyQueue [Hard]  (0) 2024.08.16
codeTree - TreeSet [easy]  (0) 2024.08.16