跳转至

快速幂#
发布于2021-04-18
上次编辑2021-04-18

算法#

快速幂 (binary exponentiation) 即在 \(\mathcal{O}(\log n)\) 的时间复杂度下计算 \(a^n\) 的方法。

主要技巧是将指数 \(n\) 表示为二进制的形式来方便计算。

如计算 \(3^{13}\):

\[ \begin{aligned} 3^{13}&=3^{(1101)_2}\\ &=3^{1\times2^0+0\times2^1+1\times2^2+1\times2^3}\\ &=3^{1\times2^0}\times3^{0\times2^1}\times3^{1\times2^2}\times3^{1\times2^3} \end{aligned} \]
\[ a^{2^b}=a^{2\times2^{b-1}}=\left(a^{2^{b-1}}\right)^2 \]

因为 \(n\)\(\lfloor\log_2n\rfloor+1\) 个二进制位,所以可以在 \(\mathcal{O}(\log n)\) 的时间内计算完 \(a_1,a_2,\cdots,a^{\lfloor\log_2n\rfloor}\) 的值,然后将具体为二进制为 1 的位置的值相加即可。

实现#

递归#

1
2
3
4
5
6
7
def bin_pow(a: int, n: int) -> int:
    if n == 0:
        return 1
    p = bin_pow(a, n >> 1)
    if n & 1:
        return p * p * a
    return p * p

非递归#

1
2
3
4
5
6
7
8
def bin_pow(a: int, n: int) -> int:
    ans = 1
    while n:
        if n & 1:
            ans *= a
        a *= a
        n >>= 1
    return ans

应用#

\(a^n\bmod b\)#

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
def bin_pow(a: int, n: int, b: int) -> int:
    """a^n % b"""

    ans = 1
    a %= b

    while n:
        if n & 1:
            ans = (ans * a) % b
        n >>= 1
        a = (a * a) % b

    return ans

矩阵快速幂#

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
using LL = long long;
using VL = vector<LL>;
using VLL = vector<VL>;
VLL matmul(VLL &A, VLL &B) {
    int n = A.size(), d = A[0].size(), m = B[0].size();
    VLL out(n, VL(m));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < m; ++j) {
            for (int k = 0; k < d; ++k) {
                out[i][j] = out[i][j] + A[i][k] * B[k][j];
            }
        }
    }
    return out;
}
VLL bin_pow(VLL base, int m) {
    int n = base.size();
    VLL out = VLL(n, VL(n));
    for (int i = 0; i < n; ++i) {
        out[i][i] = 1;
    }
    while (m) {
        if (m & 1) {
            out = matmul(out, base);
        }
        base = matmul(base, base);
        m >>= 1;
    }
    return out;
}

斐波那契数列:

\[ \begin{cases} F(n-1) &= 0 \cdot F(n-2) + 1 \cdot F(n-1)\\ F(n) &=1 \cdot F(n-2) + 1 \cdot F(n-1)\\ \end{cases} \]
\[ \begin{aligned} [F(n-1),F(n)] &= [F(n-2),F(n-1)] \cdot \begin{bmatrix} 0 & 1 \\ 1 & 1 \end{bmatrix} \\ &=[F(0), F(1)] \cdot \begin{bmatrix} 0 & 1 \\ 1 & 1 \end{bmatrix}^{n-1} \end{aligned} \]
剑指 Offer 10- I. 斐波那契数列
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
using LL = long long;
using VL = vector<LL>;
using VLL = vector<VL>;
VLL matmul(VLL &A, VLL &B) {
    int n = A.size(), d = A[0].size(), m = B[0].size();
    VLL out(n, VL(m));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < m; ++j) {
            for (int k = 0; k < d; ++k) {
                out[i][j] = (out[i][j] + A[i][k] * B[k][j]) % 1'000'000'007;
            }
        }
    }
    return out;
}
VLL bin_pow(VLL base, int m) {
    int n = base.size();
    VLL out = VLL(n, VL(n));
    for (int i = 0; i < n; ++i) {
        out[i][i] = 1;
    }
    while (m) {
        if (m & 1) {
            out = matmul(out, base);
        }
        base = matmul(base, base);
        m >>= 1;
    }
    return out;
}

class Solution {
public:
    int fib(int n) {
        if (n == 0) return 0;
        VLL base = {{0, 1}, {1, 1}};
        VLL out = bin_pow(base, n - 1);
        return out[1][1];
    }
};
1137. 第 N 个泰波那契数
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
using VL = vector<long long>;
using VLL = vector<VL>;
VLL matmul(VLL &A, VLL &B) {
    int n = A.size(), d = A[0].size(), m = B[0].size();
    VLL out(n, VL(m));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < m; ++j) {
            for (int k = 0; k < d; ++k) {
                out[i][j] = out[i][j] + A[i][k] * B[k][j];
            }
        }
    }
    return out;
}
VLL bin_pow(VLL base, int m) {
    int n = base.size();
    VLL out = VLL(n, VL(n));
    for (int i = 0; i < n; ++i) {
        out[i][i] = 1;
    }
    while (m) {
        if (m & 1) {
            out = matmul(out, base);
        }
        base = matmul(base, base);
        m >>= 1;
    }
    return out;
}

class Solution {
public:
    int tribonacci(int n) {
        if (n == 0) return 0;
        if (n == 1) return 1;
        VLL base = {{0, 0, 1}, {1, 0, 1}, {0, 1, 1}};
        VLL out = bin_pow(base, n - 2);
        return out[1][2] + out[2][2];
    }
};

例题#

返回顶部

在手机上阅读