后缀数组 + 单调栈好题
网络赛时的后缀数组可以用这个思路解,但是我并不是很会,今天补上。
题意:
给你两个串,问你长度不小于 k 的重复子串的数量。
思路:
若只是对于一个字符串,问它长度不小于 k 的重复子串数量,那么考虑对height根据 k 分组,然后记录一下就可以了,将 height 去重后,每个 height 的贡献为 ( max ( 0, height – k +1 ) ) 个人猜测,后果不计
但这题是两个字符串,我们可以将两个字符串用一个特殊字符连接,并将后缀通过第一个字符串长度区分开来,对于同一个 height 组中,当前的 height 只能跟另一个 height 组相比求 lcp 。
一个可行的方法是,将组内不同的height都记录下来,再对于每一个当前height,我把它跟另一个height组一一求贡献。但是这个复杂度为 ( O ( n^2 ) ),显然是不行的。
最后学到的优化就是单调栈优化了,用 st[i][0]表示栈中第 i 号元素记录时候的height值,st[i][1]表示在这个height值上覆盖了st[i][1]个子串,在栈内维护height递增,实现( O( 1 ) ) 求贡献。
#include <algorithm>
#include <cstdio>
#include <cstring>
#define each(i, n) for (int(i) = 0; (i) < (n); (i)++)
#define reach(i, n) for (int(i) = n - 1; (i) >= 0; (i)--)
#define range(i, st, en) for (int(i) = (st); (i) <= (en); (i)++)
#define rrange(i, st, en) for (int(i) = (en); (i) >= (st); (i)--)
#define fill(ary, num) memset((ary), (num), sizeof(ary))
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const double eps = 1e-5;
const int maxn = 2e5 + 5;
struct SuffixArray {
private:
int len;
int t1[maxn], t2[maxn], buc[maxn];
int s[maxn];
void DA(int n, int m)
{
int p, *x = t1, *y = t2;
each(i, m) buc[i] = 0;
each(i, n) buc[x[i] = s[i]]++;
range(i, 1, m - 1) buc[i] += buc[i - 1];
reach(i, n) sa[--buc[x[i]]] = i;
for (int k = 1; k <= n; k <<= 1) {
p = 0;
range(i, n - k, n - 1) y[p++] = i;
each(i, n) if (sa[i] >= k) y[p++] = sa[i] - k;
each(i, m) buc[i] = 0;
each(i, n) buc[x[i]]++;
range(i, 1, m - 1) buc[i] += buc[i - 1];
reach(i, n) sa[--buc[x[y[i]]]] = y[i];
swap(x, y);
p = 1, x[sa[0]] = 0;
range(i, 1, n - 1) x[sa[i]] = y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + k] == y[sa[i] + k] ? p - 1 : p++;
if (p >= n)
break;
m = p;
}
}
void getHeight(int n)
{
int j, k = 0;
each(i, n + 1) rank[sa[i]] = i;
each(i, n)
{
k ? k-- : 0;
j = sa[rank[i] - 1];
while (s[i + k] == s[j + k])
k++;
height[rank[i]] = k;
}
}
public:
int sa[maxn];
int rank[maxn], height[maxn];
void input(char* str)
{
len = strlen(str);
range(i, 0, len) s[i] = str[i];
DA(len + 1, 130);
getHeight(len);
}
int st[maxn][2];
void solve(int key, int limit)
{
ll tot = 0, top = 0, sum = 0;
range(i, 1, len) if (height[i] < limit) top = tot = 0;
else
{
int cnt = 0;
if (sa[i - 1] < key)
cnt++, tot += height[i] - limit + 1;
while (top > 0 && height[i] <= st[top - 1][0]) {
top--;
tot -= st[top][1] * (st[top][0] - height[i]);
cnt += st[top][1];
}
st[top][0] = height[i];
st[top++][1] = cnt;
if (sa[i] > key)
sum += tot;
}
range(i, 1, len) if (height[i] < limit) top = tot = 0;
else
{
int cnt = 0;
if (sa[i - 1] > key)
cnt++, tot += height[i] - limit + 1;
while (top > 0 && height[i] <= st[top - 1][0]) {
top--;
tot -= st[top][1] * (st[top][0] - height[i]);
cnt += st[top][1];
}
st[top][0] = height[i];
st[top++][1] = cnt;
if (sa[i] < key)
sum += tot;
}
printf("%lld\n", sum);
}
} suffix_array;
char buf[maxn];
int main()
{
int k;
while (scanf("%d", &k) != EOF && k > 0) {
scanf("%s", buf);
int len = strlen(buf);
buf[len] = '$';
scanf("%s", buf + len + 1);
suffix_array.input(buf);
suffix_array.solve(len, k);
}
return 0;
}