Hanbit the Developer

[C++] 백준 14942번: 개미(Sparse Table) 본문

Algorithm/백준

[C++] 백준 14942번: 개미(Sparse Table)

hanbikan 2022. 1. 7. 19:08

https://www.acmicpc.net/problem/14942

 

14942번: 개미

자연수 n이 주어진다. n은 방의 개수이다. (1 ≤ n ≤ 105) 다음 n개의 줄에는 차례대로 현재 각각의 개미가 보유하고 있는 에너지 값이 주어진다. i+1번째 줄에는 i번째 방에 있는 개미가 가진 에너

www.acmicpc.net

 

 > 접근

각 노드에서 1까지 얼마나 걸리는지를 일일히 측정하면 N이 100000이므로 TLE에 걸린다. 따라서 '희소테이블'을 통해 1까지의 빠른 접근을 구현한다.

 

 > 풀이

우선 그래프를 마치 '무빙워크'처럼, 1로 향하게 되는 방향그래프로 만들어준다. 이를 위해서 1에서부터 시작하는 BFS를 한번 해준다. 이 때 이동 경로는 '무빙워크'의 반대방향일 것이다. 따라서 지나오면서 이용한 edge들을 모두 삭제해준다.

또한, 현재 위치로부터 얼마나 이동해야 1에 도달하는지를 저장(배열 level)함으로써, 희소테이블을 효율적으로 이용할 수 있게 된다.

 

 > 희소테이블

위 사진에서, 2번 노드에서 16번 이동하면 어디에 도달하게 되는지를 고려해보자. 일일히 세면서 계산을 해보면 6번 노드에 도달하게 된다는 것을 알 수 있게 된다.

그럼, 32번 이동하면 어디에 도달하게 되는가? 이 때 우리가 노드 6에서 16번 이동하면 4번에 도달하게 된다는 정보를 알고 있다고 해보자. 이렇게 되면 매우 쉬워진다.

2번 노드에서 16번 이동: 6번 노드
6번 노드에서 16번 이동: 4번 노드

위 정보를 조합하면 4번 노드라는 결론에 다다르게 된다!

 

정보를 미리 '희소테이블'이라는 곳에 저장한 뒤, 이것을 우리는 활용하게 된다. table[n][i]은, 노드 i에서 2^n번 이동했을 때 도달하게 되는 곳을 저장한다.

 

테이블을 구성하는 함수는 다음과 같다.

void setTable() {
	int i;
	for (i = 1; i < n + 1; i++) {
		table[0][i] = graph[i].at(0);
	}

	int k;
	for (k = 1; k < MAX_LOG; k++) {
		for (i = 1; i < n + 1; i++) {
			int midNode = table[k - 1][i].first;
			int midCost = table[k - 1][i].second;

			table[k][i].first = table[k - 1][midNode].first;
			table[k][i].second = table[k - 1][midNode].second + midCost;
		}
	}
}

 

이 문제는 희소테이블을 응용한 것이어서 cost의 개념이 추가가 된 것이다. 보통은 노드만 저장해준다. 하지만 아래의 사진을 참고하면 응용을 이해하는 것이 그리 어렵진 않을 것으로 생각한다.

그리고 테이블은 다음과 같이 초기화된다.

pair<int, int> table[MAX_LOG][100001];

이 때 MAX_LOG는 log2(100000)을 한 값으로, 16이 들어가게 된다.

 

 

이렇게 테이블을 구성했으면 2의 배수만큼 이동했을 때 어디에 도달하게 되며, 그 때의 비용은 얼마나 드는지를 알 수 있다. 그럼 특정 노드에서 100000번 이동했을 때는 어떻게 해야하는가? 이 때는 값을 2의 배수로 나누어주면 된다.

100000 = 65536 + 32768 + 1024 + 512 + 128 + 32

이 문제의 최댓값을 넣어줬음에도 불구하고 테이블을 6번만 참조하면 된다!

 

node에서 k번 이동했을 때의 노드와 비용을 반환해주는 함수는 다음과 같다.

pair<int, int> moveAndGetPair(int node, int k) {
	int curNode = node;
	int curCost = 0;
	int i;
	for (i = MAX_LOG; i >= 0; i--) {
		if ((k & (1 << i)) != 0) {
			curCost += table[i][curNode].second;
			curNode = table[i][curNode].first;
		}
	}

	return {curNode, curCost};
}

 

 

이제, moveAndGetPair() 함수를 활용만 해주면 된다.

solution() 함수에서 이를 활용해주는데, 여기에 설명을 덧붙이자면 이진탐색에서 bisect_right 방식을 이용하여, '비용이 개미가 지닌 에너지 이하이면서, 가장 많이 이동하여 도달할 수 있는 노드'를 찾게 된다.

 

#include <bits/stdc++.h>
#define MAX_LOG 16 //log2(100000)
using namespace std;

int n;
int energy[100001];
vector<pair<int, int>> graph[100001];
pair<int, int> table[MAX_LOG][100001];
int level[100001];

void setGraphSparseAndLevel() {
	level[1] = 0;

	queue<int> q;
	q.push(1);

	int i, cur, next, curLevel = 1;
	while (!q.empty()) {
		queue<int> nq;

		while (!q.empty()) {
			cur = q.front();
			q.pop();

			vector<int> toErase;

			for (i = 0; i < graph[cur].size(); i++) {
				next = graph[cur].at(i).first;

				if (level[next] == 0 && next != 1) {
					nq.push(next);
					level[next] = curLevel;

					toErase.push_back(i);
				}
			}

			// 1부터 시작하는 edges들을 모두 지움
			for (i = toErase.size()-1; i >= 0; i--) {
				graph[cur].erase(graph[cur].begin() + toErase.at(i));
			}
		}

		q = nq;
		curLevel++;
	}

	graph[1].push_back({ 1,0 });
}

void setTable() {
	int i;
	for (i = 1; i < n + 1; i++) {
		table[0][i] = graph[i].at(0);
	}

	int k;
	for (k = 1; k < MAX_LOG; k++) {
		for (i = 1; i < n + 1; i++) {
			int midNode = table[k - 1][i].first;
			int midCost = table[k - 1][i].second;

			table[k][i].first = table[k - 1][midNode].first;
			table[k][i].second = table[k - 1][midNode].second + midCost;
		}
	}
}

pair<int, int> moveAndGetPair(int node, int k) {
	int curNode = node;
	int curCost = 0;
	int i;
	for (i = MAX_LOG; i >= 0; i--) {
		if ((k & (1 << i)) != 0) {
			curCost += table[i][curNode].second;
			curNode = table[i][curNode].first;
		}
	}

	return {curNode, curCost};
}

void solution() {
	setGraphSparseAndLevel();
	setTable();

	int i;
	for (i = 1; i < n + 1; i++) {
		int left = 0;
		int right = level[i];

		while (left <= right) {
			int mid = (left + right) / 2;

			pair<int, int> pair = moveAndGetPair(i, mid);
			if (pair.second <= energy[i]) {
				left = mid + 1;
			}
			else {
				right = mid - 1;
			}
		}

		cout << moveAndGetPair(i, right).first << "\n";
	}
}

int main()
{
	ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);

	cin >> n;
	
	int i;
	for (i = 1; i < n+1; i++) {
		cin >> energy[i];
	}

	int a, b, c;
	for (i = 0; i < n - 1; i++) {
		cin >> a >> b >> c;
		graph[a].push_back({ b, c });
		graph[b].push_back({ a, c });
	}

	solution();

	return 0;
}