Hanbit the Developer

[Python] 백준 2887번: 행성 터널 본문

Algorithm/백준

[Python] 백준 2887번: 행성 터널

hanbikan 2021. 9. 14. 11:48

https://www.acmicpc.net/problem/2887

 

2887번: 행성 터널

첫째 줄에 행성의 개수 N이 주어진다. (1 ≤ N ≤ 100,000) 다음 N개 줄에는 각 행성의 x, y, z좌표가 주어진다. 좌표는 -109보다 크거나 같고, 109보다 작거나 같은 정수이다. 한 위치에 행성이 두 개 이

www.acmicpc.net

 

Union-find를 이용한 Kruskal 알고리즘(https://rccode.tistory.com/entry/Python-%EB%B0%B1%EC%A4%80-1197%EB%B2%88-%EC%B5%9C%EC%86%8C-%EC%8A%A4%ED%8C%A8%EB%8B%9D-%ED%8A%B8%EB%A6%AC)을 알고 있다는 가정 하에 작성하였다.

 

이 문제는 유니온 파인드를 조금 응용한 문제이다. 시간 제한이 1초인데, N의 상한이 100,000이라는 점에서 유니온 파인드를 최적화해야한다는 것을 눈치채야한다.

 

먼저 평범한 방식으로 구현했을 때의 get_info() 함수이다. info는 크루스칼에서 edge의 정보를 담당한다.

def get_info(positions, N):
    info = []

    for i in range(N):
        for j in range(i+1, N):
            info.append((get_distance(positions[i], positions[j]), i, j))
    info.sort(key=lambda x: x[0])

    return info

정말 전형적인 코드이며, 이렇게 풀면 무조건 시간초과가 난다.

 

비용이 min(|xA-xB|, |yA-yB|, |zA-zB|)이라는 점에서 시작된 발상인데, 각 좌표별로 정렬하여 인접한 행성을 간선에 추가하는 것이다.

예제 입력 1을 좌표별로 정렬하면 다음과 같은 결과가 나온다.x: [[-1, -1, -5], [10, -4, -1], [11, -15, -15], [14, -5, -15], [19, -4, 19]]

y: [[11, -15, -15], [14, -5, -15], [10, -4, -1], [19, -4, 19], [-1, -1, -5]]

z: [[11, -15, -15], [14, -5, -15], [-1, -1, -5], [10, -4, -1], [19, -4, 19]]

 

이것의 결과 코드는 다음과 같다.

 

def get_info():
    info = []

    for k in range(3):
        positions.sort(key=lambda x: x[k])

        for i in range(1, N):
            info.append([positions[i], positions[i-1],
                         abs(positions[i][k] - positions[i-1][k])])

    info.sort(key=lambda x: x[2])

    return info

나머지는 쉽다. 크루스칼을 진행해주면 된다!

 

import sys
input = sys.stdin.readline


def get_info():
    info = []

    for k in range(3):
        positions.sort(key=lambda x: x[k])

        for i in range(1, N):
            info.append([positions[i], positions[i-1],
                         abs(positions[i][k] - positions[i-1][k])])

    info.sort(key=lambda x: x[2])

    return info


def find_parent(x):
    if parents[x] == x:
        return x

    parents[x] = find_parent(parents[x])
    return parents[x]


def union(x, y):
    px, py = find_parent(x), find_parent(y)

    if sum(px) < sum(py):
        parents[py] = px
    else:
        parents[px] = py


def get_total_cost():
    info = get_info()
    total_cost = 0
    linked = 0

    for pos1, pos2, cost in info:
        if find_parent(tuple(pos1)) != find_parent(tuple(pos2)):
            union(tuple(pos1), tuple(pos2))
            total_cost += cost

            linked += 1
            if linked >= N-1:
                break

    return total_cost


if __name__ == '__main__':
    # Input
    N = int(input())
    positions = [list(map(int, input().split())) for _ in range(N)]

    # For union-find
    parents = {tuple(pos): tuple(pos) for pos in positions}

    # Solution
    total_cost = get_total_cost()
    print(total_cost)