51Nod 1601 完全图的最小生成树计数

分治+字典树+MST 综合好题

标题即题意,但是不一样的是边权是点权的异或。

不可能会出什么边权也是给定的完全图最小生成树计数的啦,现在看来,也算是套路题了。
当然点数不能太多…………

不卡常才是一个好题的前提条件。写 Trie 习惯性用了类,虽然慢了一些,但也不至于太慢。

题意:
给你 n 个点的点权,边权为其异或值,要你求出最小生成树,以及最小生成树的个数。

思路:
这道题的核心思想还是分治。
我们对给定区间的点权进行处理,并假设这个区间的前面 k – 1 位全部相同。那么对于第 k 位,我们根据 01 将其划分为两个集合,这时候两个集合的前 k 位分别相同,再对这两个集合分别进行同样的处理。
其正确性不言而喻。

在处理的过程中,我们需要将两个集合之间建边,因为是MST,所以要找最小的边。将其中一个集合插入 字典树,再让另一个集合的每一个点以常数级别的复杂度去找到最小异或值,再更新答案与边数。
最小生成树的个数只要根据乘法原理将边的数量相乘即可。

AC Code

#include <bits/stdc++.h>

using namespace std;

typedef pair<int, int> pii;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int mod = 1e9 + 7;
const int maxn = 1e5 + 5;

int n, anscnt, val[maxn], s[maxn], t[maxn];
long long sum;

struct Trie {
    int end[maxn * 30], next[maxn * 30][2];
    int size, root;

    int newNode()
    {
        next[size][0] = next[size][1] = 0;
        end[size++] = 0;
        return size - 1;
    }

    void init()
    {
        size = 0;
        root = newNode();
    }

    void insert(int x)
    {
        int cur = root;
        for (int i = 30, y; i >= 0; i--) {
            y = (x >> i) & 1;
            if (!next[cur][y])
                next[cur][y] = newNode();
            cur = next[cur][y];
        }
        end[cur]++;
    }

    pii search(int x)
    {
        int cur = root, ans = 0;
        for (int i = 30, y; i >= 0; i--) {
            y = (x >> i) & 1;
            if (next[cur][y])
                cur = next[cur][y], ans |= (y << i);
            else
                cur = next[cur][y ^ 1], ans |= ((y ^ 1) << i);
        }
        return make_pair(ans ^ x, end[cur]);
    }
} trie;

inline void read(int& x)
{
    char ch = getchar();
    x = 0;
    while (!(ch >= '0' && ch <= '9'))
        ch = getchar();
    while (ch >= '0' && ch <= '9')
        x = x * 10 + ch - '0', ch = getchar();
}

int fastPow(int x, int y)
{
    int res = 1;
    while (y) {
        if (y & 1)
            res = 1LL * res * x % mod;
        x = 1LL * x * x % mod, y >>= 1;
    }
    return res;
}

void dac(int l, int r, int dep)
{
    if (l >= r)
        return;
    if (dep < 0) {
        if (r - l + 1 >= 2)
            anscnt = 1LL * anscnt * fastPow(r - l + 1, r - l - 1) % mod;
        return;
    }
    int lc = 0, rc = 0, ans = inf, cnt = 0;
    for (int i = l; i <= r; i++) {
        if ((val[i] >> dep) & 1)
            s[lc++] = val[i];
        else
            t[rc++] = val[i];
    }
    trie.init();
    for (int i = 0; i < lc; i++)
        val[l + i] = s[i];
    for (int i = 0; i < rc; i++) {
        val[l + lc + i] = t[i];
        trie.insert(t[i]);
    }
    pii tmp;
    for (int i = 0; i < lc; i++) {
        tmp = trie.search(s[i]);
        if (ans > tmp.first)
            ans = tmp.first, cnt = tmp.second;
        else if (ans == tmp.first)
            cnt += tmp.second;
    }
    if (cnt) {
        sum += ans;
        anscnt = 1LL * cnt * anscnt % mod;
    }
    dac(l, l + lc - 1, dep - 1);
    dac(l + lc, r, dep - 1);
}

int main()
{
    anscnt = 1;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
        scanf("%d", val + i);
    dac(1, n, 30);
    printf("%lld\n%d\n", sum, anscnt);
    return 0;
}