Post

baekjoon 14003:가장 긴 증가하는 부분 수열 5

baekjoon 14003 가장 긴 증가하는 부분 수열 5

14003번 가장 긴 증가하는 부분 수열 5

접근

dp로 풀려고 했더니 시간을 도저히 맞출 수 없었다.
dp로 풀면 O(n^2)에 푸는 것인데 n = 1,000,000이면 3초 무조건 넘는다.
그래서 결국 lis를 O(nlogn)에 푸는 알고리즘을 쓸 수 밖에 없었다.
이 알고리즘은 ans라는 벡터배열을 가지고 시작한다.
ans[i]는 길이 i+1의 lis의 마지막 숫자를 저장하고 있다.
만약 지금 ans[3]까지 채워져 있다면 길이 4의 lis까지 찾았다는 것이다.
이때 새로운 수를 접근했는데 ans[3]보다 크다면, 그 즉시 길이 5의 lis를 찾은 것 이므로
ans[4]에 새로운 수를 넣는다.
또한 ans는 항상 오름차순이다.
이는 길이 i의 lis 마지막 숫자가 있을 때 길이 i+1의 lis를 이루는 숫자는 길이 i의 lis 마지막 숫자보다 크기 때문이다.
(i+1숫자가 뒤에 있으면 i가 마지막일 때 뒤에 추가하는 거니까 새로운 i+1생긴 것이고
i+1숫자가 앞에 있으면 i+1숫자 앞에 i숫자가 분명히 있을 거니까 i+1은 i숫자보다 크고 지금 있는 마지막 i숫자는 그 앞에 있는 i숫자보다 작고) 그래서 ans의 마지막 수보다 크지 않은 경우에는 ans안에서 이분탐색으로 자기가 ans[2]와 ans[3]사이에 있다 하면 ans[3]에 대입하는 것이다.
ans[3]에는 자기보다 작은 수가 오는게 앞으로에 유리하고
새로온 수 입장에서는 ans[0] ans[1] ans[2]가 다 앞에 깔려 있으니 3은 보장된다.
실제 구현에서는 lower_bound를 이용하여 구했다.
그리고 실제로 lis를 찾는 과정에서는 지금까지의 마지막 숫자를 다 기록해둔 ansTrace 벡터를 이용하여
뒤에서부터 자기보다 먼저 나왔고 작은 수를 찾아나간다.

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include <iostream>
#include <array>
#include <vector>
#include <algorithm>
using namespace std;
int n; 
int main(){
  // ios::sync_with_stdio(false);
  // cin.tie(NULL); cout.tie(NULL);

  cin >> n;
  vector<int> v(n);
  for(int i = 0; i < n; i++){
    cin >> v[i];
  }
  vector<vector<pair<int, int>>> ansTrace(1, vector<pair<int, int>>(1, make_pair(v[0], 0)));
  vector<int> ans;
  ans.push_back(v[0]);
  // for(auto& i : ans){
  //   for(auto j : i){
  //     cout << j << " ";
  //   }
  //   cout << "\n";
  // }
  for(int i = 1; i < n; i++){
    // cout << i << " ";
    if(v[i] > ans.back()){
      vector<pair<int, int>> temp(1, {v[i], i});
      ansTrace.push_back(temp);
      ans.push_back(v[i]);
    }
    else{
      int idx = lower_bound(ans.begin(), ans.end(), v[i]) - ans.begin();
      //lower_bound는 v[i]이상의 수가 ans에서 어디서 등장하는지 알려준다.  
      //lower_bound는 iterator를 return 하기 때문에 index를 알고싶으면 ans.begin()을 빼준다.
      ans[idx] = v[i];
      ansTrace[idx].push_back({v[i], i});
    }
  }
  cout << ans.size() << "\n";
  // for(auto& i : ansTrace){
  //   for(auto j : i){
  //     cout << j.first << ", " << j.second << " ";
  //   }
  //   cout << "\n";
  // }
  vector<int> lisAns;
  int prevX = 0;
  for(int i = ans.size()-1; i >= 0; i--){
    lisAns.push_back(ansTrace[i][prevX].first);
    // cout << ansTrace[i][prevX].first << " ";
    if(i==0) break;
    for(int j = 0; j < ansTrace[i-1].size(); j++){
      if(ansTrace[i-1][j].first < ansTrace[i][prevX].first && ansTrace[i-1][j].second < ansTrace[i][prevX].second){
        prevX = j;
        break;
      }
    }
  }
  sort(lisAns.begin(), lisAns.end());
  for(auto i : lisAns){
    cout << i << " ";
  }
  
  
  
  return 0;
}

배운 점

lis를 dp가 아닌 방법으로 O(nlogn)에 푸는 알고리즘을 배웠다.

This post is licensed under CC BY 4.0 by the author.