Hanbit the Developer
[Python] 백준 2887번: 행성 터널 본문
https://www.acmicpc.net/problem/2887
Union-find를 이용한 Kruskal 알고리즘(https://rccode.tistory.com/entry/Python-%EB%B0%B1%EC%A4%80-1197%EB%B2%88-%EC%B5%9C%EC%86%8C-%EC%8A%A4%ED%8C%A8%EB%8B%9D-%ED%8A%B8%EB%A6%AC)을 알고 있다는 가정 하에 작성하였다.
이 문제는 유니온 파인드를 조금 응용한 문제이다. 시간 제한이 1초인데, N의 상한이 100,000이라는 점에서 유니온 파인드를 최적화해야한다는 것을 눈치채야한다.
먼저 평범한 방식으로 구현했을 때의 get_info() 함수이다. info는 크루스칼에서 edge의 정보를 담당한다.
def get_info(positions, N):
info = []
for i in range(N):
for j in range(i+1, N):
info.append((get_distance(positions[i], positions[j]), i, j))
info.sort(key=lambda x: x[0])
return info
정말 전형적인 코드이며, 이렇게 풀면 무조건 시간초과가 난다.
비용이 min(|xA-xB|, |yA-yB|, |zA-zB|)이라는 점에서 시작된 발상인데, 각 좌표별로 정렬하여 인접한 행성을 간선에 추가하는 것이다.
예제 입력 1을 좌표별로 정렬하면 다음과 같은 결과가 나온다.x: [[-1, -1, -5], [10, -4, -1], [11, -15, -15], [14, -5, -15], [19, -4, 19]]
y: [[11, -15, -15], [14, -5, -15], [10, -4, -1], [19, -4, 19], [-1, -1, -5]]
z: [[11, -15, -15], [14, -5, -15], [-1, -1, -5], [10, -4, -1], [19, -4, 19]]
이것의 결과 코드는 다음과 같다.
def get_info():
info = []
for k in range(3):
positions.sort(key=lambda x: x[k])
for i in range(1, N):
info.append([positions[i], positions[i-1],
abs(positions[i][k] - positions[i-1][k])])
info.sort(key=lambda x: x[2])
return info
나머지는 쉽다. 크루스칼을 진행해주면 된다!
import sys
input = sys.stdin.readline
def get_info():
info = []
for k in range(3):
positions.sort(key=lambda x: x[k])
for i in range(1, N):
info.append([positions[i], positions[i-1],
abs(positions[i][k] - positions[i-1][k])])
info.sort(key=lambda x: x[2])
return info
def find_parent(x):
if parents[x] == x:
return x
parents[x] = find_parent(parents[x])
return parents[x]
def union(x, y):
px, py = find_parent(x), find_parent(y)
if sum(px) < sum(py):
parents[py] = px
else:
parents[px] = py
def get_total_cost():
info = get_info()
total_cost = 0
linked = 0
for pos1, pos2, cost in info:
if find_parent(tuple(pos1)) != find_parent(tuple(pos2)):
union(tuple(pos1), tuple(pos2))
total_cost += cost
linked += 1
if linked >= N-1:
break
return total_cost
if __name__ == '__main__':
# Input
N = int(input())
positions = [list(map(int, input().split())) for _ in range(N)]
# For union-find
parents = {tuple(pos): tuple(pos) for pos in positions}
# Solution
total_cost = get_total_cost()
print(total_cost)
'Algorithm > 백준' 카테고리의 다른 글
[Python] 백준 9527번: 1의 개수 세기 (0) | 2021.09.17 |
---|---|
[Java] 백준 2565번: 전깃줄 (0) | 2021.09.15 |
[C] 백준 2342번: Dance Dance Revolution (0) | 2021.09.13 |
[Python] 백준 13925번: 수열과 쿼리 13(다이아 문제 첫 성공) (0) | 2021.09.11 |
[Python] 백준 2162번: 선분 그룹 (0) | 2021.09.10 |