AC自动机

ZOJ 3228 Searching the String

Posted on

AC自动机吼题

虽然是很常规的模式串与匹配串的匹配计数问题,但它稍微给了一个新花样和一些坑。
写完这题,我当时就很感慨,如果我能在多校之前写过这题就好了……

题意:
先给你一个匹配串,再给你很多模式串,模式串前都有标号。标号为 0 表示计数可重叠,标号为 1 表示计数不可重叠。

思路:
写不来呀…………
看了题解,kuangbin只说了一句,加一个数组维护一下就好了…………这特么怎么维护…………

看了代码大致能理解,开了一个 last 数组记录每个位置上一次被匹配时的位置。
如果匹配串的位置与AC自动机上的相应位置的上一匹配位置差距为这个模式串的长度以上,说明可以匹配。
特别想要详细说,但又觉得没有什么好说的,所谓大道至简,原理真的很简单。

AC Code

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>

using namespace std;

const int max_n = 6e5 + 10;
const int max_c = 26;
int mp[max_n / 6], sel[max_n / 6];

struct Aho {
    int next[max_n][max_c], fail[max_n], end[max_n];
    int dep[max_n];
    int root, size;
    queue<int> que;

    int newNode()
    {
        for (int i = 0; i < max_c; i++)
            next[size][i] = 0;
        end[size++] = 0;
        return size - 1;
    }

    inline void init()
    {
        size = 1;
        root = newNode();
        dep[root] = 0;
    }

    int insert(char str[])
    {
        int len = strlen(str), now = root;
        for (int i = 0; i < len; i++) {
            int c = str[i] - 'a';
            if (!next[now][c]) {
                next[now][c] = newNode();
                dep[next[now][c]] = dep[now] + 1;
            }
            now = next[now][c];
        }
        end[now] = true;
        return now;
    }

    void build()
    {
        fail[root] = root;
        for (int i = 0; i < max_c; i++)
            if (!next[root][i])
                next[root][i] = root;
            else {
                fail[next[root][i]] = root;
                que.push(next[root][i]);
            }
        while (!que.empty()) {
            int now = que.front();
            que.pop();
            for (int i = 0; i < max_c; i++)
                if (!next[now][i])
                    next[now][i] = next[fail[now]][i];
                else {
                    fail[next[now][i]] = next[fail[now]][i];
                    que.push(next[now][i]);
                }
        }
    }

    int ans[max_n][2];
    int last[max_n];
    void query(char str[], int n)
    {
        memset(ans, 0, sizeof ans);
        memset(last, -1, sizeof last);
        int len = strlen(str), now = root;
        for (int i = 0; i < len; i++) {
            now = next[now][str[i] - 'a'];
            int cur = now;
            while (cur != root) {
                if (end[cur]) {
                    ans[cur][0]++;
                    if (i - last[cur] >= dep[cur]) {
                        last[cur] = i;
                        ans[cur][1]++;
                    }
                }
                cur = fail[cur];
            }
        }
        for (int i = 0; i < n; i++)
            printf("%d\n", ans[mp[i]][sel[i]]);
        puts("");
    }
} aho;

char str[max_n], buf[10];

int main()
{
    int n, cas = 0;
    while (scanf("%s", str) != EOF) {
        scanf("%d", &n);
        aho.init();
        for (int i = 0; i < n; i++) {
            scanf("%d %s", sel + i, buf);
            mp[i] = aho.insert(buf);
        }
        aho.build();
        printf("Case %d\n", ++cas);
        aho.query(str, n);
    }
    return 0;
}
AC自动机

HDU 2296 Ring

Posted on

AC自动机+带权字符串DP

题目本身不是很难,但在处理上有点繁琐。

题意:
给你很多模式串,每个模式串都有一个权值,再给你一个限制长度,要你在限制长度内找出最短的字符串,使得权值和最大。若有多个结果,输出字典序最小的。

思路:
连AC自动机上的最短路都有了,加个权值也不是什么奇怪的事情。
这个加了权值有点像…………背包??不是么??
想象不到的话将背包的物品替换成一个字符试试。

