Hanbit the Developer

CHT(Convex Hull Trick) 알고리즘 본문

Algorithm

CHT(Convex Hull Trick) 알고리즘

hanbikan 2024. 7. 13. 18:15

정의

CHT는 특정 조건의 DP 점화식에 활용할 수 있는 최적화 기법입니다.

 

[조건]

1. dp[i] = min(a[j] * b[i] + dp[j]) (0 <= j < i)

2. 수열 a가 단조감소 한다.

 

dp[N]은 이중 포문을 돌아야 구할 수 있으므로 시간 복잡도는 O(N^2)입니다. 하지만 CHT를 활용하면 O(NlogN)으로 최적화를 할 수 있게 됩니다.

 

접근 방식

CHT에서는 dp[i]를 구할 때 '일차 함수'를 활용합니다. 아래 표에서 파란색 식을 보면 a[0] * b[i] + dp[0]의 꼴을 하고 있습니다. 여기서 b[i]를 x로 치환을 하면 a[0] * x + dp[0] 형태가 됩니다. 즉 j = 0일 때의 함수 f0 a[0] * x + dp[0]입니다. j = 1일 때의 함수 f1은 a[1] * x + dp[1]입니다.

이 함수를 통해 dp[i]의 값을 구할 수 있습니다. 가령 dp[2]를 구할 때는 f0, f1에 b[2]를 x값으로 넣은 뒤 최소값을 구하면 됩니다.

일반화하면 dp[i]를 얻기 위해선 일차함수 f0, f1, ... fi-1에 x = b[i]를 넣은 뒤 최소값을 구해주면 됩니다.

dp[i] candidates
dp[0] 0 (j = 0)
dp[1] a[0] * b[1] + dp[0]  (j = 0)
dp[2] a[0] * b[2] + dp[0]  (j = 0)
  a[1] * b[2] + dp[1]  (j = 1)
dp[3] a[0] * b[3] + dp[0]  (j = 0)
  a[1] * b[3] + dp[1]  (j = 1)
  a[2] * b[3] + dp[2]  (j = 2)


스택: 함수 저장소

코드를 어떻게 구현해야 할지를 생각해보면 함수를 저장하는 공간이 필요해보입니다. CHT는 이 저장 공간으로 스택을 활용합니다. 코드는 대략 다음과 같습니다.

stack = deque() # Python에서 stack의 역할을 함

for i in range(1, N): # 1 <= i < N
    fi = (a[i - 1], dp[i - 1]) # 일차함수 3x + 7를 (3, 7)로 표현할 수 있음
    stack.append(fi) # 스택에 이번 함수 추가
    x = b[i]
    for 0 <= j < i: # 지금까지의 함수 싹 다 돌면서 x 대입해서 최소값 찾음
        dp[i] = min(dp[i], stack[j].a * x + stack[j].b)

함수 자르기

기존에는 dp[3]을 구할 때, 아래 사진처럼 f0, f1, f2에 x = b[3]를 넣어서 가장 작은 값인 초록색 함수에서의 값을 dp[3]로 취급하게 됩니다.

하지만 아래 사진과 같이 일차 함수의 아랫 부분만을 취급하게 된다면, 여러 값을 비교할 필요 없이 곧바로 최소값을 얻을 수 있습니다.

이를 어떻게 구현할지를 생각해봅시다. 위 사진을 보면, 각 함수가 만나는 교차점을 기준으로 최소값을 갖게 되는 함수가 변경됩니다. x = 0 ~ 8에서는 f0, x = 8~16에서는 f1, x = 16~에서는 f2가 항상 최소값을 갖게 됩니다. 따라서 교차점을 기준으로 함수를 자르면 됩니다.

f1 = 0.5x + 4일 때 이 함수를 [0.5, 4, 8]과 같이 표현할 수 있습니다. 0.5는 기울기, 4는 y절편, 8은 f1이 최소가 되는 시작점의 x 좌표입니다.

*교차점 구하는 법: (f2.b - f1.b) / (f1.a - f2.a)

 

이해를 위해 함수를 추가하는 과정을 순차적으로 설명하겠습니다.

[i = 1일 때]

[i = 2일 때]

- 교차점 x 좌표가 8이므로 f1는 x = 8부터 시작한다.

[i = 3일 때]

- 추가하고자 하는 함수 f2스택에서 가장 위에 있는(가장 최근에 추가한) 함수 f1과 교차점을 구한다.

- 교차점 x 좌표가 16이므로 f2는 x = 16부터 시작한다.

 

함수를 스택에 넣는 것을 코드로 구현하면 다음과 같습니다.

f_to_add = [a[i - 1], dp[i - 1], 0] # 함수 정의. start_x는 4번째 줄에서 수정할 예정
top_f = stack[top] # 가장 최근 함수 가져옴
intersection_x = get_intersection_x(f_to_add, top_f) # 교차점 구함
f_to_add[2] = intersection_x # 이번에 추가할 함수의 시작점(x)은 그 교차점의 x 좌표임
stack.append(f_to_add) # 스택에 넣음

