baekjoon 14003:가장 긴 증가하는 부분 수열 5
baekjoon 14003 가장 긴 증가하는 부분 수열 5
14003번 가장 긴 증가하는 부분 수열 5
접근
dp로 풀려고 했더니 시간을 도저히 맞출 수 없었다.
dp로 풀면 O(n^2)에 푸는 것인데 n = 1,000,000이면 3초 무조건 넘는다.
그래서 결국 lis를 O(nlogn)에 푸는 알고리즘을 쓸 수 밖에 없었다.
이 알고리즘은 ans라는 벡터배열을 가지고 시작한다.
ans[i]는 길이 i+1의 lis의 마지막 숫자를 저장하고 있다.
만약 지금 ans[3]까지 채워져 있다면 길이 4의 lis까지 찾았다는 것이다.
이때 새로운 수를 접근했는데 ans[3]보다 크다면, 그 즉시 길이 5의 lis를 찾은 것 이므로
ans[4]에 새로운 수를 넣는다.
또한 ans는 항상 오름차순이다.
이는 길이 i의 lis 마지막 숫자가 있을 때 길이 i+1의 lis를 이루는 숫자는 길이 i의 lis 마지막 숫자보다 크기 때문이다.
(i+1숫자가 뒤에 있으면 i가 마지막일 때 뒤에 추가하는 거니까 새로운 i+1생긴 것이고
i+1숫자가 앞에 있으면 i+1숫자 앞에 i숫자가 분명히 있을 거니까 i+1은 i숫자보다 크고 지금 있는 마지막 i숫자는 그 앞에 있는 i숫자보다 작고) 그래서 ans의 마지막 수보다 크지 않은 경우에는 ans안에서 이분탐색으로 자기가 ans[2]와 ans[3]사이에 있다 하면 ans[3]에 대입하는 것이다.
ans[3]에는 자기보다 작은 수가 오는게 앞으로에 유리하고
새로온 수 입장에서는 ans[0] ans[1] ans[2]가 다 앞에 깔려 있으니 3은 보장된다.
실제 구현에서는 lower_bound를 이용하여 구했다.
그리고 실제로 lis를 찾는 과정에서는 지금까지의 마지막 숫자를 다 기록해둔 ansTrace 벡터를 이용하여
뒤에서부터 자기보다 먼저 나왔고 작은 수를 찾아나간다.
코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include <iostream>
#include <array>
#include <vector>
#include <algorithm>
using namespace std;
int n;
int main(){
// ios::sync_with_stdio(false);
// cin.tie(NULL); cout.tie(NULL);
cin >> n;
vector<int> v(n);
for(int i = 0; i < n; i++){
cin >> v[i];
}
vector<vector<pair<int, int>>> ansTrace(1, vector<pair<int, int>>(1, make_pair(v[0], 0)));
vector<int> ans;
ans.push_back(v[0]);
// for(auto& i : ans){
// for(auto j : i){
// cout << j << " ";
// }
// cout << "\n";
// }
for(int i = 1; i < n; i++){
// cout << i << " ";
if(v[i] > ans.back()){
vector<pair<int, int>> temp(1, {v[i], i});
ansTrace.push_back(temp);
ans.push_back(v[i]);
}
else{
int idx = lower_bound(ans.begin(), ans.end(), v[i]) - ans.begin();
//lower_bound는 v[i]이상의 수가 ans에서 어디서 등장하는지 알려준다.
//lower_bound는 iterator를 return 하기 때문에 index를 알고싶으면 ans.begin()을 빼준다.
ans[idx] = v[i];
ansTrace[idx].push_back({v[i], i});
}
}
cout << ans.size() << "\n";
// for(auto& i : ansTrace){
// for(auto j : i){
// cout << j.first << ", " << j.second << " ";
// }
// cout << "\n";
// }
vector<int> lisAns;
int prevX = 0;
for(int i = ans.size()-1; i >= 0; i--){
lisAns.push_back(ansTrace[i][prevX].first);
// cout << ansTrace[i][prevX].first << " ";
if(i==0) break;
for(int j = 0; j < ansTrace[i-1].size(); j++){
if(ansTrace[i-1][j].first < ansTrace[i][prevX].first && ansTrace[i-1][j].second < ansTrace[i][prevX].second){
prevX = j;
break;
}
}
}
sort(lisAns.begin(), lisAns.end());
for(auto i : lisAns){
cout << i << " ";
}
return 0;
}
배운 점
lis를 dp가 아닌 방법으로 O(nlogn)에 푸는 알고리즘을 배웠다.