세그먼트 트리는 "구간을 저장하기 위한 트리" 이다. 예를 들어서 int Array[5] = { 1, 2, 3, 4, 5 }
2번째 값 부터 4번째 값 까지의 합은 2 + 3 + 4로 계산할 것이며 1번째 값 부터 5번째 값 까지의 합은 1 + 2 + 3 + 4 + 5로 계산하게 될 것이다. 특정 구간에 대한 연산이라면, 모든 합을 다 구해놓고 계산하는 방식을 생각할 수 있다. 이걸 1번 연산 이라고 칭하자. 2번 째 값을 7로 바꾸면 { 1, 7, 3, 4, 5 }. 이렇게, 특정 값을 바꾸는 연산 2번 연산 이라고 칭하자. 이 후에 2번째 값부터 4번째 값 까지의 합을 구하고 3번째 값부터 5번째 값 까지의 합을 구하고, 4번째 값을 6으로 바꾸는거와 같은 연산들이 쭉 이어진다고 생각해보자. 별로 문제 없어보인다 1번연산과 2번연산에 대해서 조금 더 구체적으로 생각해보자. N칸 짜리 배열이 있다. 1번과 2번 연산은 합쳐서 최대 M번 주어진다고 가정해보자. 우리는 1번 연산을 진행하는데 걸리는 시간은 O(N) 만큼의 시간이 소요될 수 있고, 2번 연산을 진행하는데 걸리는 시간은 O(1) 만큼 소요될 수 있다. 결과적으로 2개의 연산을 모두 진행하는데 걸리는 시간은 O(NM) 만큼의 시간복잡도를 갖게 된다. 즉, N값과 M값이 매우 큰 값이 들어올 경우에는 생각보다 간단한 연산이라고 생각했음에도 굉장히 시간이 오래 걸리게 된다. 이런 경우 '세그먼트 트리'를 사용하게 되면 O(logN) 만큼의 시간만으로 굉장히 효율적으로 해결할 수 있다.
각 노드들 안에 적혀있는 숫자들은 배열의 Index번호를 의미한다. 리프노드(자식이 없는 노드)들은 배열의 값을 그대로 가지고 있는 형태이고, 그게 아닌 노드들은 자식 노드들이 가지는 값들에 대한 연산 결과를 저장하고 있는 형태이다. 예를 들어서 각 구간에 대한 합을 구하기 위해서 세그먼트 트리를 만들었다면, '0-1', '2-3' 과 같이, a-b 의 형태로 적혀있는 노드들이 갖는 의미는 a번 Index부터 b번 Index까지의 합을 나타내고 있는 것이다. 즉, 루트노드는 전체 구간에 대한 합을 가지고 있게 되는 것이다.
세그먼트 트리(Segment Tree) 생성과정
Arr[]{ 1, 2, 3, 4, 5 } 라는 배열이 있다고 가정하고. 그리고 1번연산'과 2번연산을 진행해보자. 세그먼트 트리는 트리 중에서도 '이진트리'의 모습을 가진 구조이다. 리프노드(자식이 없는 가장 말단노드)는 배열의 값 그 자체를, 그게 아닌 노드에는 해당 자식들의 합을 저장" 하는 형태이다.
리프노드(자식이 없는 노드)들은(파랑색 원), 배열 { 1, 2, 3, 4, 5 } 의 값을 그대로 가지고 있다. 하지만, 리프노드가 아닌 다른 노드들을 보게되면 다른 숫자가 있을텐데 자식 노드들의 합이다. 루트노드의 '15'라는 값은 왼쪽 자식인 '6'과 오른쪽 자식인 '9'의 합이고 왼쪽 자식인 6은, '3'과 '3'의 합이다. 그 중, 오른쪽 자식은 리프노드로써, 배열의 3번째 값인 '3'을 나타내고 있다. 그리고 왼쪽 자식인 '3'을 보게되면, '1'과 '2'의 합으로 생성된 값을 나타내고 있으며, '1'과 '2'는 리프노드로써 더 이상의 자식을 가지고 있지 않다. 다시 루트노드에서 오른쪽 자식으로 오게되면 '9'라는 값을 가지고 있는데 이 '9'는 '4'와 '5'의 합이고 '4'와 '5'는 리프노드로써 배열의 값을 그대로 가지고 있다. 즉, '15'라는 값은, 모든 자식의 합이다
위의 배열 5개에 대한 세그먼트 트리를 통해서 알아 낼 수 있는 정보들은
- 5개에 대한 세그먼트 트리를 구현하는데 필요한 노드의 수는 9개이다.
- 트리의 높이가 3이다.
2의 제곱꼴로 표현되는 크기의 배열을 생각해보자. 예를 들어서 { 1, 2, 3, 4 }
위의 세그먼트 트리는 노드가 7개이고, 높이가 2인 이진트리의 형태를 갖추고 있다. log2(N) 으로 구할 수가 있다. 조금 더 정확하게 표현해보자면 ceil(log2(N)). 여기서 'ceil(x)' 는 반올림을 시켜주는 함수이다. N = 5일 경우, log2(5) = 2.xxxx 이고, 이 경우에는 소수점이 붙는데, 해당 소수점이 0.5보다 크든 작든 상관없이 무조건 반올림을 하게 된다. 즉, ceil(log2(5)) = 3 이 된다. 즉 높이가 3이 된다.
세그먼트 트리를 구현하기 위해서 필요한 크기는 2 ^ (높이 + 1) 을 한 값 만큼 필요하게 된다. N = 4일 때, 높이가 2인 세그먼트 트리에는 노드가 총 7개가 존재했다. 위의 공식에 대입해보면 2 ^ (2 + 1) = 2 ^ 3 = 8 실제 노드는 7개인데 하나의 노드를 자식과 부모 관계를 계산하는데 있어서 복잡해지기 때문에 배열이든 벡터이든 0번째 Index를 사용하지 않는다. 부모노드번호 x 2 = 왼쪽자식노드번호, 부모노드번호 x 2 + 1 = 오른쪽자식노드번호 이게 이진트리에서 일반적으로 사용하는 부모와 자식 관계를 표현하는 방식이다. 그런데 0번 Index를 사용하는 순간, 위의 방식이 깨져버리기 때문이다.
필요한 크기는 2 ^ (3 + 1) = 2 ^ 4 = 16칸이다. 실제로, 그림을 보게되면 총 9개의 노드가 사용되었고, 16칸 까지는 필요가 없다. 하지만 주어진 N에 따라서 딱 맞게 칸을 생성해주기가 힘들다. 왜냐하면, ceil(log2(N)) = x 라는 값이 나왔다고 가정했을 때, 이 'x'는 2의 거듭제곱꼴에 의해서 딱 맞게 나온 x일 수도 있고, 그게 아니라 더 작은 숫자인데 반올림을 해서 나온 x일 수도 있다. 트리의 높이가 H일 때, 세그먼트 트리의 크기 = 1 << (H + 1)
크기가 N인 배열이 존재할 때
- 트리의 높이 = ceil(log2(N))
- 세그먼트 트리의 크기 = (1 << (트리의 높이 + 1) )
배열의 크기 x 4로 설정해도 무방하다.
세그먼트 트리 만들기
이제 직접 값을 대입해보자. 본인 같은 경우는 보통 벡터를 하나 선언해서 세그먼트 트리로 만들어 준다.
재귀를 사용해서 만들어준다. 재귀에 사용되는 매개변수로는 다음과 같이 3개의 변수가 있다.
{ 현재 노드 번호 , 시작 범위 , 마지막 범위 } 1번 인덱스부터 사용할 것이기 때문에 가장 초기에 호출되는 현재노드번호 = 1 로 호출될 것이다. 시작범위는 배열의 시작범위이다.즉 '0'을 의미한다. 마지막범위는 배열의 마지막범위이다. 즉, 'N - 1'을 의미한다. 다시 { 1, 2, 3, 4, 5 } 라는 5개의 배열을 세그먼트 트리로 만드는 과정을 알아보자.
과정을 먼저 크기 적어보고 시작하자면 다음과 같다.
- 주어진 범위를 반으로 나눈다.
- 나눠진 2개의 범위에 대해서 '왼쪽범위'에 대한 재귀호출을 한다.
- 나눠진 2개의 범위에 대해서 '오른쪽범위'에 대한 재귀호출을 한다.
- 위의 과정을 반복한다.
이진트리로 이루어져 있기 때문에, 항상 범위를 반으로 나누는 것이 중요하다. 중간값은 = (0 + N - 1) / 2 = (0 + 4) / 2 = 2 가 된다. 즉, 우리는 0 ~ N - 1로 설정되어 있는 범위를 0 ~ 2와, 3 ~ 4로 나눠서 구현할 것이다. 시작점 ~ 끝점 으로 이루어져 있던 범위를, 시작점 ~ 중간점, 중간점 + 1 ~ 끝점 이렇게 2개의 범위로 나눈다고 생각하면 된다. 해당 범위를 토대로 재귀를 호출해보자. 노드 번호는 왼쪽 노드로 가게되면, 현재 노드 x 2 의 번호를 가지게 될 것이고, 오른쪽 노드로 가게되면 현재 노드 x 2 + 1 의 번호를 가지게 될 것이다. 나눠진 2개의 범위에 대해서 '왼쪽범위' 에 대한 재귀호출을 진행해보자. 재귀는 [ 2 , 0 , 2 ] 로 호출이 될 것이다. 그 안에서 다시 1번과정이 반복되어진다. 범위를 2개로 나누게 되면 0 ~ 1, 2 ~ 2 로 나뉘게 된다. 나눠진 2개의 범위에 대해서 '왼쪽범위' 에 대해서 재귀호출을 또 진행해보자. 재귀는 [ 4 , 0 , 1 ] 로 호출이 될 것이다. 또 반복하자. 그렇게되면, [ 8 , 0, 0 ] 으로 호출이 된다. 시작하는 범위 = 끝나는 범위일 경우, 재귀는 종료된다. 종료하기 전에 값을 반드시 설정해 줘야 한다. SegmentTree[노드번호] = 배열[시작범위];
[ 4 , 0 , 1 ] 에서 우리는 왼쪽범위인 [ 8 , 0 , 0 ] 으로 온 상태고 끝난다면, 오른쪽범위를 체크하러 갈 것이다.
오른쪽범위는 [ 9, 1, 1 ] 이 될 것이다. 마찬가지로 '시작범위 = 끝나는 범위' 이므로, 그대로 값을 설정해준다.
SegmentTree[9] = Arr[1]. 그럼 이렇게 되면 우리는 노드 4번의 값 또한 구할 수 있다. 바로, 8번과 9번 노드의 값을 구했으니, 이 2노드의 값을 더해주면
위의 상태와 같다고 볼 수 있다. 우리는 지금 [ 4, 0, 1 ] 에서 호출할 수 있는 재귀를 모두 호출 후, 4번 노드의 값 까지 구한 상태이다. 위의 재귀가 끝난다면 [ 5, 1, 1 ] 로 갈 것이다. 이 경우 바로 '시작범위 = 끝 범위' 이므로, 노드 5번의 값이 Arr[2] 로 설정될 것이다. 5번 노드의 값이 설정되면서 동시에 2번 노드의 값이 결정되었다. 바로 4번노드 + 5번 노드를 한 값이다.
이 후, 2번 노드의 값이 설정되었으면 해당 재귀 호출이 모두 끝난다는 것을 의미하고, 이제 [ 3, 3, 4 ] 로 넘어가보자.
[ 3, 3, 4 ] 는 [ 6, 3, 3] 과 [ 7, 4, 4 ]를 호출 하게 된다. 여기서 6번 노드와 7번 노드의 값이 구해질 것이고, 이 값들에 의해서 3번 노드의 값이 구해질 것이다. 3번 노드의 값이 구해짐에 따라서 2번 노드 + 3번 노드를 한 값이 1번 노드의 값이 됨으로써 세그먼트 트리가 완성된다.
#include <iostream>
#include <vector>
using namespace std;
int n;
vector<int>g_vec_segment_tree;
int* g_arr = nullptr;
int Make_segment_tree(int _node, int _start, int _end)
{
if (_start == _end)
return g_vec_segment_tree[_node] = g_arr[_start];
int mid = (_start + _end) / 2;
int left_res = Make_segment_tree(_node * 2, _start, mid);
int right_res = Make_segment_tree(_node * 2 + 1, mid + 1, _end);
return g_vec_segment_tree[_node] = left_res + right_res;
}
int main()
{
cin >> n;
g_arr = new int[n];
for (int i = 0; i < n; i++)
cin >> g_arr[i];
int tree_height = (int)ceil(log2(n));
int tree_size = (1 << (tree_height + 1));
g_vec_segment_tree.resize(tree_size);
Make_segment_tree(1, 0, n - 1);
delete[] g_arr;
}
세그먼트 트리 1번연산 (구간합 구하기)
Arr[] = { 1, 2, 3, 4, 5 } 에 대한 세그먼트 트리이다. 여기서 "2번째 값부터 3번째 값 까지의 합을 구하세요" , "3번째값부터 5번째 값 까지의 합을 구하세요" 라는 연산이 주어졌다고 생각해보자.
우리가 탐색을 할 때에는 크게 3가지 경우로 나눠서 생각할 수 있다.
- 현재 우리가 탐색하는 범위가, 우리가 찾고자 하는 구간과 완전히 겹쳐지지 않는 경우.
- 현재 우리가 탐색하는 범위가, 우리가 찾고자 하는 구간에 완전히 속해있는 경우.
- 1, 2번 경우를 제외한 나머지 경우. 즉, 일부만 걸쳐있는 경우.
1번 같은 경우에는 우리에게 제시된 연산이 "배열의 1번째 값부터 3번째 값 까지의 합을 구하세요" 라고 주어지게 된다면 우리는 세그먼트 트리의 1번 노드에서부터 범위를 2개로 나눠가면서 왼쪽 범위에 대한 탐색, 오른쪽 범위에 대한 탐색을 진행할 것이다. 이 때, 오른쪽 범위로 노드 번호는 '3'번일 것이고, 범위는 '3 ~ 4' 로 호출된 상태일 것이다.(위의 재귀그림 참고) 그런데, 배열의 첫 번째 값부터 3번째 값 까지는 사실상 0번 Index ~ 2번 Index까지의 합을 구하라는 것을 의미한다. 이 때, 3번 Index와 4번 Index에 대한 계산을 해놓은, 3번 노드에 대해서는 우리가 원하는 값은 나오지 않을 것이기에 더 이상의 탐색이 필요 없다.
2번 같은 경우는 첫 번째 값부터 세번째 값 까지의 합을 구하라고 했는데, 2번 노드로 온 경우이다. 위의 그림에서 빨강색 글씨는 '노드번호' 를 의미한다. 여기서 2번 노드의 '6'은 배열의 0번 Index ~ 2번 Index까지의 구간을 포함하고 있는 노드이다. 이 경우에는, "우리가 탐색하고 있는 범위가, 우리가 구하고자 하는 범위에 완전히 속해 있다" 라고 말할 수 있고, 더 이상의 탐색을 하지 않고 그대로 그 Node의 값을 return 시켜주면 된다. 3번 경우는 일부만 걸친 경우를 의미한다. 예를들어서 "배열의 3번째 값 부터 4번째 값 까지의 합을 구하세요" 라는 연산이 주어졌다고 생각하자. 물론, 실제로 배열에서는 2번 Index부터 3번 Index까지의 합을 구해야 하는 것이다. 이 때, 2번 노드(6의 값을 가진 노드) 로 왔고 포함하고 있는 범위는 배열의 "0번 Index ~ 2번 Index". 그런데 "2번 Index ~ 3번 Index"까지의 합을 구하는 범위랑 걸쳐있는데 이 경우에는 마찬가지로 왼쪽자식과 오른쪽 자식으로 더 깊은 탐색을 진행해야 한다.
int Sum(int _node, int _start, int _end, int _left, int _right)
{
if (_left > _end ||
_right < _start)
return 0;
if (_left <= _start &&
_end <= _right)
return g_vec_segment_tree[_node];
int mid = (_start + _end) / 2;
int left_res = Sum(_node * 2, _start, mid, _left, _right);
int right_res = Sum(_node * 2, _start, mid, _left, _right);
}
세그먼트 트리 2번 연산 (값 바꾸기)
배열에서는 특정 Index의 값을 바꾸라고 하면 너무나도 간단하게 바뀌었지만, 세그먼트 트리에서는 그럴 수가 없다.
왜냐하면, 구간에 따른 연산 결과를 저장해 놓은 트리인데, 여기서 하나의 값이 바뀌게 되면, 그 값에 영향을 끼치는 모든 노드들의
값을 바꿔줘야 하기 때문이다. 지금부터는 이 과정을 알아보자.
우리는 먼저 크게 2가지의 경우로 나눠서 생각할 수 있다.
- 바꾸고자 하는 Index값이, 현재 우리가 탐색하는 범위내에 속해있는 경우.
- 바꾸고자 하는 Index값이, 현재 우리가 탐색하는 범위내에 속해있지 않은 경우.
2번 같은 경우에는 더 이상의 탐색을 하지 않아도 된다. 왜냐하면 어차피 우리가 원하는 Index는 탐색하고 있는 범위 내에
속해있지 않은데, 자식노드로 더 깊게 들어간다고 해서 해당 Index가 갑자기 나올 수 있을까 ? 절대 그렇지 않다.
따라서 2번 같은 경우는 그대로 탐색을 종료해버리면 된다.
그럼 1번의 경우를 생각해보자. 우리가 사용하는 배열(Arr[] = { 1, 2, 3, 4, 5 }) 에서 "2번째 값을 5로 바꾸세요 !" 라는 연산이
주어졌다고 생각해보자. 2번째 값이라는 것은 배열에서는 '1번 Index'를 의미한다.
1번 노드에서 부터 탐색을 시작할 것이다. 1번노드에서, 2번 노드와 3번 노드로 갈 것이다. 3번 노드는 (3번 Index ~ 4번 Index)에 대한 정보를 가지고 있는 노드이다. 1번 Index가 완전히 속해있지 않은 경우이다. 이 경우에는 굳이 6, 7번 노드 까지 들어가볼 필요 없이 그대로 탐색을 종료해주면 된다.
2번 노드는 (0번 ~ 2번Index)에 대한 정보를 가지고 있는 노드이고, 우리가 찾는 Index가 해당 범위에 속해있으니, 값을 바꿔줘야 한다. Arr[1] = 2 의 값을 '5'로 바꾸고 싶어한다. 마찬가지로 2번 노드에도 + 3을 더해주면 되는 것이다. 그 후 더 깊은 탐색을 진행해야 한다. 4번노드와 5번 노드로 갈 것이고 이 때 5번노드 더 이상의 탐색이 필요없다 왜냐하면 2번 Index에 대한 정보를 가지고있는 노드이므로, 1번 Index를 찾고 있는 나와는 무관한 노드이다. 4번 노드는 0번 ~ 1번Index 에 대한 정보를 가지고 있는 노드이고, 우리가 찾고자 하는 Index가 이 안에 속해있다. 8번 노드는 리프노드이고, 우리가 원하는 정보를 가진 노드가 아니라서 종료한다.
#include <iostream>
#include <vector>
using namespace std;
int n;
int* g_arr = nullptr;
class Segment_tree
{
vector<int>m_vec_segment_tree;
public:
Segment_tree() = default;
Segment_tree(int _size)
{
m_vec_segment_tree.resize(_size);
}
~Segment_tree()
{
cout << "소멸자 호출" << endl;
delete[] g_arr;
}
public:
int Init(int _node, int _start, int _end)
{
if (_start == _end)
return m_vec_segment_tree[_node] = g_arr[_start];
int mid = (_start + _end) / 2;
int left_res = Init(_node * 2, _start, mid);
int right_res = Init(_node * 2 + 1, mid + 1, _end);
return m_vec_segment_tree[_node] = left_res + right_res;
}
int Sum(int _node, int _start, int _end, int _left, int _right)
{
if (_left > _end ||
_right < _start)
return 0;
if (_left <= _start &&
_end <= _right)
return m_vec_segment_tree[_node];
int mid = (_start + _end) / 2;
int left_res = Sum(_node * 2, _start, mid, _left, _right);
int right_res = Sum(_node * 2, _start, mid, _left, _right);
}
void Update(int _node, int _start, int _end, int _index, int _diff)
{
if (_index<_start ||
_index>_end)
return;
m_vec_segment_tree[_node] = m_vec_segment_tree[_node] + _diff;
if(_start!=_end)
{
int mid = (_start + _end) / 2;
Update(_node * 2, _start, mid, _index, _diff);
Update(_node * 2 + 1, mid + 1, _end, _index, _diff);
}
}
};
int main()
{
cin >> n;
g_arr = new int[n];
for (int i = 0; i < n; i++)
cin >> g_arr[i];
int tree_height = (int)ceil(log2(n));
int tree_size = (1 << (tree_height + 1));
Segment_tree segment_tree(tree_size);
segment_tree.Init(1, 0, n - 1);
// 세그먼트 트리 업데이트
int index = 1, value = 5;
int diff = value - g_arr[index];
g_arr[index] = value;
segment_tree.Update(1, 0, n - 1, index, diff);
}
'CS > 자료구조 & 알고리즘' 카테고리의 다른 글
C++ 그래프를 이용한 BFS / DFS 계속 업데이트할 예정 (0) | 2022.06.18 |
---|---|
BOIDS (군중 알고리즘) RTS AI (0) | 2022.06.18 |
이터레이터 Iterator (반복자) (0) | 2022.06.16 |
선 중 후 순회 트리 (0) | 2022.06.15 |
깊이 우선 탐색 BFS / 너비 우선 탐색 BFS 개념 (0) | 2022.06.14 |