矩阵快速幂

计蒜客 Coin 附带矩阵加速模板

Posted on

大概是第5次写矩阵快速幂……
这道题在DP还是矩阵上都是属于简单范畴……,但没想到用DP就很烦……

感觉矩阵加速还是挺有必要学一下的……

顺带学了一下什么叫分数取模用逆元

题意:
让你抛 k 次 硬币,求上面朝上的次数为偶数的概率。
单次正面朝上的概率为 \( \dfrac {q} {p} \)

思路:
令 dp [ n ] [ 0 ] 表示抛了 n 次后,朝上次数为偶数的概率,dp[ n ] [ 1 ] 则表示奇数的概率。
得转移方程
$ dp[n][0] = \dfrac {p-q} {p} \times dp[n-1][0] + \dfrac {q} {p} \times dp[n-1][1] $
$ dp[n][1] = \dfrac {q} {p} \times dp[n-1][0] + \dfrac {p-q} {p} \times dp[n-1][1] $

然后矩阵加速一下就解决了……

#include <bits/stdc++.h>

#define each(i, n) for (int(i) = 0; (i) < (n); (i)++)
#define reach(i, n) for (int(i) = n - 1; (i) >= 0; (i)--)
#define range(i, st, en) for (int(i) = (st); (i) <= (en); (i)++)
#define rrange(i, st, en) for (int(i) = (en); (i) >= (st); (i)--)
#define fill(ary, num) memset((ary), (num), sizeof(ary))

using namespace std;

typedef long long ll;

const int inf = 0x3f3f3f3f;

const int max_n = 5;
const int mod = 1e9 + 7;

struct Matrix {
    int mat[max_n][max_n], n;
    Matrix(int _n = 1)
    {
        n = _n;
        memset(mat, 0, sizeof mat);
    }
    Matrix operator*(const Matrix& a) const
    {
        Matrix tmp(n);
        for (int i = 1; i < n; ++i)
            for (int j = 1; j < n; ++j)
                if (mat[i][j])
                    for (int k = 1; k < n; ++k)
                        tmp.mat[i][k] = (tmp.mat[i][k] + 1LL * mat[i][j] * a.mat[j][k] % mod) % mod;
        return tmp;
    }
    Matrix Pow(ll m)
    {
        Matrix ret(n), a(*this);
        for (int i = 1; i < n; i++)
            ret.mat[i][i] = 1;
        while (m) {
            if (m & 1)
                ret = ret * a;
            a = a * a;
            m >>= 1;
        }
        return ret;
    }
};

ll fastPow(ll n, ll m)
{
    ll ret = 1;
    while (m) {
        if (m & 1)
            ret = ret * n % mod;
        m >>= 1;
        n = n * n % mod;
    }
    return ret;
}

int main()
{
    int T;
    ll p, q, k;
    scanf("%d", &T);
    while (T--) {
        scanf("%lld %lld %lld", &p, &q, &k);
        ll up = q * fastPow(p, mod - 2) % mod;
        ll down = (p - q) * fastPow(p, mod - 2) % mod;
        if (k == 1) {
            printf("%lld\n", up);
            continue;
        }
        Matrix base(3), res(3);
        base.mat[1][1] = down, base.mat[1][2] = up;
        base.mat[2][1] = up, base.mat[2][2] = down;
        res = base.Pow(k);
        printf("%d\n", res.mat[1][1]);
    }
    return 0;
}