分治+字典树+MST 综合好题
标题即题意,但是不一样的是边权是点权的异或。
不可能会出什么边权也是给定的完全图最小生成树计数的啦,现在看来,也算是套路题了。
当然点数不能太多…………
不卡常才是一个好题的前提条件。写 Trie 习惯性用了类,虽然慢了一些,但也不至于太慢。
题意:
给你 n 个点的点权,边权为其异或值,要你求出最小生成树,以及最小生成树的个数。
思路:
这道题的核心思想还是分治。
我们对给定区间的点权进行处理,并假设这个区间的前面 k – 1 位全部相同。那么对于第 k 位,我们根据 01 将其划分为两个集合,这时候两个集合的前 k 位分别相同,再对这两个集合分别进行同样的处理。
其正确性不言而喻。
在处理的过程中,我们需要将两个集合之间建边,因为是MST,所以要找最小的边。将其中一个集合插入 字典树,再让另一个集合的每一个点以常数级别的复杂度去找到最小异或值,再更新答案与边数。
最小生成树的个数只要根据乘法原理将边的数量相乘即可。
AC Code
#include <bits/stdc++.h>
using namespace std;
typedef pair<int, int> pii;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int mod = 1e9 + 7;
const int maxn = 1e5 + 5;
int n, anscnt, val[maxn], s[maxn], t[maxn];
long long sum;
struct Trie {
int end[maxn * 30], next[maxn * 30][2];
int size, root;
int newNode()
{
next[size][0] = next[size][1] = 0;
end[size++] = 0;
return size - 1;
}
void init()
{
size = 0;
root = newNode();
}
void insert(int x)
{
int cur = root;
for (int i = 30, y; i >= 0; i--) {
y = (x >> i) & 1;
if (!next[cur][y])
next[cur][y] = newNode();
cur = next[cur][y];
}
end[cur]++;
}
pii search(int x)
{
int cur = root, ans = 0;
for (int i = 30, y; i >= 0; i--) {
y = (x >> i) & 1;
if (next[cur][y])
cur = next[cur][y], ans |= (y << i);
else
cur = next[cur][y ^ 1], ans |= ((y ^ 1) << i);
}
return make_pair(ans ^ x, end[cur]);
}
} trie;
inline void read(int& x)
{
char ch = getchar();
x = 0;
while (!(ch >= '0' && ch <= '9'))
ch = getchar();
while (ch >= '0' && ch <= '9')
x = x * 10 + ch - '0', ch = getchar();
}
int fastPow(int x, int y)
{
int res = 1;
while (y) {
if (y & 1)
res = 1LL * res * x % mod;
x = 1LL * x * x % mod, y >>= 1;
}
return res;
}
void dac(int l, int r, int dep)
{
if (l >= r)
return;
if (dep < 0) {
if (r - l + 1 >= 2)
anscnt = 1LL * anscnt * fastPow(r - l + 1, r - l - 1) % mod;
return;
}
int lc = 0, rc = 0, ans = inf, cnt = 0;
for (int i = l; i <= r; i++) {
if ((val[i] >> dep) & 1)
s[lc++] = val[i];
else
t[rc++] = val[i];
}
trie.init();
for (int i = 0; i < lc; i++)
val[l + i] = s[i];
for (int i = 0; i < rc; i++) {
val[l + lc + i] = t[i];
trie.insert(t[i]);
}
pii tmp;
for (int i = 0; i < lc; i++) {
tmp = trie.search(s[i]);
if (ans > tmp.first)
ans = tmp.first, cnt = tmp.second;
else if (ans == tmp.first)
cnt += tmp.second;
}
if (cnt) {
sum += ans;
anscnt = 1LL * cnt * anscnt % mod;
}
dac(l, l + lc - 1, dep - 1);
dac(l + lc, r, dep - 1);
}
int main()
{
anscnt = 1;
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d", val + i);
dac(1, n, 30);
printf("%lld\n%d\n", sum, anscnt);
return 0;
}