스택에서 최소값을 갖는 함수 찾기

이제 함수 여러 개가 스택에 있기 때문에 dp를 구하기만 하면 됩니다.

dp[3]을 구하고자 하는 상황에서 b[3] = 14라고 하면 이에 해당하는 최소 함수는 f1입니다. f1x = 8 ~ 16에서 최소값을 갖기 때문입니다.

스택에서 최소값을 갖는 함수를 찾으려면, 스택에서 start_x를 기준으로 이진 탐색을 하면 됩니다. 즉, [0, 8, 16]에서 x = b[3] = 14로 이진탐색 해서 인덱스 1을 얻으면 됩니다.

코드는 다음과 같습니다.(파이썬에선 bisect_left를 해준 뒤 값을 1 빼주면 해당 인덱스를 더욱 편하게 구할 수 있습니다.)

x = b[i]
l, r = 0, len(stack) - 1
while l <= r:
    m = (l + r) // 2
    if stack[m].start_x >= x:
    	r = m - 1
    else:
    	l = m + 1
stack_index = l - 1

중간 정리(코드)

지금까지의 내용을 코드로 작성하면 다음과 같습니다. 생각보다 간단하다는 걸 알 수 있습니다.

A, B, START_X = 0, 1, 2

def get_intersection_x(f1, f2):
    return (f2[B] - f1[B]) / (f1[A] - f2[A])

dp = [0] * n
stack = []
for i in range(1, N): # 1 <= i < N
    f_to_add = [a[i - 1], dp[i - 1], 0] # dp[i] = min(a[j] * b[i] + dp[j])

    # f_to_add의 시작점 구하기(교차점)
    top_f = stack[top] # 가장 최근 함수 가져옴
    intersection_x = get_intersection_x(f_to_add, top_f) # 교차점 구함
    f_to_add[2] = intersection_x # 이번에 추가할 함수의 시작점(x)은 그 교차점의 x 좌표임
    stack.append(f_to_add) # 스택에 넣음

    # x에 맞는 함수 찾기(이진 탐색)
    x = b[i]
    l, r = 0, len(stack) - 1
    while l <= r:
        m = (l + r) // 2
        if stack[m].start_x >= x:
        	r = m - 1
        else:
        	l = m + 1
    stack_index = l - 1

    dp[i] = stack[stack_index][A] * x + stack[stack_index][B]

스택에 함수를 추가할 때 발생하는 예외 처리

마지막으로 예외가 하나 있어서 이것만 처리해주면 구현이 끝입니다. 함수를 스택에 저장해주는 이유도 이 부분 때문입니다.

네 번째 함수 f3를 추가했는데 아래처럼 되었다고 가정하겠습니다. 이렇게 되면 f2가 의미가 없어지고 처리하기도 곤란해집니다. 만약 이대로 함수를 추가하게 되면, 각 함수들의 start_x는 [0, 8, 16, 13.6]가 됩니다.(13.6은 f3f2가 만나는 교차점 x) 정렬된 상태가 아니므로 이진탐색을 통한 최소 함수 찾기도 못하게 됩니다.

다음과 같은 과정을 통해 이 문제를 해결할 수 있습니다.

1. 추가할 함수(f3)와 스택 맨 위 함수(f2)의 교차점을 구한다.

2-1. 교차점이 스택 맨 위 함수(f2)의 start_x보다 앞에 있다면, 스택 맨 위 함수(f2)를 제거한다.(pop)

2-2. 정상적인 경우라면 반복문을 탈출한다.

3. 반복

아래 사진은 이 과정을 마친 후 스택의 모습입니다.

CHT 전체 코드

A, B, START_X = 0, 1, 2

def get_intersection_x(f1, f2):
    return (f2[B] - f1[B]) / (f1[A] - f2[A])

dp = [0] * n
stack = []
for i in range(1, N): # 1 <= i < N
    f_to_add = [a[i - 1], dp[i - 1], 0] # dp[i] = min(a[j] * b[i] + dp[j])
    # stack에서 불필요한 함수 제거
    while stack:
        top_f = stack[top]
        intersection_x = f_to_add.get_intersection_x(top_f)
        if intersection_x < top_f.start_x: # 교차점이 top_f의 시작점보다 앞에 있음
            stack.pop()
        else: # 정상
            f_to_add.start_x = intersection_x
            stack.append(f_to_add)
            break

    # x에 맞는 함수 찾기(이진 탐색)
    x = b[i]
    l, r = 0, len(stack) - 1
    while l <= r:
        m = (l + r) // 2
        if stack[m].start_x >= x:
            r = m - 1
        else:
            l = m + 1
    stack_index = l - 1

    dp[i] = stack[stack_index][A] * x + stack[stack_index][B]

 

What's more?

1. 수열 a가 단조증가하고 최대값을 구하는 경우도 적용할 수 있다.

2. 수열 b가 단조증가한다면 이진탐색 없이 O(N)으로 풀 수 있다.

 

CHT 문제 풀어보기(백준)

https://www.acmicpc.net/problemset?sort=ac_desc&algo=89