分类
OI

AC 自动机 — LOJ 3089 「BJOI2019」奥术神杖

0 说在之前

我吐了,这题我写了两天……

考虑到我自己写的博客还没有 AC 自动机的,我会简单写一下

1 AC 自动机

1.0 什么是 AC 自动机

有一个说烂但是很形象的说法 Trie + KMP

AC 自动机用于多模式串匹配

就是你拿一个字符串,和一堆字符串

然后 AC 自动机可以让你快速的知道这一堆字符串中,那些是你这一个字符串的子串

1.1 构建

AC 自动机的第一步是建立 Trie 树,这并不复杂

问题在于,如果仅仅建立一个 Trie 树,想要很快的完成上面的任务还是略显艰难

跟 KMP 一样,如果我们可以使用一个类似 next 数组的指针,告诉我们失配后应该跳到哪里能最大化效率就好了

这个东西就叫做 fail 指针(失配指针)

建立完 Trie 树之后,我们可以像建立 next 数组一样建立 fail 指针

假设到遍历到某一节点时,其父亲节点及深度比他小的 fail 指针已经建立完毕

那么对于其子节点,如果存在,那么我们需要顺延现在节点的 fail 指针即可

如果不存在,那么就直接指向 fail 的对应点即可

2 LOJ 3089

2.1 思路

AC 自动机上 DP

事实上我是第一次接触这种科技

首先对所有模式串跑 AC 自动机

原来的答案带根号,取 ln

然后就变成了喜闻乐见的 01 分数规划

$$
\begin{aligned}
& \ln\sqrt[k]{\Pi_{i=1}^k maigc_i} \\
= & \frac{1}{k} \ln \Pi_{i=1}^k maigc_i \\
= & \frac{ \sum_{i=1}^k \ln maigc_i }{k}
\end{aligned}
$$

(关于为什么能这么变 Wikipedia - 对数

f[i][j] 表示到了到了第 $i$ 位第 $j$ 个节点的结果

剩下就是 dp 和记录方案了

2.2 Code

#include <cmath>
#include <cstdio>
#include <cstring>

#include <queue>
#include <algorithm>

const int N = 2100;
const double eps = 1e-7;

int n, m;
char T[N];

int node_cnt = 1;
struct ac_node {
    int nxt[20], fail;
    double val, cnt;
    ac_node() {
        memset( nxt, 0, sizeof(nxt) );
        fail = 1;
        val = cnt = 0;
    }
} tree[ N * 10 ];

void insert( char *s, double val ) {
    int cur = 1;
    for( ; *s; s ++ ) {
        int cur_char = *s - '0';
        int &nxt = tree[cur].nxt[cur_char];
        if( nxt == 0 ) {
            node_cnt ++;
            nxt = node_cnt;
        }
        cur = nxt;
    }
    tree[cur].cnt ++;
    tree[cur].val += val;
}

void build() {
    std::queue<int> q;
    for( int i = 0; i < 20; i ++ ) {
        if( tree[1].nxt[i] == 0 ) 
            tree[1].nxt[i] = 1;
        else 
            q.push( tree[1].nxt[i] );
    }
    while( !q.empty() ) {
        int cur = q.front(); q.pop();
        int fail = tree[cur].fail;
        tree[cur].cnt += tree[fail].cnt;
        tree[cur].val += tree[fail].val;
        for( int i = 0; i < 10; i ++ ) {
            if( tree[cur].nxt[i] ) {
                tree[ tree[cur].nxt[i] ].fail = tree[fail].nxt[i];
                q.push( tree[cur].nxt[i] );
            }
            else 
                tree[cur].nxt[i] = tree[fail].nxt[i];
        }
    }
}

double f[N][N];
struct last { int la, cur_char; } la[N][N];

char ans[N];
void update_ans( int id, int pos ) {
    if( id != 1 ) 
        update_ans( id - 1, la[id][pos].la );
    ans[id] = la[id][pos].cur_char + '0';
}

inline void update( double &cur, double nxt, last &la, last upd ) {
    if( nxt > cur ) {
        cur = nxt;
        la = upd;
    }
}

bool check( double mid ) {
    double INF = 1e16;
    for( int i = 0; i <= n; i ++ ) {
        for( int j = 0; j <= node_cnt; j ++ ) {
            f[i][j] = -INF;
        }
    }
    f[0][1] = 0;

    for( int i = 1; i <= n; i ++ ) {
        if( T[i] == '.' ) {
            for( int j = 1; j <= node_cnt; j ++ ) {
                for( int k = 0; k < 10; k ++ ) {
                    if( f[ i - 1 ][j] == -INF )
                        continue;
                    int nxt = tree[j].nxt[k];
                    update( f[i][nxt], f[ i - 1 ][j] + tree[nxt].val - tree[nxt].cnt * mid, 
                            la[i][nxt], (last){ j, k } );
                }
            }
        }
        else {
            int k = T[i] - '0';
            for( int j = 1; j <= node_cnt; j ++ ) {
                if( f[ i - 1 ][j] == -INF )
                    continue;
                int nxt = tree[j].nxt[k];
                update( f[i][nxt], f[ i - 1 ][j] + tree[nxt].val - tree[nxt].cnt * mid,
                        la[i][nxt], (last){ j, k } );
            }
        }
    }

    int pos = 0;
    for( int i = 1; i <= node_cnt; i ++ ) {
        if( f[n][i] > f[n][pos] ) 
            pos = i;
    }
    if( f[n][pos] > 0 ) {
        update_ans( n, pos );
        return 1;
    }
    return 0;
}

void write( int x, int pos ) {
    if( x != 1 ) 
        write( x - 1, la[x][pos].la );
    printf( "%d", la[x][pos].cur_char );
}

int main() {
#ifdef woshiluo
    freopen( "loj.3089.in", "r", stdin );
    freopen( "loj.3089.out", "w", stdout );
#endif
    scanf( "%d%d", &n, &m );
    scanf( "%s", T + 1 );
    double left = 0, rig = 0;
    for( int i = 1; i <= m; i ++ ) {
        double val;
        char str[N];
        scanf( "%s%lf", str + 1, &val );
        val = std::log( val );
        rig = std::max( rig, val );
        insert( str + 1, val );
    }

    build();

    while( left + eps <= rig ) {
        double mid = ( left + rig ) / 2;
        if( check(mid) ) {
            left = mid + eps;
        }
        else 
            rig = mid - eps;
    }

    printf( "%s\n", ans + 1 );
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注