真是不容易啊……
最小生成树计数,虽然说觉得以后已经基本不会考到了,但毕竟防范与未然
具体算法我已经在之前一篇算法证明上提到过了,这里是他的矩阵树实现。
这里的代码就暂时作为我的模板了,之后再找题验证一下。
题意:
最小生成树计数裸题
思路:
详见之前一篇证明文章
这里有个地方我一开始不是很明白,就是为什么要用两个并查集来维护。
纸上画了一画,稍微有点理解,简单说就是为了相同边权的边进行缩点,后续的并查集操作就用缩点后的点集,这符合算法的流程和原理。
该代码已经尝试在BZOJ 1016 上 1A ,作为本人模板使用
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <queue>
using namespace std;
#define each(i, n) for (int(i) = 0; (i) < (n); (i)++)
#define reach(i, n) for (int(i) = n - 1; (i) >= 0; (i)--)
#define range(i, st, en) for (int(i) = (st); (i) <= (en); (i)++)
#define rrange(i, st, en) for (int(i) = (en); (i) >= (st); (i)--)
typedef long long ll;
const int maxn = 105;
const int maxm = 1005;
ll mod, ans;
int n, m;
vector<int> G[maxn];
int root[maxn], vis[maxn], scc[maxn];
int g[maxn][maxn];
ll mat[maxn][maxn];
struct node {
int u, v, w;
} edges[maxm];
void init()
{
ans = 1;
each(i, n + 1) scc[i] = root[i] = i;
memset(g, 0, sizeof g);
}
int findRoot(int x, int rootAry[])
{
return rootAry[x] == x ? x : rootAry[x] = findRoot(root[x], rootAry);
}
ll getDet(ll a[][maxn], int n)
{
range(i, 1, n) range(j, 1, n) a[i][j] = (a[i][j] + mod) % mod;
ll ret = 1;
range(i, 1, n - 1)
{
range(j, i + 1, n - 1) while (a[j][i])
{
ll t = a[i][i] / a[j][i];
range(k, i, n - 1) a[i][k] = (a[i][k] - a[j][k] * t % mod + mod) % mod;
swap(a[i], a[j]);
ret = -ret;
}
if (a[i][i] == 0)
return 0;
ret = ret * a[i][i] % mod;
}
return (ret + mod) % mod;
}
void matrixTree()
{
range(i, 1, n) if (vis[i])
{
G[findRoot(i, root)].push_back(i);
vis[i] = false;
}
range(i, 1, n) if (G[i].size() > 1)
{
int sz = G[i].size();
memset(mat, 0, sizeof mat);
each(j, sz) range(k, j + 1, sz - 1)
{
int u = G[i][j], v = G[i][k];
if (g[u][v]) {
mat[k][j] = (mat[j][k] -= g[u][v]);
mat[k][k] += g[u][v];
mat[j][j] += g[u][v];
}
}
ans = ans * getDet(mat, G[i].size()) % mod;
each(j, sz) scc[G[i][j]] = i;
}
range(i, 1, n)
{
G[i].clear();
root[i] = scc[i] = findRoot(i, scc);
}
}
void output()
{
range(i, 1, n - 1) if (scc[i] != scc[i + 1])
{
puts("0");
return;
}
printf("%lld\n", ans % mod);
}
int main()
{
while (scanf("%d %d %lld", &n, &m, &mod) != EOF && n + m + mod) {
each(i, m) scanf("%d %d %d", &edges[i].u, &edges[i].v, &edges[i].w);
sort(edges, edges + m, [](const node& a, const node& b) { return a.w < b.w; });
init();
edges[m].w = -1;
range(i, 0, m)
{
if (i && edges[i].w != edges[i - 1].w)
matrixTree();
int u = findRoot(edges[i].u, scc), v = findRoot(edges[i].v, scc); //缩点后的点
if (u != v) {
vis[u] = vis[v] = true;
root[findRoot(u, root)] = findRoot(v, root);
g[u][v]++, g[v][u]++;
}
}
output();
}
return 0;
}