세그먼트 트리

  • 주어진 데이터들의 구간 합과 데이터 업데이트를 빠르게 수행하기 위해 고안해낸 자료구조가 바로 세그먼트 트리이다.

Segment tree 핵심이론

  • 세그먼트 트리의 종류는 구간 합, 최대/최소 구하기로 나눌 수 있고,

  • 구현 단계는 트리 초기화하기, 질의값 구하기, 데이터 업데이트하기로 나눌 수 있다.

트리 초기화하기

  • 리프 노드에 원본 데이터를 입력한다. 이때 리프 노드의 시작 위치를 트리 배열의 인덱스로 구해야 하는데, 2^k를 시작 인덱스로 취하면 된다.

  • 리프 노드를 제외한 나머지 노드 값을 채ㅊ운다. 자신의 자식 노드를 이용해 해당 값을 채울 수 있다.

질의값 구하기

✔️ 질의 인덱스를 세그먼트 트리 인덱스로 변경

세그먼트 트리 index = 주어진 질의 index + 2^k - 1

✔️ 질의값 구하는 과정

  1. start_index % 2 == 1일 때 해당 노드를 선택한다.

  2. end_index % 2 == 0일 때 해당 노드를 선택한다.

  3. 1 ~ 2에서 노드를 선택하지 않았다면 startindex = (startindex + 1) / 2 연산을 실행한다.

  4. 1 ~ 2에서 노드를 선택하지 않았다면 endindex = (endindex - 1) / 2 연산을 실행한다.

  5. 1 ~ 4를 반복하다가 endindex < start_index가 되면 종료한다.

✔️ 질의에 해당하는 노드 선택 방법

  • 구간 합 : 선택된 노드들을 모두 더한다.

  • 최댓값 구하기 : 선택된 노드들 중 MAX값을 선택해 출력한다.

  • 최솟값 구하기 : 선택된 노드들 중 MIN 값을 선택해 출력한다.

데이터 업데이터 업데이트하기

구간 합 구하기 3

Question - 2042

Code

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class N71_P2042_구간합_구하기3 {

    static long[] tree;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());

        int num = Integer.parseInt(st.nextToken());
        int change = Integer.parseInt(st.nextToken());
        int cal = Integer.parseInt(st.nextToken());

        int height = 0;
        int length = num;

        while (length != 0) {
            length /= 2;
            height++;
        }

        int treeSize = (int) Math.pow(2, height + 1);
        int leftNodeStartIndex = treeSize / 2 - 1;
        tree = new long[treeSize + 1];

        for (int i = leftNodeStartIndex + 1; i <= leftNodeStartIndex + num; i++) {
            tree[i] = Long.parseLong(br.readLine());
        }

        setTree(treeSize - 1);

        for (int i = 0; i < change + cal; i++) {
            st = new StringTokenizer(br.readLine());
            long type = Integer.parseInt(st.nextToken());
            int s = Integer.parseInt(st.nextToken());
            long e = Long.parseLong(st.nextToken());
            if (type == 1) { // 변경
                changeVal(leftNodeStartIndex + s, e);
            } else if (type == 2) {
                s = s + leftNodeStartIndex;
                e = e + leftNodeStartIndex;
                System.out.println(getSum(s, (int) e));
            } else {
                return;
            }
        }

        br.close();
    }

    private static long getSum(int s, int e) {
        long partSum = 0;
        while (s <= e) {
            if (s % 2 == 1) {
                partSum = partSum + tree[s];
                s++;
            }
            if (e % 2 == 0) {
                partSum = partSum + tree[e];
                e--;
            }

            s = s / 2;
            e = e / 2;
        }
        return partSum;
    }

    private static void changeVal(int index, long val){
        long diff = val - tree[index];
        while (index > 0) {
            tree[index] = tree[index] + diff;
            index = index / 2;
        }
    }

    private static void setTree(int i) {
        while ( i != 1) {
            tree[i / 2] += tree[i];
            i--;
        }
    }
}

Idea

reference

최솟값 찾기 2

Question - 10868

