PS

[Python] 백준 4008 - 특공대

kkigon 2025. 3. 24. 17:15

1부터 n까지 번호가 붙여진 n명의 병사들로 이루어진 군대의 지휘관이 있다. 이 지휘관은 앞으로의 전투를 위하여 n명의 병사들을 여러 개의 특공대로 나누고자 한다. 결속력과 사기를 높이기 위하여 각 특공대는 {i, i+1, ..., i+k}형태의 번호가 연속하는 병사들로 구성된다.

각 병사 i의 전투력은 xi이다. 병사들 {i, i+1, ..., i+k}로 구성된 특공대의 전투력 x는 원래는 각 병사의 전투력의 합으로 계산되었다. 달리 말하면 x = xi + xi+1 + ... + xk이었다.

그러나 여러 해의 영광스러운 승리를 통하여 특공대의 전투력을 다음과 같이 조정해야 하는 것으로 결론을 내렸다: 특공대의 조정된 전투력 x′는 등식 x′ = ax^2 + bx + c로 계산한다. 여기서 a, b, c는 알려져 있는 계수들로서 a<0이고, x는 특공대의 원래 정의된 전투력이다.

여러분이 할 일은 모든 특공대의 조정된 전투력의 합을 최대화하도록 병사들을 특공대로 나누는 것이다.

예를 들어, 4명의 병사들이 있고, 각 병사의 전투력 x1 = 2, x2 = 2, x3 = 3, x4 = 4라 하자. 특공대의 조정된 전투력 등식에 있는 계수가 a=-1, b=10, c=-20이라 하자. 이러한 경우, 최적인 해는 병사들을 다음과 같이 세 개의 특공대로 나누는 것이다: 첫 번째 특공대는 병사 1과 2로 구성하고, 두 번째 특공대는 병사 3으로 구성하고, 세 번째 특공대는 병사 4로 구성한다. 이들 세 특공대의 원래의 전투력은 각각 4, 3, 4이고 조정된 전투력은 각각 4, 1, 4이다. 이렇게 나눌 때 조정된 전체 전투력은 각 특공대의 조정된 전투력의 합인 9이며, 이보다 더 좋은 해가 없음을 알 수 있다.

입력

입력은 세 줄로 구성된다. 첫 번째 줄에 전체 병사들 수인 양의 정수 n이 주어진다. 두 번째 줄에 특공대의 조정된 전투력 계산 등식의 계수인 세 정수 a, b, c가 주어진다. 마지막 줄에 병사들 1, 2, ..., n의 전투력을 나타내는 n개의 정수 x1, x2, ..., xn이 공백을 사이에 두고 주어진다.

n ≤ 1,000,000, -5 ≤ a ≤ -1, |b| ≤ 10,000,000, |c| ≤ 30,000,000, 1 ≤ xi ≤ 100

출력

얻을 수 있는 최대의 조정된 전체 전투력을 나타내는 하나의 정수를 한 줄에 출력한다.

 


 

풀이

Convex Hull Trick의 바이블과도 같은 문제라고 한다.

CHT라는건 O(n^2)의 dp로 풀 수도 있다는 것.

 

dp[i]에 i번째 병사까지 계산을 진행하였을 때의 전투력의 최대값을 저장하면 될 것이다.

 

그러면 다음과 같이 이중 포문을 사용하여 쉽게 코딩을 할 수 있다.

 

N = int(input())
a, b, c = map(int, input().split())
arr = list(map(int, input().split()))
S = [0]
for i in range(N):
    S.append(S[-1] + arr[i])

dp = [0]*(N+1)

#원래는 이렇게 O(n**2) 안에 짤 수 있음
for i in range(1, N+1):
    tmp = -1298312798123
    for j in range(i+1):
        x = S[i] - S[j]
        tmp = max(tmp, dp[j] + a*x**2 + b*x + c)
    dp[i] = tmp
print(dp)

 

이 문제에서는 연속된 수들의 합을 빠르게 계산해야하므로

누적합 배열을 이용하였다.

 

당연히 시간초과가 나는 풀이이다. n이 백만이므로 우리는 CHT를 이용해서 시간복잡도는 O(N)까지 줄여야 할 필요가 있다.

보통 CHT에서 이분탐색을 해서 시간복잡도는 O(N log N)이 되어야하는 것 아닌가 싶은데 O(N)까지 줄어들 수 있는 원리는 이후에 설명하겠다.

 


CHT 적용을 위한 관찰

 

dp값을 업데이트 하는 것을 잘 보자.

x = S[i] - S[j]
tmp = max(tmp, dp[j] + a*x**2 + b*x + c)

 

분명 CHT는 여러 개의 일차함수에서 적용되는데 여기서는 이차식이 눈에 들어온다.

엥?

 

