Hanbit the Developer

[Python] 백준 2042번: 구간 합 구하기 본문

Algorithm/백준

[Python] 백준 2042번: 구간 합 구하기

hanbikan 2021. 5. 13. 18:50

https://www.acmicpc.net/problem/2042

 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

import sys
input = sys.stdin.readline


class Node:
    def __init__(self):
        self.start = None
        self.end = None
        self.left = None
        self.right = None
        self.var = None


def initializeSegmentTree(node, start, end):
    node.start = start
    node.end = end

    if start == end:  # 말단
        node.var = nums[start]
    else:
        # 자식 노드 생성하고 dfs 로직으로 초기화
        node.left = Node()
        initializeSegmentTree(node.left, start, (start+end)//2)

        node.right = Node()
        initializeSegmentTree(node.right, (start+end)//2+1, end)

        # 자식을 합함
        node.var = node.left.var + node.right.var


def modifySegmentTree(node):
    node.var += c-nums[b]  # 세그먼트 트리 값 변경

    if node.start == node.end:
        nums[b] = c  # nums도 꼭 수정을 해야함
        return

    # b가 해당 자식의 범위에 있을 시 탐색
    if node.left and node.left.start <= b <= node.left.end:
        modifySegmentTree(node.left)
    elif node.right and node.right.start <= b <= node.right.end:
        modifySegmentTree(node.right)


def setSegmentSum(node, start, end):
    global sum

    if start <= node.start and node.end <= end:
        # sum 갱신
        sum += node.var

        # 나뉘어진 새 범위에 대해서 탐색
        if not (node.start == start and node.end == end):
            if b <= node.start-1 and node.left and start <= node.left.end:
                setSegmentSum(node.left, start, node.start-1)
            if node.end+1 <= c and node.right and node.right.start <= end:
                setSegmentSum(node.right, node.end+1, end)

    else:
        if node.left and start <= node.left.end:
            setSegmentSum(node.left, start, end)
        if node.right and node.right.start <= end:
            setSegmentSum(node.right, start, end)


N, M, K = map(int, input().split())
nums = [0]
for _ in range(N):
    nums.append(int(input()))

root = Node()
initializeSegmentTree(root, 1, N)

for _ in range(M + K):
    a, b, c = map(int, input().split())

    if a == 1:
        modifySegmentTree(root)
    elif a == 2:
        sum = 0
        setSegmentSum(root, b, c)
        print(sum)

문제에서 가장 중요한 작업이 3개가 있다.

1. 세그먼트 트리 초기화

2. 세그먼트 트리 변경

3. 세그먼트 트리에서 특정 범위의 합 구하기

이 세 개의 작업을 각각 함수 하나씩, 총 3개의 함수 안에 작성하였다.

 

우선 initializeSegmentTree()이다.

크게 어렵지는 않은데, Node 클래스의 형식에 맞게 dfs로 탐색을 하면서 세그먼트 트리를 초기화 해준다.

 

다음으로 modifySegmentTree()에 대한 설명이다.

우선 첫째줄은, dfs 탐색 대상인 모든 노드에 대해 값을 변경해준다는 것이다.

그리고 말단 함수를 node.start == node.end라는 조건문으로 색출해낸다. 말단 노드에서 nums의 값을 c로 변경시켜주는데 필수적인 작업이다. 절대 세그먼트 트리만 수정해선 안 된다.

마지막으로, b가 자식의 범위에 있으면 탐색을 해준다.

 

세번째 함수, setSegmentSum()이다.

node의 범위가 start와 end의 범위 안에 속할 경우와 속하지 않을 경우를 분기한다.

전자의 경우에 수행할 연산들은 아래와 같다.

현재 세그먼트 트리의 노드의 범위가 2~3이고, start, end는 각각 1, 5이다.

찾는 범위(1~5) 안에 현재 노드(2~3)이 있으므로, 이것을 수용하고 sum을 갱신해준 뒤, 반으로 나누어 조건을 확인하고(가령, 1~4에서 노드가 2~4이면 다음 탐색은 1~1만 하면 됨) 탐색한다.

다음으로, 후자의 경우인데, 이 때는 자식 노드들의 범위를 고려하여 조건문을 걸고 탐색해주면 된다.