본문 바로가기
Problem Solving/백준

백준 13511번 | 트리와 쿼리 2 (C++ 풀이)

by kadokok 2022. 10. 2.

문제

https://www.acmicpc.net/problem/13511

 

13511번: 트리와 쿼리 2

N개의 정점으로 이루어진 트리(무방향 사이클이 없는 연결 그래프)가 있다. 정점은 1번부터 N번까지 번호가 매겨져 있고, 간선은 1번부터 N-1번까지 번호가 매겨져 있다. 아래의 두 쿼리를 수행하

www.acmicpc.net


풀이

  • 사용 알고리즘 : LCA, Sparse Table
  • 시간 복잡도 : O(M * logN)

LCA를 이용하여 풀 수 있는 문제입니다.

주의점은 N, M의 상한이 10만이므로, Naive LCA는 O(NM)이 되어 TLE를 피할 수 없습니다.
따라서 Sparse Table을 정의해서 LCA를 O(M logN)으로 최적화해주면 됩니다.
이제 문제의 두 가지 쿼리에 대해 어떤 방식으로 처리할지를 보겠습니다.

u에서 v로 가는 경로의 비용

기본적으로 정점들의 거리(https://www.acmicpc.net/problem/1761) 문제와 같은 아이디어입니다.
우리는 u에서 v로 가는 경로를 u → LCA(u, v) → v 라는 경로로 정의할 수 있습니다.
그다음은 u → LCA(u, v) → v 경로의 거리를 구해줄 것인데, 다음과 같은 배열을 이용해서 거리를 쉽게 구할 수 있습니다.

dist[a] = b; // 루트 노드에서 a 노드까지의 거리는 b 이다.

 

그림을 참고해보시면, 결과적으로 u → v 거리의 값은 dist[u] + dist[v] - 2 * dist[LCA(u, v)] 가 된다는 것을 아실 수 있습니다.

u에서 v로 가는 경로 중 k번째 정점

여기서도 마찬가지로 u에서 v로 가는 경로를 u → LCA(u, v) → v 경로로 정의해 줍시다.
그 후에 u부터 LCA(u, v)까지의 정점의 개수 cnt 를 구할 수 있다면 다음과 같이 k번째 정점을 찾을 수 있습니다.

① k == cnt 인 경우, k번째 정점은 LCA(u, v)가 됩니다.

② k < cnt 인 경우, k번째 정점은 u → LCA(u, v) 경로 안에 존재한다는 점을 알 수 있습니다. 그렇다면 기존에 만든 sparse table을 이용하여 u의 k번째 부모를 log2n번 안에 찾아낼 수 있습니다.

③ k > cnt 인 경우, k번째 정점은 v → LCA(u, v) 경로 안에 존재한다는 점을 알 수 있습니다. 그렇다면 이번에는 v의 k번째 부모를 위와 같은 방식으로 log2n번 안에 찾아낼 수 있습니다.

이때 u부터 LCA(u, v)까지의 정점의 개수는 각 노드의 깊이를 알면 구할 수 있습니다. depth[u] - depth[LCA(u, v)] + 1과 같이 구하면 됩니다.


코드

#include <bits/stdc++.h>
using namespace std;
const int INF = 2e9;
const int inf = 1e9;

int parent[100001][18]; // 노드 i의 2^j 번째 부모
int depth[100001];
bool visited[100001];
long long dist[100001];
vector<pair<int, int>> adj[100001];

void go(int node) {
	visited[node] = true;
	for (int i = 0; i < adj[node].size(); i++) {
		int next = adj[node][i].first;
		if (visited[next]) continue;
		parent[next][0] = node;
		depth[next] = depth[node] + 1;
		dist[next] = dist[node] + adj[node][i].second;
		go(next);
	}
}

int lca(int a, int b) {
	if (depth[a] < depth[b]) swap(a, b);
	int diff = depth[a] - depth[b];
	for (int i = 0; i < 18; i++) {
		if (diff & 1 << i) {
			diff -= 1 << i;
			a = parent[a][i];
		}
	}
	if (a != b) {
		for (int i = 17; i >= 0; i--) {
			if (parent[a][i] != -1 && parent[a][i] != parent[b][i]) {
				a = parent[a][i];
				b = parent[b][i];
			}
		}
		a = parent[a][0];
	}
	return a;
}

int main(void) {
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);

	int n;
	cin >> n;
	for (int i = 0; i < n - 1; i++) {
		int a, b, c;
		cin >> a >> b >> c;
		adj[a].push_back({ b, c });
		adj[b].push_back({ a, c });
	}
	go(1);
	for (int i = 1; i < 18; i++) {
		for (int j = 2; j <= n; j++) {
			parent[j][i] = parent[parent[j][i - 1]][i - 1];
		}
	}
	int m;
	cin >> m;
	while (m--) {
		int a, u, v, k;
		cin >> a;
		if (a == 1) {
			cin >> u >> v;
			int root = lca(u, v);
			cout << dist[u] + dist[v] - 2 * dist[root] << '\n';
		}
		else {
			cin >> u >> v >> k;
			int root = lca(u, v);
			int cnt = depth[u] - depth[root] + 1;
			if (cnt == k) cout << root << '\n';
			else if (cnt > k) {
				k--;
				int tmp = u;
				for (int i = 0; i < 18; i++) {
					if (k & 1 << i) {
						k -= 1 << i;
						tmp = parent[tmp][i];
					}
				}
				cout << tmp << '\n';
			}
			else {
				k = cnt + depth[v] - depth[root] - k + 1;
				k--;
				int tmp = v;
				for (int i = 0; i < 18; i++) {
					if (k & 1 << i) {
						k -= 1 << i;
						tmp = parent[tmp][i];
					}
				}
				cout << tmp << '\n';
			}
		}
	}
	return 0;
}

댓글