BZOJ 4197 寿司晚宴

上一场多校一道状压DP的原题。
那我肯定是去做原题啦。

多校的时候我还觉得这肯定是许学姐的锅,没想到居然能转化成状压。

题意:
中文题面,给你( [ 2, n ] ) 这几个数,要你各找两组数字,两组数字之间两两互质。
( n) 最大为500,问组合数量。

思路:
考虑数据 ( n) 非常小,在 ( \sqrt 500 )以内的质数只有8个,而大于( \sqrt 500 )的大元素只可能为一个,那么我们把相同的大元素放在一起处理,开一个二维数组表示这个大元素在两组中的哪一组,而对于小元素,可以用状态压缩来表示两组数字各自包含那几个质数。最后累加的时候让两个状态不重叠即可。

代码中 ( dp [ k ] [ a ] [ b ] [ c ] ) 表示 前( k)个元素中 第一组数包含的小质数状态为( a) ,第二组数包含的小质数状态为( b) ,大质数为在第一组( 0) ,或者在第二组( 1) 的组合数量。
dp过程中,可以通过类似背包处理,反向计算,省掉一维。

( f [ k ] [ a ] [ b ] ) 表示前(k ) 个元素中,第一组包含的小质数状态为( a),第二组包含的小质数状态为 ( b) 的组合数量。

因为( dp)数组包含了当前大质数的分配状态。所以在处理不同大质数的时候理所当然地要将 ( dp)数组重置。方法就是用( f [ a ] [ b ] )进行汇总,再分配。
在汇总的时候格外要注意的是,必须减去一个以前的( f [ a ] [ b ] ),因为 ( dp)数组在大质数分配的两个状态中都记录了以前的( f [ a ] [ b ] ),也就是在此之上没有加上任何元素的组合数。汇总的时候是多加一次的,这应该比较容易理解。

而状态转移方程是最好理解,我就不多说了 虽然不是我想出来的

AC Code

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>

#define each(i, n) for (int(i) = 0; (i) < (n); (i)++)
#define reach(i, n) for (int(i) = n - 1; (i) >= 0; (i)--)
#define range(i, st, en) for (int(i) = (st); (i) <= (en); (i)++)
#define rrange(i, st, en) for (int(i) = (en); (i) >= (st); (i)--)
#define fill(ary, num) memset((ary), (num), sizeof(ary))

using namespace std;
typedef long long ll;
const int maxn = 510;
const int maxc = 1 << 8;

ll dp[maxc][maxc][2], f[maxc][maxc];
ll n, mod;
int pri[10] = { 2, 3, 5, 7, 11, 13, 17, 19, 0, 0 };

struct node {
    int state;
    int big;
    bool operator<(const node& a) const { return big < a.big; }
} num[maxn];

void init()
{
    range(i, 2, n)
    {
        int x = i, v;
        each(j, 8) if (x % (v = pri[j]) == 0)
        {
            while (x % v == 0)
                x /= v;
            num[i].state |= (1 << j);
        }
        num[i].big = x;
    }
    sort(num + 2, num + n + 1);
    f[0][0] = 1;
}

int main()
{
    scanf("%lld %lld", &n, &mod);
    init();
    range(i, 2, n)
    {
        if (i == 2 || num[i].big == 1 || num[i].big != num[i - 1].big)
            reach(j, maxc) reach(k, maxc) dp[j][k][0] = dp[j][k][1] = f[j][k];
        reach(j, maxc) reach(k, maxc)
        {
            if ((num[i].state & j) == 0)
                dp[j][k | num[i].state][1] = (dp[j][k | num[i].state][1] + dp[j][k][1]) % mod;
            if ((num[i].state & k) == 0)
                dp[j | num[i].state][k][0] = (dp[j | num[i].state][k][0] + dp[j][k][0]) % mod;
        }
        if (i == n || num[i].big == 1 || num[i].big != num[i + 1].big)
            reach(j, maxc) reach(k, maxc) f[j][k] = (dp[j][k][0] + dp[j][k][1] - f[j][k] + mod) % mod;
    }
    ll ans = 0;
    reach(i, maxc) reach(j, maxc) if ((i & j) == 0) ans = (ans + f[i][j]) % mod;
    printf("%lld\n", ans);
    return 0;
}