但是实际上都是我瞎猜罢了,我并没有按照背包的写法去写,还是很常规的按照节点和状态添加字符,到下一节点的下一状态。

稍微繁琐一点的就是要输出字符串,但所幸数据比较小,三维空间存的下。

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>

using namespace std;

const int inf = 0x3f3f3f3f;
const int max_n = 100 * 20;
const int max_c = 26;
const int max_l = 55;

int val[max_l * 2];

struct Aho {
    int next[max_n][max_c], fail[max_n];
    int end[max_n];
    int root, size;
    queue<int> que;

    int newNode()
    {
        for (int i = 0; i < max_c; i++)
            next[size][i] = 0;
        end[size++] = 0;
        return size - 1;
    }

    inline void init()
    {
        size = 1;
        root = newNode();
    }

    void insert(char str[], int id)
    {
        int len = strlen(str), now = root;
        for (int i = 0; i < len; i++) {
            int c = str[i] - 'a';
            if (!next[now][c])
                next[now][c] = newNode();
            now = next[now][c];
        }
        end[now] = id;
    }

    void build()
    {
        for (int i = root; i < size; i++)
            if (end[i])
                end[i] = val[end[i]];
        fail[root] = root;
        for (int i = 0; i < max_c; i++)
            if (!next[root][i])
                next[root][i] = root;
            else {
                fail[next[root][i]] = root;
                que.push(next[root][i]);
            }
        while (!que.empty()) {
            int now = que.front();
            que.pop();
            if (end[fail[now]])
                end[now] += end[fail[now]];
            for (int i = 0; i < max_c; i++)
                if (!next[now][i])
                    next[now][i] = next[fail[now]][i];
                else {
                    fail[next[now][i]] = next[fail[now]][i];
                    que.push(next[now][i]);
                }
        }
    }

    bool cmp(const char* sa, const char* sb)
    {
        int la = strlen(sa), lb = strlen(sb);
        return la != lb ? la < lb : strcmp(sa, sb) < 0;
    }

    int dp[max_l][max_n];
    char str[max_l][max_n][max_l];
    void solve(int n)
    {
        for (int i = 0; i <= n; i++)
            for (int j = root; j < size; j++)
                dp[i][j] = -inf;
        dp[0][root] = 0;
        char curs[max_l];
        strcpy(str[0][root], "");
        for (int i = 0; i < n; i++) {
            for (int sta = root; sta < size; sta++) {
                if (dp[i][sta] >= 0) {
                    strcpy(curs, str[i][sta]);
                    int len = strlen(curs);
                    for (int j = 0; j < max_c; j++) {
                        int nxt = next[sta][j];
                        int pv = dp[i][sta];
                        curs[len] = 'a' + j;
                        curs[len + 1] = 0;
                        if (end[nxt])
                            pv += end[nxt];
                        if (dp[i + 1][nxt] < pv || (dp[i + 1][nxt] == pv && cmp(curs, str[i + 1][nxt]))) {
                            dp[i + 1][nxt] = pv;
                            strcpy(str[i + 1][nxt], curs);
                        }
                    }
                }
            }
        }
        char as[max_n] = "";
        int ans = -inf;
        for (int i = 0; i <= n; i++)
            for (int j = root; j < size; j++)
                if (ans < dp[i][j] || (ans == dp[i][j] && cmp(str[i][j], as))) {
                    strcpy(as, str[i][j]);
                    ans = dp[i][j];
                }
        printf("%s\n", as);
    }

} aho;

char buf[max_l];

int main()
{
    int T, n, m;
    scanf("%d", &T);
    while (T--) {
        scanf("%d%d", &n, &m);
        aho.init();
        for (int i = 1; i <= m; i++) {
            scanf("%s", buf);
            aho.insert(buf, i);
        }
        for (int i = 1; i <= m; i++)
            scanf("%d", val + i);
        aho.build();
        aho.solve(n);
    }
    return 0;
}
AC自动机

POJ 1699 Best Sequence

Posted on

