BZOJ 1016 [JSOI2008]最小生成树计数

bzoj 第一题,矩阵树算法写起来炒鸡烦,就先去写了一发暴力的……
说是写的,其实基本是看着黄学长的代码敲的……

题意:
最小生成树计数。

思路:
请查阅本人关于最小生成树计数的粗浅证明文章。

注: 这道题你不手写读入会蜜汁WA

#include <bits/stdc++.h>

#define ll long long
#define each(i, n) for (int(i) = 0; (i) < (n); (i)++)
#define reach(i, n) for (int(i) = n - 1; (i) >= 0; (i)--)
#define range(i, st, en) for (int(i) = (st); (i) <= (en); (i)++)
#define rrange(i, st, en) for (int(i) = (en); (i) >= (st); (i)--)
#define fill(num, ary) memset((ary), (num), sizeof((ary)))

using namespace std;
const int mod = 31011;
const int maxn = 1005;

struct edge {
    int u, v, w;
    bool operator<(const edge& a) const
    {
        return w < a.w;
    }
} edges[maxn];

struct segment {
    int l, r, num;
} segs[maxn];

int n, m, cnt, tot, sum, ans = 1;
int root[maxn];

inline bool scan_d(int& num)
{
    char in;
    bool IsN = false;
    in = getchar();
    if (in == EOF)
        return false;
    while (in != '-' && (in < '0' || in > '9'))
        in = getchar();
    if (in == '-') {
        IsN = true;
        num = 0;
    } else
        num = in - '0';
    while (in = getchar(), in >= '0' && in <= '9') {
        num *= 10, num += in - '0';
    }
    if (IsN)
        num = -num;
    return true;
}

int findRoot(int x)
{
    return root[x] == x ? x : findRoot(root[x]);
}

void dfs(int id, int cur, int num)
{
    if (cur == segs[id].r + 1) {
        if (num == segs[id].num)
            sum++;
        return;
    }
    int ru = findRoot(edges[cur].u), rv = findRoot(edges[cur].v);
    if (ru != rv) {
        root[rv] = ru;
        dfs(id, cur + 1, num + 1);
        root[rv] = rv, root[ru] = ru;
    }
    dfs(id, cur + 1, num);
}

int main()
{
    scan_d(n),scan_d(m);
    range(i, 1, m)
        scan_d(edges[i].u),scan_d(edges[i].v),scan_d(edges[i].w);
    each(i, n + 1)
        root[i] = i;
    sort(edges + 1, edges + m + 1);
    range(i, 1, m)
    {
        if (edges[i].w != edges[i - 1].w) {
            segs[cnt].r = i - 1;
            segs[++cnt].l = i;
        }
        int ru = findRoot(edges[i].u), rv = findRoot(edges[i].v);
        if (ru != rv) {
            root[rv] = ru;
            segs[cnt].num++;
            tot++;
        }
    }
    segs[cnt].r = m ;
    if (tot != n - 1) {
        puts("0");
        return 0;
    }
    each(i, n + 1)
        root[i] = i;
    range(i, 1, cnt)
    {
        sum = 0;
        dfs(i, segs[i].l, 0);
        ans = ans * sum % mod;
        range(j, segs[i].l, segs[i].r)
        {
            int ru = findRoot(edges[j].u), rv = findRoot(edges[j].v);
            if (ru != rv)
                root[rv] = ru;
        }
    }
    printf("%d\n", ans);
    return 0;
}