[백준] 1377 버블 소트
문제 설명
문제를 살펴보면 제목이 제시하듯 버블 소트 코드가 주어지고, 출력되는 것은 for문이 실행된 횟수이다.
이에 따라 버블정렬 코드를 아래와 같이 작성해보았으나, 시간 초과가 떴다…
코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import sys
input = sys.stdin.readline
def bubbleSort(arr, start):
for i in range(start, len(arr) - 1):
if(arr[i] >= arr[i+1]):
arr[i], arr[i+1] = arr[i+1], arr[i]
return arr
N = int(input())
A = []
start = 0
count = 1
for i in range(N):
A.append(int(input()))
A.append(1000001)
while A != sorted(A):
A = bubbleSort(A, start)
count += 1
start += 1
print(count)
시간 초과가 났다. 생각해보니 start를 하나씩 더하는 방식이 잘못 되었다. 버블 정렬은 맨 뒤에서부터 하나씩 정렬하는 정렬 방식이기 때문에 끝부분을 설정을 해야 하는 것이었다…
다시 코드를 작성해보았다:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import sys
input = sys.stdin.readline
def bubbleSort(arr, end):
swapped = False
for i in range(end):
if(arr[i] >= arr[i+1]):
arr[i], arr[i+1] = arr[i+1], arr[i]
swapped = True
return swapped
N = int(input())
A = [int(input()) for _ in range(N)]
count = 1
end = N - 1
while end > 0 :
if not bubbleSort(A, end):
break
end -= 1
count += 1
print(count)
근데 이번에도 시간 초과가 나는 것이었다.
첫 번째 시도에서는 내 로직 자체가 잘못되었음을 알고 수정했지만,
두 번째 코드에서의 시간 초과는 내가 접근을 잘못하고 있음을 느꼈다.
문제는 ‘버블 소트’지만 실제로 버블 정렬을 시행하면 O(N^2)의 시간 복잡도로 인해 무조건 시간 초과가 날 수 밖에 없구나…!
그렇다면 코드를 직접 실행하지 않고 몇 번 정렬 과정을 거치는지 확인할 방법이 필요하다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import sys
input = sys.stdin.readline
N = int(input())
A = [int(input()) for _ in range(N)]
count = 1
swapped = True
while swapped:
if max(A) != A[-1]:
A.remove(max(A))
count += 1
else:
swapped = False
print(count)
해당 문제의 경우에 버블 정렬은 가장 큰 값이 배열의 맨 뒤로 가는 오름차순 정렬이기 때문에 가장 큰 값이 맨 뒤에 위치하지 않는 경우 제거하는 방식으로 작성해보았다.
예시로, 1, 3, 2, 5, 9, 4
라는 배열이 주어지면, 9를 먼저 제거하고 다음 번엔 5를 제거하여 count 값이 3이 된다.
이렇게 하면 이중으로 있던 반복문을 하나로 줄일 수 있기 때문에 시간효율이 올라갈 것이라 생각하였지만…
또 다시 시간 초과가 났다. 아무래도 max를 매번 새로 찾는 것과 값을 삭제하는 것이 시간을 잡아먹는 모양이다.
그렇다면 버블 정렬의 특징 중 변경된 횟수를 구할 방법을 다시 찾아봐야 한다.
버블 정렬의 특징 중 하나는 정렬 과정 중 특정 숫자는 왼쪽으로는 한 칸씩만 이동할 수 있다는 점이다. 그리고, 실제 정렬이 이루어졌다면 적어도 한 수는 왼쪽으로 이동했을 것이다. 즉, 가장 많이 이동한 수를 찾아, 인덱스를 비교하면 되는 것이다.
1
2
3
4
5
6
7
8
9
10
11
12
13
import sys
input = sys.stdin.readline
N = int(input())
A = [(int(input()), i) for i in range(N)]
prime = sorted(A)
result = []
for j in range(N):
result.append(prime[j][1]- A[j][1])
print(max(result) + 1)
처음 받는 수들을 인덱스와 같이 배열에 넣어주고, 해당 배열을 정렬한 배열을 만들어 인덱스 값들을 뺴고 그 중 max값을 구하면 된다. 코드 자체는 쉽지만, 접근 방법이 어려운 문제여서 골드2인듯 하다.