Code

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class N72_P10868_최솟값_찾기2 {
    static long[] tree;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        int num = Integer.parseInt(st.nextToken());
        int pair = Integer.parseInt(st.nextToken());

        int height = 0;
        int length = num;
        while (length != 0) {
            length /= 2;
            height++;
        }

        int treeSize = (int) Math.pow(2, height + 1);
        int leftNodeStartIndex = treeSize / 2 - 1;

        // 트리 초기화하기
        tree = new long[treeSize + 1];
        for (int i = 0; i < tree.length; i++) {
            tree[i] = Integer.MAX_VALUE;
        }

        // 데이터 입력받기
        for (int i = leftNodeStartIndex + 1; i <= leftNodeStartIndex + num; i++) {
            tree[i] = Long.parseLong(br.readLine());
        }

        setTree(treeSize - 1);

        for (int i = 0; i < pair; i++) {
            st = new StringTokenizer(br.readLine());
            int s = Integer.parseInt(st.nextToken());
            int e = Integer.parseInt(st.nextToken());
            s = s + leftNodeStartIndex;
            e = e + leftNodeStartIndex;
            System.out.println(getMin(s, e));
        }
    }

    private static long getMin(int s, int e) {
        long min = Long.MAX_VALUE;
        while (s <= e) {
            if (s % 2 == 1) {
                min = Math.min(min, tree[s]);
                s++;
            }
            s = s / 2;
            if (e % 2 == 0) {
                min = Math.min(min, tree[e]);
                e--;
            }
            e = e / 2;
        }
        return min;
    }

    private static void setTree(int i) {
        while (i != 1) {
            if (tree[i / 2] > tree[i]) {
                tree[i / 2] = tree[i];
            }
            i--;
        }
    }
}

Idea

reference

구간 곱 구하기

Question - 11505

Code

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class N73_P11505_구간곱_구하기 {

    static long[] tree;
    static int MOD;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());

        int num = Integer.parseInt(st.nextToken());
        int change = Integer.parseInt(st.nextToken());
        int cal = Integer.parseInt(st.nextToken());

        int height = 0;
        int length = num;
        while (length != 0) {
            length /= 2;
            height++;
        }

        int treeSize = (int) Math.pow(2, height + 1);
        int leftNodeStartIndex = treeSize / 2 - 1;
        MOD = 1000000007;
        tree = new long[treeSize + 1];

        for (int i = 0; i < tree.length; i++) {
            tree[i] = 1;
        }

        for (int i = leftNodeStartIndex + 1; i <= leftNodeStartIndex + num; i++) {
            tree[i] = Long.parseLong(br.readLine());
        }

        setTree(treeSize - 1);
        for (int i = 0; i < change + cal; i++) {
            st = new StringTokenizer(br.readLine());
            long a = Integer.parseInt(st.nextToken());
            int s = Integer.parseInt(st.nextToken());
            long e = Long.parseLong(st.nextToken());
            if (a == 1) {
                changeVal(leftNodeStartIndex + s, e);
            } else if (a == 2) {
                s = s + leftNodeStartIndex;
                e = e + leftNodeStartIndex;
                System.out.println(getMul(s, (int) e));
            } else {
                return;
            }
        }

        br.close();
    }

    private static long getMul(int s, int e) {
        long partMul = 1;
        while (s <= e) {
            if (s % 2 == 1) {
                partMul = partMul * tree[s] % MOD;
                s++;
            }
            if (e % 2 == 0) {
                partMul = partMul * tree[e] % MOD;
                e--;
            }
            s = s / 2;
            e = e / 2;
        }
        return partMul;
    }

    private static void changeVal(int index, long val) {
        tree[index] = val;
        while (index > 1) {
            index = index / 2;
            tree[index] = tree[index * 2] % MOD * tree[index * 2 + 1] % MOD;
        }
    }

    private static void setTree(int i) {
        while (i != 1) {
            tree[i / 2] = tree[i / 2]*tree[i]%MOD;
            i--;
        }
    }
}

Idea

reference

Last updated