欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

洛谷P3193 [HNOI2008]GT考试(dp 矩阵乘法)

程序员文章站 2023-09-07 16:50:46
题意 "题目链接" Sol 设$f[i][j]$表示枚举到位置串的第i位,当前与未知串的第j位匹配,那么我们只要保证在转移的时候永远不会匹配即可 预处理出已知串的每个位置加上某个字符后能转移到的位置,矩阵快速幂优化一下 复杂度$O(M^3 \log n)$ cpp include using nam ......

题意

题目链接

sol

\(f[i][j]\)表示枚举到位置串的第i位,当前与未知串的第j位匹配,那么我们只要保证在转移的时候永远不会匹配即可

预处理出已知串的每个位置加上某个字符后能转移到的位置,矩阵快速幂优化一下

复杂度\(o(m^3 \log n)\)

#include<bits/stdc++.h>
using namespace std;
const int maxn = 22;
int n, m, mod,  s[maxn], trans[maxn][10], p[maxn], g[maxn], base[maxn];
char ss[maxn];
template<typename a, typename b> inline void add2(a &x, b y) {
    if(x + y < 0) x = x + y + mod;
    else x = x + y >= mod ? x + y - mod : x + y;
}
int lim;
struct ma {
    int m[maxn][maxn];
    ma() {
        memset(m, 0, sizeof(m));
    }
    void init() {
        for(int i = 0; i <= lim; i++) m[i][i] = 1;
    }
    ma operator * (const ma &rhs) const {
        ma ans;
        for(int i = 0; i <= lim; i++)
            for(int j = 0; j <= lim; j++) {
                __int128 tmp = 0;
                for(int k = 0; k <= lim; k++) tmp += 1ll * m[i][k] * rhs.m[k][j] % mod;
                ans.m[i][j] = tmp % mod;
            }
        return ans;
    }
}f;
void getnxt() {
    int j = 0;
    for(int i = 0; i <= m; i++) {
        if(i > 1) {
            while(j && s[i] != s[j + 1]) j = p[j];
            if(s[i] == s[j + 1]) j++;
            p[i] = j;
        }
        for(int t = 0; t <= 9; t++) {
            int k = i;
            while(k && t != s[k + 1]) k = p[k];
            if(t == s[k + 1]) k++;
            trans[i][t] = k;
        }
    }
}
ma mpow(ma a, int p) {
    ma base; base.init();
    while(p) {
        if(p & 1) base = base * a;
         a = a * a; p >>= 1;    
    }
    return base;
}
int main() {
    cin >> n >> m >> mod; lim = m + 1;
    scanf("%s", ss + 1);
    for(int i = 1; i <= m; i++) s[i] = ss[i] - '0';
    for(int i = 0; i <= 9; i++) g[i == s[1]]++;
    getnxt();
    for(int j = 0; j <= m; j++) 
        for(int k = 0; k <= 9; k++) 
            if(trans[j][k] != m)
                f.m[trans[j][k]][j]++;
    ma tmp = mpow(f, n - 1);
    for(int i = 0; i <= lim; i++) 
        for(int j = 0; j <= lim; j++)
            add2(base[i], 1ll * tmp.m[i][j] * g[j] % mod);
    int ans = 0;
    for(int i = 0; i <= m - 1; i++) add2(ans, base[i]);
    cout << ans;
    return 0;
}
/*
4 3 100
121
*/