블로그 이름 뭐로 하지

[알고리즘] 백준 1238 - 파티 본문

알고리즘

[알고리즘] 백준 1238 - 파티

발등이 따뜻한 사람 2024. 1. 2. 18:14

문제

https://www.acmicpc.net/problem/1238

 

1238번: 파티

첫째 줄에 N(1 ≤ N ≤ 1,000), M(1 ≤ M ≤ 10,000), X가 공백으로 구분되어 입력된다. 두 번째 줄부터 M+1번째 줄까지 i번째 도로의 시작점, 끝점, 그리고 이 도로를 지나는데 필요한 소요시간 Ti가 들어

www.acmicpc.net

 

풀이

2달전 시도했다가 시간초과로 실패했던 문제였다. 당시 다익스트라 뽀개기를 했던걸로 기억하는데... 우선 오늘 다시 이 문제를 봤을 때는 플로이드 워셜이 떠올랐다. 그래서 플로이드로 풀었더니 시간 초과가 떴다. ㅎㅎ

n,m,x = map(int,input().split())

graph = [[1e9] * (n+1) for _ in range(n+1)]

for a in range(1, n + 1):
    for b in range(1, n + 1):
        if a == b:
            graph[a][b] = 0

for _ in range(m):
    s,e,t = map(int,input().split())
    graph[s][e] = t

for k in range(1, n + 1):
    for a in range(1, n + 1):
        for b in range(1, n + 1):
            graph[a][b] = min(graph[a][b], graph[a][k] + graph[k][b])


answer = -1
for i in range(1,n+1):
    time = graph[i][x] + graph[x][i]
    answer = max(answer,time)

print(answer)

인풋이 1000이고, 플로이드 워셜은 O^3의 시간복잡도를 가지니 시간초과가 날 거라곤 예상을 했지만... 슬펐다. 사실 플로이드 워셜이면 좀 날로 먹을 수 있으니까

암튼 그래서 다익스트라로 변경해서 풀었다. 다익스트라로 변경할 때 핵심은 heapq라고 생각하는데, 이 큐에 넣을때 무조건 거리가 앞에 오게끔 넣어야 한다는 점이다. 그래야 거리가 짧은 아이가 앞으로 정렬이 되고, 그렇게 되면 더욱 빠르고 불필요한 과정 없이 최단 경로를 찾을 수 있게 된다.

 

정답 코드

import sys
import heapq

def djikstra(s,e):
    distance = [int(1e9)]*(N+1)
    q = []
    # 이 부분에서 거리가 앞에 오게끔 넣어줘야 한다.
    # [s,0] 이런 식으로 노드번호가 앞에 오게 넣어주면 시간 초과가 뜬다.
    heapq.heappush(q,[0,s])
    distance[s] = 0
    while q:
        t,v = heapq.heappop(q)
        if distance[v] < t:
            continue
        for v1,t1 in graph[v]:
            if distance[v1] > distance[v] + t1:
                distance[v1] = distance[v] + t1
                heapq.heappush(q,[distance[v1],v1])
    return distance[e]


input = sys.stdin.readline
n,m,x = map(int,input().split())
graph = [[] for _ in range(n+1)]
for _ in range(m):
    s,e,t = map(int,input().split())
    graph[s].append([e,t])
result = 0
for i in range(1,n+1):
    if i == x:
        continue
    if result < djikstra(i,x) + djikstra(x,i):
        result = djikstra(i,x) + djikstra(x,i)
print(result)