AC自动机+spfa最短路
啊啊,明明AC自动机还有最后一题,今晚怕是写不完了,今天写的还有两道没写题解。

最后一道AC自动机与这一题有相似之处,是这一题的强力加强版本。
先把这题给解决了

题意:
给你很多碱基串,不同碱基串在前后缀相同的条件下可以重叠,要你求出最短的碱基串,包含所有给定串。

思路:
AC自动机上状压SPFA,也可考虑TSP的模型。
还是属于比较基础的,连我都可以不看题解写出来的题目。

#include <cstring>
#include <iostream>
#include <queue>
#include <set>

using namespace std;

const int max_n = 250;
const int max_c = 4;
int key[130];

struct Aho {
    int next[max_n][max_c], fail[max_n], end[max_n];
    int root, size;
    queue<int> que;

    int newNode()
    {
        for (int i = 0; i < max_c; i++)
            next[size][i] = 0;
        end[size++] = 0;
        return size - 1;
    }

    inline void init()
    {
        size = 1;
        root = newNode();
    }

    void insert(int id, const string& str)
    {
        int len = str.length(), now = root;
        for (int i = 0; i < len; i++) {
            int c = key[str[i]];
            if (!next[now][c])
                next[now][c] = newNode();
            now = next[now][c];
        }
        end[now] = 1 << id;
    }

    void build()
    {
        fail[root] = root;
        for (int i = 0; i < max_c; i++)
            if (!next[root][i])
                next[root][i] = root;
            else {
                fail[next[root][i]] = root;
                que.push(next[root][i]);
            }
        while (!que.empty()) {
            int now = que.front();
            que.pop();
            if (end[fail[now]])
                end[now] |= end[fail[now]];
            for (int i = 0; i < max_c; i++)
                if (!next[now][i])
                    next[now][i] = next[fail[now]][i];
                else {
                    fail[next[now][i]] = next[fail[now]][i];
                    que.push(next[now][i]);
                }
        }
    }

    int dp[max_n][1024];
    bool inq[max_n][1024];
    void solve(int n)
    {
        memset(dp, 0x3f, sizeof dp);
        memset(inq, false, sizeof inq);
        n = 1 << n;
        que.push(root * n);
        dp[root][0] = 0, inq[root][0] = true;
        while (!que.empty()) {
            int cur = que.front();
            que.pop();
            int a = cur / n, b = cur % n;
            inq[a][b] = false;
            for (int i = 0; i < max_c; i++) {
                int na = next[a][i], nb = b | end[next[a][i]];
                if (dp[na][nb] > dp[a][b] + 1) {
                    dp[na][nb] = dp[a][b] + 1;
                    if (!inq[na][nb]) {
                        inq[na][nb] = true;
                        que.push(na * n + nb);
                    }
                }
            }
        }
        int ans = 0x3f3f3f3f;
        for (int i = 1; i <= size; i++)
            ans = min(ans, dp[i][n - 1]);
        cout << ans << endl;
    }

} aho;

string tmp;
set<string> ks;

int main()
{
    key['A'] = 0, key['C'] = 1, key['T'] = 2, key['G'] = 3;
    ios::sync_with_stdio(false);
    int T, n;
    cin >> T;
    while (T--) {
        ks.clear();
        cin >> n;
        while (n--) {
            cin >> tmp;
            ks.insert(tmp);
        }
        n++;
        aho.init();
        for (set<string>::iterator it = ks.begin(); it != ks.end(); it++)
            aho.insert(n++, *it);
        aho.build();
        aho.solve(n);
    }
    return 0;
}
AC自动机

HDU 3341 Lost’s revenge

Posted on

AC自动机 + 变进制状压DP

看题目的 discuss 都说会卡数据,但我还是 1500ms A了。虽然变进制的思路并不是我的……

最近有点懒得写博客了呀

题意:
还是基因碱基对背景。
给你几个模式串,再给你一个匹配串,允许你对匹配串重新排列,要求重新排列后跟模式串匹配的最大数量。

