잡글 가득 블로그
article thumbnail
Published 2022. 10. 10. 18:35
[ABC 272] F. Two Strings ★ PS 문제들

문제 링크

 

Suffix array를 이용하여 "문자열 $S$의 회전 $\le$ 문자열 $T$의 회전"을 만족하는 쌍의 개수를 세는 문제이다. (단, $N=|S|=|T|$)

회전(rotate)을 관리할 때는 원래 문자열을 두 번 연속해서 붙여주는 것이 경험적으로 유용하다. 왜냐하면 붙여 만든 새 문자열의 길이가 $N$인 각 부분 문자열들이 원래 문자열의 회전이기 때문이다.

 

붙여서 관리하겠다는 것까지는 알겠다. 이제는 효율적으로 $S$의 회전들과 $T$의 회전들을 비교해야 한다.

$S$의 회전마다 $T$의 SA에다가 이진 탐색으로 패턴 매칭을 시켜주는 것은 어떨까? 괜찮은 생각이지만 아직은 제한에 못 미친다.

 

$S$와 $T$를 전부 때려넣고 SA를 돌릴 수는 없을까?

이게 핵심이다. $\color {red}S+S+$"$a\cdots a$"$+\color {blue} T+T+$"$z\cdots z$"라는 문자열로 SA를 구축한다면 가능하다.

시작 지점이 빨간 $S$의 어느 지점인 suffix들은 $(S$의 회전$)+(S$의 회전의 어떤 prefix$)+$"$a\cdots a$"$+\cdots$이고,

시작 지점이 파란 $T$의 어느 지점인 suffix들은 $(T$의 회전$)+(T$의 회전의 어떤 prefix$)+$"$z\cdots z$"이다.

이 두 가지 경우가 아닌 suffix들은 그냥 무시하고 위 두 가지 경우들에 대해 SA로 정렬되었을 때 어떻게 되는지 확인해보자.

 

$\mathrm i.$ $S$의 회전 $\neq T$의 회전

그냥 이 안에서 잘 결정된다.

 

$\mathrm {ii}.$ $S$의 회전 $=T$의 회전

이 경우를 위해 "$a\cdots a$"와 "$z\cdots z$" 더미를 만든 것이다.

$(S$의 회전의 어떤 prefix$)+$"$a\cdots a$" $\le(T$의 회전의 어떤 prefix$)+$"$z\cdots z$"이게 되어 있다.

한 쪽 prefix가 지나가면 '$a$' 또는 '$z$'를 반드시 맞닥뜨리기 때문이다.

 

 

그럼 코딩하자.

#include <bits/stdc++.h>
using namespace std; using ii = pair<int,int>; using ll = long long;
#define rep(i,a,b) for (auto i = (a); i <= (b); ++i)
#define dbg(x) cerr << #x << ": " << x << '\n'
#define siz(x) int((x).size())
#define Mup(x,y) x = max(x,y)
#define mup(x,y) x = min(x,y)
using vi = vector<int>;

vi suffix_array(const string &s) {
    int n=siz(s);
    vi t(n), a(n), b(n);
    rep(i,0,n-1) t[i]=i, b[i]=s[i];
    for (int k=1; k<n; k*=2, b=a) {
        auto cmp=[&](int x, int y){ return ii(b[x],x+k<n ? b[x+k]:-1) < ii(b[y],y+k<n ? b[y+k]:-1); };
        sort(begin(t),end(t),cmp);
        a[t[0]]=0;
        rep(i,1,n-1) a[t[i]]=a[t[i-1]]+cmp(t[i-1],t[i]);
    }
    return t;
}

int main() {
    cin.tie(0)->sync_with_stdio(0);
    int n; string s, t;
    cin >> n >> s >> t;
    auto sa = suffix_array(s+s+string(n,'a')+t+t+string(n,'z'));
    ll ans = 0, cnt = 0;
    for (int i : sa) {
        if (0 <= i and i < n) cnt++;
        if (3*n <= i and i < 4*n) ans += cnt;
    }
    cout << ans;
}
profile

잡글 가득 블로그

@도훈.

포스팅이 좋았다면 "좋아요❤️" 또는 "구독👍🏻" 해주세요!

profile on loading

Loading...