여기서 막힐 수 있지만 한 번 x를 직접 대입해보자.

 

복잡한 계산

약간의 복잡한 대입연산을 거치면 max 함수 안에는 일차함수 꼴만 남는다! 이게 잘 안보인다.

 

쿼리 X는 S[i]로 놓으면 되는 것이다.

 

여기서 시간복잡도에 대한 질문을 해결 할 수 있는데, S는 누적 합이므로 S[i]는 i가 증가함에 따라 계속 증가한다.

즉, 굳이 이분탐색을 해줄 필요가 없다!

그래서 함수들을 관리해주는 O(N)만 들게 되는 것이다.

 

누적함수를 잡아주었으므로 인덱스는 1 늘게 된다.

따라서 dp는 이렇게 잡는다.

 

dp = [0]*(N+1)

 

그리고 늘 그랬듯이 일차함수의 교차점을 구하는 함수를 정의한다.

 

# 일차함수의 교차점 x좌표 찾는 함수
def cross(f, g):
    return (g[1] - f[1])/(f[0] - g[0])

 

이제 함수들을 관리할 stack을 deque를 이용하여 만든다. 처음에는 초기함수 [0, 0, 0]이 들어있다.

# 함수는 [기울기, y절편, 시작점의 x좌표]
stack = deque([[0, 0, 0]])  # f들 관리하는 스택

 

그리고, i는 1부터 N+1까지 다음을 반복한다.

 

1. 쿼리 처리

2. 새로운 함수 추가

 

쿼리를 처리할 때 약간 고려해주어야 할 점이 있는데, 바로 이 문제의 메모리 제한이 64MB로 매우 작다는 것이다.

그렇기 때문에, 계속 증가하는 S[i]의 특징을 이용하여, 더이상 필요가 없는 함수들은 스택에서 pop해줌으로써 메모리를 아끼는 방법을 쓰는 것이다.

 

for i in range(1, N+1):
    X = S[i]
    while len(stack) >= 2 and stack[1][2] < X:
        stack.popleft()

    stack_index = 0
    dp[i] = X * stack[stack_index][0] + stack[stack_index][1] + a*X**2 + b*X + c

 

필요가 없는 함수들을 다 빼주고 다면 쿼리에 해당되는 함수는 스택의 첫 번째 원소가 될 것이라는 느낌이다.

 

새로운 함수를 추가하는 과정은 다른 CHT 문제들과 다를 것이 없다.

# 새로운 함수를 추가해보자.
f = [-2 * a * S[i], a * S[i] ** 2 - b * S[i] + dp[i], 0]
tmp = cross(f, stack[-1])
f[2] = tmp
stack.append(f)

# 그런데 쓸모없는 함수가 생길지도 모르니까
while len(stack) >= 3 and cross(stack[-3], stack[-2]) > cross(stack[-2], stack[-1]):
    tmp = stack.pop()
    stack.pop()
    stack.append(tmp)
stack[-1][2] = cross(stack[-1], stack[-2])

 

이제 전체 코드이다.

 

from collections import deque

N = int(input())
a, b, c = map(int, input().split())
arr = list(map(int, input().split()))
S = [0]
for i in range(N):
    S.append(S[-1] + arr[i])

dp = [0]*(N+1)

# 일차함수의 교차점 x좌표 찾는 함수
def cross(f, g):
    return (g[1] - f[1])/(f[0] - g[0])

# 함수는 [기울기, y절편, 시작점의 x좌표]
stack = deque([[0, 0, 0]])  # f들 관리하는 스택

for i in range(1, N+1):
    X = S[i]
    while len(stack) >= 2 and stack[1][2] < X:
        stack.popleft()

    stack_index = 0
    dp[i] = X * stack[stack_index][0] + stack[stack_index][1] + a*X**2 + b*X + c
    # 새로운 함수를 추가해보자.
    f = [-2 * a * S[i], a * S[i] ** 2 - b * S[i] + dp[i], 0]
    tmp = cross(f, stack[-1])
    f[2] = tmp
    stack.append(f)

    # 그런데 쓸모없는 함수가 생길지도 모르니까
    while len(stack) >= 3 and cross(stack[-3], stack[-2]) > cross(stack[-2], stack[-1]):
        tmp = stack.pop()
        stack.pop()
        stack.append(tmp)
    stack[-1][2] = cross(stack[-1], stack[-2])

print(dp[-1])

'PS' 카테고리의 다른 글

[Python] 백준 14751 - Leftmost Segment  (0) 2025.03.24
[Python] 백준 13263 - 나무 자르기  (0) 2025.03.24
[Python] 백준 9892 - Width  (0) 2025.03.23
[Python] 백준 10254 - 고속도로  (0) 2025.03.23
[Python] 백준 1708 - 볼록 껍질  (0) 2025.03.23