HDU 2296 Ring

AC自动机+带权字符串DP

题目本身不是很难,但在处理上有点繁琐。

题意:
给你很多模式串,每个模式串都有一个权值,再给你一个限制长度,要你在限制长度内找出最短的字符串,使得权值和最大。若有多个结果,输出字典序最小的。

思路:
连AC自动机上的最短路都有了,加个权值也不是什么奇怪的事情。
这个加了权值有点像…………背包??不是么??
想象不到的话将背包的物品替换成一个字符试试。

但是实际上都是我瞎猜罢了,我并没有按照背包的写法去写,还是很常规的按照节点和状态添加字符,到下一节点的下一状态。

稍微繁琐一点的就是要输出字符串,但所幸数据比较小,三维空间存的下。

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>

using namespace std;

const int inf = 0x3f3f3f3f;
const int max_n = 100 * 20;
const int max_c = 26;
const int max_l = 55;

int val[max_l * 2];

struct Aho {
    int next[max_n][max_c], fail[max_n];
    int end[max_n];
    int root, size;
    queue<int> que;

    int newNode()
    {
        for (int i = 0; i < max_c; i++)
            next[size][i] = 0;
        end[size++] = 0;
        return size - 1;
    }

    inline void init()
    {
        size = 1;
        root = newNode();
    }

    void insert(char str[], int id)
    {
        int len = strlen(str), now = root;
        for (int i = 0; i < len; i++) {
            int c = str[i] - 'a';
            if (!next[now][c])
                next[now][c] = newNode();
            now = next[now][c];
        }
        end[now] = id;
    }

    void build()
    {
        for (int i = root; i < size; i++)
            if (end[i])
                end[i] = val[end[i]];
        fail[root] = root;
        for (int i = 0; i < max_c; i++)
            if (!next[root][i])
                next[root][i] = root;
            else {
                fail[next[root][i]] = root;
                que.push(next[root][i]);
            }
        while (!que.empty()) {
            int now = que.front();
            que.pop();
            if (end[fail[now]])
                end[now] += end[fail[now]];
            for (int i = 0; i < max_c; i++)
                if (!next[now][i])
                    next[now][i] = next[fail[now]][i];
                else {
                    fail[next[now][i]] = next[fail[now]][i];
                    que.push(next[now][i]);
                }
        }
    }

    bool cmp(const char* sa, const char* sb)
    {
        int la = strlen(sa), lb = strlen(sb);
        return la != lb ? la < lb : strcmp(sa, sb) < 0;
    }

    int dp[max_l][max_n];
    char str[max_l][max_n][max_l];
    void solve(int n)
    {
        for (int i = 0; i <= n; i++)
            for (int j = root; j < size; j++)
                dp[i][j] = -inf;
        dp[0][root] = 0;
        char curs[max_l];
        strcpy(str[0][root], "");
        for (int i = 0; i < n; i++) {
            for (int sta = root; sta < size; sta++) {
                if (dp[i][sta] >= 0) {
                    strcpy(curs, str[i][sta]);
                    int len = strlen(curs);
                    for (int j = 0; j < max_c; j++) {
                        int nxt = next[sta][j];
                        int pv = dp[i][sta];
                        curs[len] = 'a' + j;
                        curs[len + 1] = 0;
                        if (end[nxt])
                            pv += end[nxt];
                        if (dp[i + 1][nxt] < pv || (dp[i + 1][nxt] == pv && cmp(curs, str[i + 1][nxt]))) {
                            dp[i + 1][nxt] = pv;
                            strcpy(str[i + 1][nxt], curs);
                        }
                    }
                }
            }
        }
        char as[max_n] = "";
        int ans = -inf;
        for (int i = 0; i <= n; i++)
            for (int j = root; j < size; j++)
                if (ans < dp[i][j] || (ans == dp[i][j] && cmp(str[i][j], as))) {
                    strcpy(as, str[i][j]);
                    ans = dp[i][j];
                }
        printf("%s\n", as);
    }

} aho;

char buf[max_l];

int main()
{
    int T, n, m;
    scanf("%d", &T);
    while (T--) {
        scanf("%d%d", &n, &m);
        aho.init();
        for (int i = 1; i <= m; i++) {
            scanf("%s", buf);
            aho.insert(buf, i);
        }
        for (int i = 1; i <= m; i++)
            scanf("%d", val + i);
        aho.build();
        aho.solve(n);
    }
    return 0;
}