思路:
如果再加个权值,那就是DP套DP了,(误……
本质上还是一个计数DP,考虑状压,但是给定的匹配串长度上限为 40 ,如果开一个状态+节点的 5维数组,会炸内存。
考虑到将节点的 4 维压缩,因为 40 的实际状态数最多是 4 个碱基数量相同的时候。
也就是\( 10 \times 10 \times 10 \times 10 \),内存与复杂度均可以承受。

AC Code

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>

#define each(i, n) for (int(i) = 0; (i) < (n); (i)++)
#define reach(i, n) for (int(i) = n - 1; (i) >= 0; (i)--)
#define range(i, st, en) for (int(i) = (st); (i) <= (en); (i)++)
#define rrange(i, st, en) for (int(i) = (en); (i) >= (st); (i)--)
#define fill(ary, num) memset((ary), (num), sizeof(ary))

using namespace std;
const int max_n = 505;
const int max_c = 4;

int num[4];
int change[130];

struct Aho {
    int next[max_n][max_c], fail[max_n], end[max_n];
    int root, size;
    queue<int> que;

    int newNode()
    {
        for (int i = 0; i < max_c; i++)
            next[size][i] = 0;
        end[size++] = 0;
        return size - 1;
    }

    inline void init()
    {
        size = 1;
        root = newNode();
    }

    void insert(char str[])
    {
        int len = strlen(str), now = root;
        for (int i = 0; i < len; i++) {
            int c = change[(int)str[i]];
            if (!next[now][c])
                next[now][c] = newNode();
            now = next[now][c];
        }
        end[now]++;
    }

    void build()
    {
        fail[root] = root;
        for (int i = 0; i < max_c; i++)
            if (!next[root][i])
                next[root][i] = root;
            else {
                fail[next[root][i]] = root;
                que.push(next[root][i]);
            }
        while (!que.empty()) {
            int now = que.front();
            que.pop();
            end[now] += end[fail[now]];
            for (int i = 0; i < max_c; i++)
                if (!next[now][i])
                    next[now][i] = next[fail[now]][i];
                else {
                    fail[next[now][i]] = next[fail[now]][i];
                    que.push(next[now][i]);
                }
        }
    }

    int dp[max_n][11 * 11 * 11 * 11 + 5];
    int query()
    {
        int res = 0, bit[4];
        memset(dp, -1, sizeof dp), dp[root][0] = 0;
        bit[3] = 1;
        rrange(i, 1, 3) bit[i - 1] = (num[i] + 1) * bit[i];
        range(A, 0, num[0]) range(T, 0, num[1]) range(C, 0, num[2]) range(G, 0, num[3])
        {
            int sta = A * bit[0] + T * bit[1] + C * bit[2] + G;
            range(cur, root, size - 1) if (dp[cur][sta] >= 0) for (int i = 0; i < 4; i++)
            {
                if (i == 0 && A == num[0]) continue;
                if (i == 1 && T == num[1]) continue;
                if (i == 2 && C == num[2]) continue;
                if (i == 3 && G == num[3]) continue;
                dp[next[cur][i]][sta + bit[i]] = max(dp[next[cur][i]][sta + bit[i]], dp[cur][sta] + end[next[cur][i]]);
            }
        }
        int fnl = 0;
        each(i,4) fnl += num[i] * bit[i];
        range(i,root,size - 1) res = max(res, dp[i][fnl]);
        return res;
    }
} aho;

char buf[50];

int main()
{
    change[(int)'A'] = 0, change[(int)'T'] = 1, change[(int)'C'] = 2, change[(int)'G'] = 3;
    int n, cas = 0;
    while (scanf("%d", &n) != EOF && n) {
        memset(num, 0, sizeof num);
        aho.init();
        for (int i = 0; i < n; i++) {
            scanf("%s", buf);
            aho.insert(buf);
        }
        aho.build();
        scanf("%s", buf);
        int len = strlen(buf);
        for (int i = 0; i < len; i++)
            num[change[(int)buf[i]]]++;
        printf("Case %d: %d\n", ++cas, aho.query());
    }
    return 0;
}