「CERC2014」Virus synthesis 题解


一道回文自动机 + DP 的题, 感觉能启发思路以及学习到一些技巧, 故记之.

题目思路

假设字符串 $S$ 的长度为 $n$, 容易想到 $O(n^3)$ 的区间 DP 做法, 但是不大可取 = =, 状态数就已经到达了 $O(n^2)$.

于是换一种思路, 设 $f(i)$ 表示构造 PAM 上第 $i$ 个节点所代表的回文串, 且 $len(i)$ 为偶数, 所需要的最少操作次数, 那么

  • 对于 $i$ 的子节点 $j$, 有 $f(j) = f(i) + 1$

    因为 PAM 的节点表示一个回文串, 那么 $j$ 代表的回文串可以视为 $i$ 代表的回文串两端多出一个字符. 因为 $i$ 是回文串, 考虑在构造字符串 $i$ 时, 一定会有一次 2 操作, 否则就不是最优解. 那么, 我们就可以在这次 2 操作之前, 在 $i$ 的某端添加一个字符, 这样从 $i$ 到 $j$ 只需要一次操作.

  • 假设 $j$ 是 $i$ 的一个回文后缀, 且满足 $2 \cdot len(j) \leq len(i)$, 有

这里将 $i$ 视为 $j$ 通过 1 操作填到一半, 再通过一次 2 操作构造而来.

然后答案很好统计了, 可以把最终的字符串视为一个回文串和多次 1 操作堆叠而成, 那么

现在的问题就是如何实现转移.

对于第一个转移, 没什么好说的, 在 PAM 偶节点从上到下 BFS 一遍就好了, (似乎有在线的做法, 在新加入节点时更新 $f(i)$ 的值).

对于第二个转移, 需要处理出满足 $i$ 满足 $2 \cdot len(j) \leq len(i)$ 的最长回文后缀, 记为 $trans(i)$ 可以通过求 fail 类似的方法维护, 只是在跳 fail 的时候多了一个限制条件, 相同的技巧也在 [SHOI2011]双倍回文 使用过.

简单实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
// Luogu P4762
// DeP
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int MAXN = 1e5+5, SIGMA = 4, INF = 0x3f3f3f3f;

int n;
int M[128], f[MAXN];
char S[MAXN];

namespace PAM {
int ch[MAXN][SIGMA], len[MAXN], trans[MAXN], fail[MAXN], last, nidx, ptr;
char S[MAXN];

inline void init() {
memset(ch, 0, sizeof ch);
last = 0, nidx = 1;
len[0] = 0, fail[0] = 1, len[1] = -1;
S[ptr = 0] = '$';
}

inline int getfail(int u) {
while (S[ptr-len[u]-1] != S[ptr]) u = fail[u];
return u;
}

inline void insert(char c) {
S[++ptr] = c;
int val = M[(int) c], nd = getfail(last);
if (!ch[nd][val]) {
int p = ++nidx;
len[p] = len[nd] + 2;
fail[p] = ch[getfail(fail[nd])][val];
if (len[p] <= 2) trans[p] = fail[p];
else {
int u = trans[nd];
// int u = fail[nd];
// 如果写成以上写法, 复杂度就是假的 = =
while (S[ptr-len[u]-1] != S[ptr] || 2*(len[u]+2) > len[p]) u = fail[u];
trans[p] = ch[u][val];
}
ch[nd][val] = p;
}
last = ch[nd][val];
}

int solve() {
static int Q[MAXN], head, tail, ret;
memset(f, 0x3f, sizeof f);
ret = n, f[0] = 1, Q[head = tail = 1] = 0;
// 注意答案的最大值为 n, 如果不存在长度为偶数的回文串, 那么 min{f} = INF...
while (head <= tail) {
int u = Q[head++];
for (int c = 0; c < SIGMA; ++c) {
int i = ch[u][c];
if (!i) continue;
int j = trans[i];
f[i] = min(f[i], min(f[u] + 1, f[j] + 1 + len[i] / 2 - len[j]));
ret = min(ret, f[i] + n - len[i]);
Q[++tail] = i;
}
}
return ret;
}
}

