본문 바로가기

알고리즘 공부 및 문제 풀이/알고리즘(ALGORITHM)

[알고리즘] 세그먼트 트리(Segment Tree)

1. 세그먼트 트리

세그먼트 트리특정 구간의 합(최솟값, 최댓값, 곱 등)을 구할 때 사용하는 자료구조이다.

 

문제 상황은 다음과 같다. (백준 설명)

배열 A가 있고, 여기서 다음과 같은 두 연산을 수행해야하는 문제를 생각해봅시다.
1) 구간 l, r (l ≤ r)이 주어졌을 때, A[l] + A[l+1] + ... + A[r-1] + A[r]을 구해서 출력하기
2) i번째 수를 v로 바꾸기. A[i] = v
수행해야하는 연산은 최대 M번입니다.
세그먼트 트리나 다른 방법을 사용하지 않고 문제를 푼다면, 1번 연산을 수행하는데 O(N), 2번 연산을 수행하는데 O(1)이 걸리게 됩니다. 총 시간 복잡도는 O(NM) + O(M) = O(NM)이 나오게 됩니다.

 

N과 M이 매우 큰 경우, 너무 오랜 시간이 걸리게 된다. 하지만 세그먼트 트리를 이용하면 O(logN)만에 수행할 수 있다.

 

세그먼트 트리를 사용하기 위해서는 배열을 이진 트리 구조로 만들어야 한다. 이때 세그먼트 트리의 리프 노드와 리프 노드가 아닌 다른 노드는 다음과 같은 의미를 가진다.

리프노드: 배열의 그 수 자체를 저장

다른 노드: 부모노드는 양 쪽 자식 노드의 값의 합을 저장

따라서 N=10일 때의 세그먼트 트리는 다음과 같다.

 

 


2. 세그먼트 트리 구현

 

#include <iostream>
#include <vector>
#include <queue>
#include <cmath>
#include <cstring>

using namespace::std;

vector<long long> tree;
long long arr[11] = {0,1, 2, 3, 4, 5, 6, 7, 8, 9, 10};

long long init(int index, int start, int end){
    if(start==end){
        tree[index] = arr[start];
    }
    else{
        int mid = (start+end)/2;
        tree[index] = init(2*index, start, mid) + init(2*index+1, mid+1, end);
    }
    return tree[index];
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);

    int N = 10;
    int h = ceil(log2(N));

    tree = vector<long long>(1<<(h+1));

    for(int i=1; i<=N; i++){
        cin >> arr[i];
    }

    init(1, 1, N);

}

 

양쪽 자식 노드에 대해 시작 인덱스가 1이므로 2*index, 2*index+1로 표기한다. (0부터 시작하다면 2*index+1, 2*index+2)

트리의 크기는 2^h+1이므로 노드를 (1<<(h+1))만큼 할당해주었다.

 

- 세그먼트 트리 update

 

void update(int changed_index, long long nvalue, int index, int start, int end){
    if (changed_index < start || changed_index > end)
        return;

    if(start==end){
        tree[index] = nvalue;
        return;
    };

    int mid = (start+end) / 2;
    update(changed_index, nvalue, index*2, start, mid);
    update(changed_index, nvalue, index*2+1, mid+1, end);
    tree[index] = tree[index*2] * tree[index*2 + 1];
}

 

역시 재귀를 사용하여 업데이트 하며, 리프노드일 때만 바꿀 수 있도록 한다.

 

-세그먼트 트리 구간 합

 

long long sum(int index, int start, int end, int left, int right){
    if (start > right || end < left)
        return 0;
    else if(left<=start && end<=right){
        return tree[index];
    }
    else{
        int mid = (start+end)/2;
        return sum(2*index, start, mid, left, right) + sum(2*index+1, mid+1, end, left, right);
    }
}