Hanbit the Developer

[Python] 백준 10830번: 행렬 제곱 본문

Algorithm/백준

[Python] 백준 10830번: 행렬 제곱

hanbikan 2021. 7. 21. 12:23

https://www.acmicpc.net/problem/10830

 

10830번: 행렬 제곱

크기가 N*N인 행렬 A가 주어진다. 이때, A의 B제곱을 구하는 프로그램을 작성하시오. 수가 매우 커질 수 있으니, A^B의 각 원소를 1,000으로 나눈 나머지를 출력한다.

www.acmicpc.net

 

import sys
input = sys.stdin.readline


def getNSquaredMatrix(n):
    # Base Case
    if n == 1:
        return matrix

    # Recursive Case
    halfSquared = getNSquaredMatrix(n//2)
    if n % 2 == 0:
        return squareMatrix(halfSquared, halfSquared)
    else:
        return squareMatrix(squareMatrix(halfSquared, halfSquared), matrix)


def squareMatrix(matrixA, matrixB):
    return [[squareElement(i, j, matrixA, matrixB) for j in range(N)] for i in range(N)]


def squareElement(x, y, matrixA, matrixB):
    return sum(matrixA[x][i] * matrixB[i][y] for i in range(N)) % 1000


if __name__ == '__main__':
    # 입력
    N, B = map(int, input().split())
    matrix = [list(map(int, input().split())) for _ in range(N)]

    # 처리
    nSquaredMatrix = getNSquaredMatrix(B) if B >= 2 else [
        [matrix[i][j] % 1000 for j in range(N)] for i in range(N)]

    # 출력
    for row in nSquaredMatrix:
        print(*row)

 

먼저 행렬을 곱하는 것은, 모든 원소들을 순차적으로 곱하여 더하는 과정을 거친다. 아래의 행렬을 곱한다고 했을 때, 6에 해당하는 원소의 곱은 사진과 같이 행해진다.

이렇게 모든 원소에 대해 계산해주는 것이 행렬의 곱셈이다. 위의 과정을 마치면 아래와 같이 된다.

 

다음으로, 행렬 제곱을 이해하기 위해 예제 입력 2를 예시로 설명하겠다. 해당 행렬의 세제곱을 구한다는 것은 다음 사진의 연산을 의미한다.

이를 전개하면 다음과 같다.

이것을 곱하여 최종적으로 예제 출력 2와 같이 되는 것이다.

 

우선 두 행렬을 제곱하는 함수, squareMatrix()를 작성한다. 이것을 위해서 두 행렬의 특정 원소의 곱을 얻어오는 squareElement() 함수가 필요하다. 이 두 함수는 모두, 앞서 설명한 개념을 기반으로 짤 수 있다. 다만, 원소값을 구할 때 1000으로 나눈 나머지를 가져와야한다.

 

다음으로, B의 값의 범위가 너무 크므로 일일히 곱해서는 시간복잡도로 문제를 pass할 수 없다. 따라서 우리는 분할정복을 통해 문제를 해결한다.

가령, B가 6일 때는 아래와 같이 진행된다.

6제곱을 구하기 위해선 3제곱 * 3제곱을 해야한다. 여기서 중요한 것은 3제곱을 재귀를 통해 한 번 구했다는 것은, 3제곱의 메트릭스를 알게 되었다는 것이다. 따라서 3제곱을 구하는 행동을 1번만 수행하고, 그 결과인 세제곱 행렬(halfSquared)에 대해 squareMatrix()를 수행해주면 된다는 것이다.

이것을 기반으로 코드를 작성하는데, base case는 n이 1일 때이고, n이 짝수일 때와 홀수일 때를 구분하는 것만 지키면 된다.

 

마지막으로 반례를 대비해야한다. 아래 코드를 보자.

    # 처리
    nSquaredMatrix = getNSquaredMatrix(B) if B >= 2
    	else [[matrix[i][j] % 1000 for j in range(N)] for i in range(N)]

B가 1일 때, 행렬을 그대로 복사하되 1000으로 나눈 나머지를 넣는 것을 알 수 있다. getNSquaredMatrix()의 base case 덕분에 정상적인 값을 얻을 수 있음에도 불구하고 이렇게 하는 것은, 아래의 반례 때문이다.

2 1
1000 1000
1000 1000

또, B가 2 이상일 때는 squareMatrix를 1번이라도 수행하므로 1000으로 나누는 연산이 수행되기 때문에 B가 1일 때만 신경써주면 된다.