int main() {
#ifndef ONLINE_JUDGE
freopen("1.in", "r", stdin);
#endif
M['A'] = 0, M['C'] = 1, M['G'] = 2, M['T'] = 3;
int Ti;
scanf("%d", &Ti);
while (Ti--) {
PAM::init();
scanf("%s", S+1);
n = (int) strlen(S+1);
for (int i = 1; i <= n; ++i) PAM::insert(S[i]);
printf("%d\n", PAM::solve());
}
return 0;
}

一些技巧

上述代码成功拿到了 67 pts 的好成绩, 在第一个点跑了很久…

通过学习其他人的卡常技巧, 算是卡了过去

  1. 用到再初始化 ch 的值
  2. BFS 遍历 PAM 时, 每个节点只更新一次
  3. 给 $f(i)$ 赋初值为节点的长度 $len(i)$, 并只在 $len(i)$ 为偶数的时候更新 $f(i)$
  4. 使用 ckmin 更新最小值

话说用 static 把数组开到函数里会快一点? 不明所以.png

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
// Luogu P4762
// DeP
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int MAXN = 1e5+5, SIGMA = 4, INF = 0x3f3f3f3f;

int n;
int M[128], f[MAXN];
char S[MAXN];

namespace PAM {
int ch[MAXN][SIGMA], len[MAXN], trans[MAXN], fail[MAXN], last, nidx, ptr;
char S[MAXN];

inline void init() {
// 1.
memset(ch[0], 0, sizeof ch[0]);
memset(ch[1], 0, sizeof ch[1]);
last = 0, nidx = 1;
len[0] = 0, fail[0] = 1, len[1] = -1;
S[ptr = 0] = '$';
}

inline int getfail(int u) {
while (S[ptr-len[u]-1] != S[ptr]) u = fail[u];
return u;
}

inline void insert(char c) {
S[++ptr] = c;
int val = M[(int) c], nd = getfail(last);
if (!ch[nd][val]) {
int p = ++nidx;
// 1.
memset(ch[p], 0, sizeof ch[p]);
len[p] = len[nd] + 2;
fail[p] = ch[getfail(fail[nd])][val];
if (len[p] <= 2) trans[p] = fail[p];
else {
int u = trans[nd];
while (S[ptr-len[u]-1] != S[ptr] || 2*(len[u]+2) > len[p]) u = fail[u];
trans[p] = ch[u][val];
}
ch[nd][val] = p;
}
last = ch[nd][val];
}

// 4.
template<typename T> inline void ckmin(T& x, const T& y) { if (x > y) x = y; }

int solve() {
static int vis[MAXN], Time, Q[MAXN], head, tail, ret;
++Time;
// 3.
for (int i = 2; i <= nidx; ++i) f[i] = len[i];
ret = n, f[0] = 1, Q[head = tail = 1] = 0;
while (head <= tail) {
int u = Q[head++];
for (int c = 0; c < SIGMA; ++c) {
int i = ch[u][c];
if (!i) continue;
int j = trans[i];
f[i] = f[u] + 1;
// 3.
if (len[i] % 2 == 0)
ckmin(f[i], f[j] + 1 + len[i] / 2 - len[j]);
ckmin(ret, f[i] + n - len[i]);
// 2.
if (vis[i] != Time) Q[++tail] = i, vis[i] = Time;
}
}
return ret;
}
}

int main() {
#ifndef ONLINE_JUDGE
freopen("1.in", "r", stdin);
#endif
M['A'] = 0, M['C'] = 1, M['G'] = 2, M['T'] = 3;
int Ti;
scanf("%d", &Ti);
while (Ti--) {
PAM::init();
scanf("%s", S+1);
n = (int) strlen(S+1);
for (int i = 1; i <= n; ++i) PAM::insert(S[i]);
printf("%d\n", PAM::solve());
}
return 0;
}