CodeForces 855 C Helga Hufflepuff's Cup

前几天的上分场的树形DP

我敢说思路和我当时想得已经一模一样了,还有半个多小时的时候,gou bi 带鱼说把题目给他看一下,然后立马得出,树形DP解不了,肯定是树分治!
当时我在一个子问题上走不出来,于是就信了,然后开始划水……

md

题意:
给你一棵树,要你给树上的每一个节点赋值,赋值存在范围,再给你一个特殊值,除开特殊值意外的数值赋值次数不限,但是特殊值的相邻节点的值不能是特殊值,也不能是比特殊值大的值。
问方案数。

思路:
首先值得注意的是,特殊值赋值次数不超过10,这是一个重要突破口,这使得树形DP的解法存在可能性。
简单说就是在每个节点都保存当前节点为根节点的子树分别含有 [ 0, 10 ] 个特殊值的方案数。
但这还不好决策转移,我们还需要对当前值的范围作为状态进行判断。
我在比赛中只用了,当前值为特殊值,和不是特殊值两个状态,但是现在看了其他人的代码,发现三个状态更容易。
当前节点赋值小于特殊值,大于特殊值,等于特殊值。这样就很好转移,不细说,看代码都能秒懂。

我在比赛时候被卡的子问题是这样的,对于计算当前子树的特殊值数量为 n 的时候,我必须将当前根节点的每一个孩子的状态都理清楚。
比如说数量为 3 ,有三个孩子A,B,C
那么其方案数就是以下情况累加

  • ( A_0+B_0+C_3 )
  • ( A_0+B_3+C_0 )
  • ( A_3+B_0+C_0 )
  • ( A_1+B_0+C_2 )
  • ( A_0+B_1+C_2 )
    ……
    ……

:wc 好多,复杂度好高,这不是组合问题么???
现在看来真是煞笔……

AC Code

#include <bits/stdc++.h>
#define ll long long

using namespace std;

const int mod = 1e9 + 7;
const int maxn = 1e5 + 10;

int n, m, k, x;
vector<int> tree[maxn];

ll dp[maxn][3][11];
ll tmp[3][11];

void dfs(int u, int par)
{
    dp[u][0][0] = k - 1;
    dp[u][1][0] = m - k;
    dp[u][2][1] = 1;
    for (auto v : tree[u]) {
        if (v != par) {
            dfs(v, u);
            memset(tmp, 0, sizeof tmp);
            for (int i = 0; i <= x; i++)
                for (int j = 0; i + j <= x; j++) {
                    tmp[0][i + j] = (tmp[0][i + j] + dp[u][0][i] * (dp[v][0][j] + dp[v][1][j] + dp[v][2][j])) % mod;
                    tmp[1][i + j] = (tmp[1][i + j] + dp[u][1][i] * (dp[v][0][j] + dp[v][1][j])) % mod;
                    tmp[2][i + j] = (tmp[2][i + j] + dp[u][2][i] * dp[v][0][j]) % mod;
                }
            for (int i = 0; i <= x; i++)
                for (int j = 0; j < 3; j++)
                    dp[u][j][i] = tmp[j][i];
        }
    }
}

int main()
{
    scanf("%d%d", &n, &m);
    int u, v;
    for (int i = 1; i < n; i++) {
        scanf("%d%d", &u, &v);
        tree[u].push_back(v);
        tree[v].push_back(u);
    }
    scanf("%d%d", &k, &x);
    dfs(1, 0);
    int ans = 0;
    for (int i = 0; i <= x; i++)
        for (int j = 0; j < 3; j++)
            ans = (ans + dp[1][j][i]) % mod;
    printf("%d\n", ans);
    return 0;
}