1 条题解

  • 0
    @ 2024-12-24 1:27:58

    这道题的数据范围比D的大,并且不能通过递推的方式求解。

    所以我们可以用一个终极大招,矩阵加速求解递推式。这是一种很常见的求数据范围太大的递推式的方式。

    什么是矩阵加速呢。那么首先请出我们的斐波那契数列来演示一下。

    fn=fn1+fn2f_n=f_{n-1}+f_{n-2} 这是斐波那契的递推式,常规方法是 O(n)O(n) 的时间求解,在数据太大时就失效了。

    我们从线性代数的角度思考,这个 fn,fn1,fn2f_n,f_{n-1},f_{n-2} 之间的关系。可以自己假定一个关系:我们可以通过一个变换,将 fn2f_{n-2} 变到 fn1f_{n-1} 变到 fnf_n

    fn2fn1fnf_{n-2}\to f_{n-1} \to f_n

    我们都是用相同的变换,在线性代数里,变换通常是一个矩阵,那么我们可以思考

    $$\begin{pmatrix} f_{n-1} \\ f_{n-2} \end{pmatrix} \to \begin{pmatrix} f_{n} \\ f_{n-1} \end{pmatrix} $$

    很明显,我们左乘一个系数矩阵,就可以变成右边的矩阵了,即:

    $$\begin{pmatrix} 1 & 1 \\ 1 & 0 \end{pmatrix} \begin{pmatrix} f_{n-1} \\ f_{n-2} \end{pmatrix} = \begin{pmatrix} f_{n} \\ f_{n-1} \end{pmatrix} $$

    那么即使我们要求下一项,我们也可以再在这个基础上乘系数矩阵:

    $$\begin{pmatrix} 1 & 1 \\ 1 & 0 \end{pmatrix} \begin{pmatrix} 1 & 1 \\ 1 & 0 \end{pmatrix} \begin{pmatrix} f_{n-1} \\ f_{n-2} \end{pmatrix} = \begin{pmatrix} f_{n+1} \\ f_{n} \end{pmatrix} $$

    又矩阵具有乘法结合律,所以我们可以写成这种形式:

    $$\begin{pmatrix} 1 & 1 \\ 1 & 0 \end{pmatrix}^2 \begin{pmatrix} f_{n-1} \\ f_{n-2} \end{pmatrix} = \begin{pmatrix} f_{n+1} \\ f_{n} \end{pmatrix} $$

    初始条件是 f1=1,f2=1f_1=1,f_2=1,所以具体的表示式:

    $$\begin{pmatrix} 1 & 1 \\ 1 & 0 \end{pmatrix}^{n-2} \begin{pmatrix} f_{2} \\ f_{1} \end{pmatrix} = \begin{pmatrix} f_{n} \\ f_{n-1} \end{pmatrix} $$

    那么我们的时间复杂度就是计算这个矩阵乘法多次幂的时间复杂度了,又因为矩阵多次幂也可以用快速幂计算,那么时间复杂度为:O(m3logk)O(m^3\log k)

    mm 为系数矩阵维度,kk 为指数。

    那么回到这个题,我们有两个递推式,fn,gnf_n,g_n 我们需要找到他们之间的关系。

    我们先看对应关系,应该是先算 ff 再算 gg,我们可以合起来一起算。

    可以先写出来看看:

    $$\begin{pmatrix} g_{n-1} \\ f_{n} \\ f_{n-1} \\ ... \end{pmatrix} \to \begin{pmatrix} g_{n} \\ f_{n+1} \\ f_{n} \\ ... \end{pmatrix} $$

    只有这些肯定是不够的,因为还有常数 11,跟随次数变化的 nn ,我们需要添加额外的信息。

    首先看 gng_n,它需要 gn1,fn,n2g_{n-1},f_n,n^2,所以要增加一项 n2n^2

    fn+1f_{n+1},需要 fn,fn1,n+1f_n,f_{n-1},n+1,所以要增加一项 nn ,一项常数 11

    为什么不直接是 n+1n+1,因为在 n2n^2 变为 (n+1)2(n+1)^2 的平方的时候不好直接变换,将 (n+1)2=n2+2n+1(n+1)^2=n^2+2n + 1 这样就可以直接从 n2n^2 加上某些项变为 (n+1)2(n+1)^2

    所以我们的对应关系就是:

    $$\begin{pmatrix} g_{n-1} \\ f_{n} \\ f_{n-1} \\ n \\ n^2 \\ 1 \end{pmatrix} \to \begin{pmatrix} g_{n} \\ f_{n+1} \\ f_{n} \\ n + 1 \\ (n+1)^2 \\ 1 \end{pmatrix} $$

    我们现在来填系数矩阵:

    $$\begin{pmatrix} 1&1&0&0&1&0\\ 0&2&3&1&0&1\\ 0&1&0&0&0&0\\ 0&0&0&1&0&1\\ 0&0&0&2&1&1\\ 0&0&0&0&0&1\\ \end{pmatrix} \begin{pmatrix} g_{n-1} \\ f_{n} \\ f_{n-1} \\ n \\ n^2 \\ 1 \end{pmatrix} = \begin{pmatrix} g_{n} \\ f_{n+1} \\ f_{n} \\ n + 1 \\ (n+1)^2 \\ 1 \end{pmatrix} $$

    最后将初始值带入即可得到最后的表达式:

    $$\begin{pmatrix} 1&1&0&0&1&0\\ 0&2&3&1&0&1\\ 0&1&0&0&0&0\\ 0&0&0&1&0&1\\ 0&0&0&2&1&1\\ 0&0&0&0&0&1\\ \end{pmatrix}^{n-1} \begin{pmatrix} g_{1} \\ f_{2} \\ f_{1} \\ 2 \\ 4 \\ 1 \end{pmatrix} = \begin{pmatrix} g_{n} \\ f_{n+1} \\ f_{n} \\ n + 1 \\ (n+1)^2 \\ 1 \end{pmatrix} $$

    最后的答案就是系数矩阵的第一行和初始矩阵对应相乘再相加。

    时间复杂度为:O(m3logk)O(m^3\log k)

    参考代码:

    #include <bits/stdc++.h>
    
    using u32 = unsigned;
    using i64 = long long;
    using u64 = unsigned long long;
    using u128 = unsigned __int128;
    
    constexpr int mod = 998244353, N = 200010;
    
    struct Matrix {
    	std::vector<std::vector<i64>> matrix;
    	int n;
    	Matrix(int _n): n(_n) {
    		matrix.assign(n, std::vector<i64>(n, 0));
    	}
    	static Matrix getIdentity(int n) {
    		Matrix res(n);
    		for (int i = 0; i < n; i++) {
    			res.matrix[i][i] = 1;
    		}
    		return res;
    	}
    	const Matrix operator*(const Matrix &b) const {
    		Matrix res(n);
    		for (int i = 0; i < n; i++) {
    			for (int j = 0; j < n; j++) {
    				for (int k = 0; k < n; k++) {
    					res.matrix[i][j] = (res.matrix[i][j] % mod + matrix[i][k] * b.matrix[k][j] % mod) % mod;
    				}
    			}
    		}
    		return res;
    	}
    };
    
    Matrix power(Matrix base, i64 b) {
    	Matrix res = Matrix::getIdentity(base.n);
    	while (b) {
    		if (b & 1) res = res * base;
    		b >>= 1;
    		base = base * base;
    	}
    	return res;
    }
    
    void solve() {
    	int n;
    	std::cin >> n;
    	Matrix B(6);
    	B.matrix = {
    		{1, 1, 0, 0, 1, 0},  // g(n) = g(n-1) + f(n) + n^2
            {0, 2, 3, 1, 0, 1},  // f(n+1) = 2f(n) + 3f(n-1) + n+1
            {0, 1, 0, 0, 0, 0},  // f(n) = f(n)
            {0, 0, 0, 1, 0, 1},  // n = n + 1
            {0, 0, 0, 2, 1, 1},  // n^2 = n^2 + 2n + 1
            {0, 0, 0, 0, 0, 1}   // 常数 1
    	};
    
    	std::vector<i64> v = {2, 2, 1, 2, 4, 1}; // 初始状态:g(1), f(2), f(1), n=2, n^2=4, 常数1
    	B = power(B, n - 1);
    
    	i64 ans = 0;
    	for (int i = 0; i < 6; i++) {
    		ans = (ans + B.matrix[0][i] * v[i] % mod) % mod;
    	}
    	std::cout << ans << "\n";
    }
    
    int main() {
    	std::ios::sync_with_stdio(false);
    	std::cin.tie(nullptr);
    
    	int t;
    	std::cin >> t;
    	while (t--) {
    		solve();
    	}
    	
    	return 0;
    }
    
    • 1

    信息

    ID
    155
    时间
    2000ms
    内存
    256MiB
    难度
    10
    标签
    递交数
    11
    已通过
    2
    上传者