PS

[Python] 백준 2243 - 사탕상자

kkigon 2025. 4. 21. 17:10

수정이는 어린 동생을 달래기 위해서 사탕을 사용한다. 수정이는 평소에 여러 개의 사탕을 사서 사탕상자에 넣어두고, 동생이 말을 잘 들을 때면 그 안에서 사탕을 꺼내서 주곤 한다.

각각의 사탕은 그 맛의 좋고 나쁨이 1부터 1,000,000까지의 정수로 구분된다. 1이 가장 맛있는 사탕을 의미하며, 1,000,000은 가장 맛없는 사탕을 의미한다. 수정이는 동생이 말을 잘 들은 정도에 따라서, 사탕상자 안에 있는 사탕들 중 몇 번째로 맛있는 사탕을 꺼내주곤 한다. 예를 들어 말을 매우 잘 들었을 때에는 사탕상자에서 가장 맛있는 사탕을 꺼내주고, 말을 조금 잘 들었을 때에는 사탕상자에서 여섯 번째로 맛있는 사탕을 꺼내주는 식이다.

수정이가 보관하고 있는 사탕은 매우 많기 때문에 매번 사탕상자를 뒤져서 꺼낼 사탕을 골라내는 일은 매우 어렵다. 수정이를 도와주는 프로그램을 작성하시오.

입력

첫째 줄에 수정이가 사탕상자에 손을 댄 횟수 n(1 ≤ n ≤ 100,000)이 주어진다. 다음 n개의 줄에는 두 정수 A, B, 혹은 세 정수 A, B, C가 주어진다. A가 1인 경우는 사탕상자에서 사탕을 꺼내는 경우이다. 이때에는 한 정수만 주어지며, B는 꺼낼 사탕의 순위를 의미한다. 이 경우 사탕상자에서 한 개의 사탕이 꺼내지게 된다. 또, A가 2인 경우는 사탕을 넣는 경우이다. 이때에는 두 정수가 주어지는데, B는 넣을 사탕의 맛을 나타내는 정수이고 C는 그러한 사탕의 개수이다. C가 양수일 경우에는 사탕을 넣는 경우이고, 음수일 경우에는 빼는 경우이다. 맨 처음에는 빈 사탕상자에서 시작한다고 가정하며, 사탕의 총 개수는 2,000,000,000을 넘지 않는다. 또한 없는 사탕을 꺼내는 경우와 같은 잘못된 입력은 주어지지 않는다.

출력

A가 1인 모든 입력에 대해서, 꺼낼 사탕의 맛의 번호를 출력한다.

 


풀이

세그먼트 트리 응용 연습문제이다.

일단 어떻게 세그먼트 트리를 만들 것이냐?

사탕 맛에 따른 노드들을 그냥 맨 처음에 1,000,001개 만들어주고

각 노드에 들어있는 값은 현재 해당 사탕의 개수가 될 것이다.

그리고 이진 트리를 올라가면서 저장되는 값은 total 사탕의 개수가 될 것이다.

그러면 맨 위 노드(1번 노드)에는 현재 상자에 있는 전체 사탕의 개수가 되겠지???

 

문제에서 주어진 예시 입력 1을 예시로 들어 그림을 그려보자.

 

초기상태
1번사탕 2개 넣음
3번사탕 3개 넣음

여기까지 구현은 쉬울 것이다.

문제는 n번째로 맛있는 사탕을 찾는 과정이다.

 

여기서 우리는 기존에 쓰던 search() 함수를 버리고, 이분 탐색을 기반으로 한 완전히 새로운 search() 함수를 만들어주어야 할 것이다.

예를 들어 지금 이 상태에서 2번째로 맛있는 사탕을 찾는다고 하자.

 

맨 위 노드(5)부터 시작하자.

이 노드의 왼쪽, 오른쪽 노드는 각각 2, 3이라는 정보를 가지고 있다.

그러니까 우리가 찾고있는 2번째로 맛있는 사탕은 왼쪽으로 내려가야 있다.

그러면 이제 왼쪽(2) 노드를 루트 노드 삼아서 재귀를 쳐주자.

 

이런식으로 왼쪽/오른쪽으로 내려가기를 반복하다가

리프 노드에 도달하면(트리의 인덱스 비교로 알 수 있음)

해당 사탕을 찾은 것이다!

 

이제 코드를 짜보자.

맨 처음 트리를 만드는 과정, 그리고 update() 함수는 쉽고 빠르고 간편한 bottom-up 방식을 이용해주겠다(아직 블로그에 정리 안함).

import sys
input = sys.stdin.readline

N = int(input())
size = 1
while size < 1000001:
    size <<= 1
tree = [0] * (size * 2)
size -= 1

def update(idx, diff):
    idx += size
    while idx:
        tree[idx] += diff
        idx //= 2

 

이제 위에서 말한 원리를 기반으로 search() 함수를 재귀로 구현해보자(반복문으로 하는 방법도 있을 것 같은데 그냥 생각이 안나서 이렇게 bottom-up과 top-down을 혼용해서 쓰고 있다;;;)

 

def search(key, idx = 1):
    if size + 1 <= idx < 2 * (size + 1):
        return idx - size
    left = tree[idx*2]
    right = tree[idx*2+1]
    if left >= key:
        return search(key, idx*2)
    else:
        return search(key - left, idx*2+1)

사실 right = tree[idx*2+1] 부분은 필요가 없는 코드이다.

 

쿼리의 경우, 2가 주어지면 update() 함수를 한 번 적용해주면 될 것이고

1이 주어지면 search()로 사탕의 번호를 찾은 다음 출력하고, update()로 -1만큼 더해주면 될 것이다.

 

아래는 전체 코드이다.

 

import sys
input = sys.stdin.readline

N = int(input())
size = 1
while size < 1000001:
    size <<= 1
tree = [0] * (size * 2)
size -= 1

def update(idx, diff):
    idx += size
    while idx:
        tree[idx] += diff
        idx //= 2


def search(key, idx = 1):
    if size + 1 <= idx < 2 * (size + 1):
        return idx - size
    left = tree[idx*2]
    right = tree[idx*2+1]
    if left >= key:
        return search(key, idx*2)
    else:
        return search(key - left, idx*2+1)

for i in range(N):
    tmp = tuple(map(int, input().split()))
    if tmp[0] == 1:
        B = tmp[1]
        t = search(B)
        print(t)
        update(t, -1)
        #print(tree)
    else:
        B, C = tmp[1], tmp[2]
        update(B, C)
        #print(tree)