「CTSC 2010」性能优化 题解


这道题真是奥妙重重, 直接暴露了我 FFT 那一套单位根的东西没搞懂 = =

题意简述

给定两个序列 $a_0, a_1, \ldots, a_{n - 1}$, 以及 $b_0, b_1, \ldots, b_{n - 1}$, 定义一种运算

也就是循环卷积. 求 $a \cdot b ^ c$ 各项模 $n + 1$ 后的值.

其中 $n$ 在质因数分解后每个质因数不超过 $10$, $n + 1$ 为质数, 且 $n \le 5 \times 10 ^ 5$, $a_i, b_i, C \le 10 ^ 9$.

解题思路

先观察循环卷积是个什么东西.

根据 等比数列求和 单位根反演, 有

那么有

稍微整理一下, 得

那么对 $a$, $b$ 做 DFT, 乘起来再做一次 IDFT 即可得出 $c$.

问题在于, $n$ 不满足 FFT 的要求, 即 $n$ 不是 $2$ 的幂次. 现在要处理的事情就是混合基 FFT 的计算.

将 $n$ 质因数分解, 得到

在做朴素 FFT 时有 $n = 2 ^ c$, 每次将序列分作两部分合并. 现在, 将序列分作 $p$ 部分 ($p \in \{2 , 3, 5, 7 \}$) 部分, 然后合并.

重新推导 DFT 的式子. 将多项式的系数按照次项模 $p$ 意义下不同取值分为 $p$ 组, 记作 $F_r(x)$, 即

取 $F_r(x)$ 在 $x = \omega_{n} ^ i (0 \le i < n)$ 处的点值. 同时需要将 $p$ 组点值合并为 $pn$ 组点值. 合并后的多项式为

利用单位根的性质, 可以得到

那么有

其中 $0 \le i < pn$, 不妨将 $i$ 表示为 $an + b$. 则

此时得出了在 $O(p)$ 的时间复杂度内合并一处点值的方法.

IDFT 时类似, 根据下式可在 DFT 的基础上简单计算.

如果追求迭代实现, 那么在预处理每个元素分治时的位置需要精巧实现.

注意到 $n + 1$ 保证为质数, 且根据 $n$ 的性质可以快速计算原根. 利用原根作为单位根就好了.

再来分析时间复杂度. 分治计算的过程中, 层数 $\sum\limits_{i = 1} ^ 4 c_i = O(\log n)$, 每层合并的复杂度为 $O(pn)$ 且 $p$ 为常数, 因此总时间复杂度为 $O(n \log n)$.

代码实现

参考了 Weng_Weijie 的迭代实现. 问就是奥妙重重

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
// Luogu P4191
// DeP
#include <cctype>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

#define DEBUG(args...) fprintf(stderr, ##args)

namespace IO {
const int MAXSIZE = 1 << 18 | 1;
char buf[MAXSIZE], *p1, *p2;

inline int Gc() {
return p1 == p2 &&
(p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin), p1 == p2)? EOF: *p1++;
}
template<typename T> inline void read(T& x) {
x = 0; int f = 0, ch = Gc();
while (!isdigit(ch)) f |= ch == '-', ch = Gc();
while (isdigit(ch)) x = x * 10 + ch - '0', ch = Gc();
if (f) x = -x;
}
}
using IO::read;

const int MAXN = 5e5 + 5, MAXM = 114514;

int P, W[MAXN];
int fact[MAXM], tot;

int fpow(int base, int b, int m = P) {
int ret = 1;
while (b > 0) {
if (b & 1) ret = 1LL * ret * base % m;
base = 1LL * base * base % m, b >>= 1;
}
return ret % m;
}

namespace Poly {
// f(w ^ {an+b}_{pn}) = sum (w ^ {an+b}_{pn}) ^ r f_r (w ^ {b}_{n})
int r[MAXN];
void Rev(int* f, const int& n) {
for (int k = tot, Mid = n; k; Mid /= fact[k--]) {
// i, l = an, j = b
for (int idx = 0, i = 0; i < n; i += Mid)
for (int j = 0; j < fact[k]; ++j)
for (int l = 0; l < Mid; l += fact[k]) r[idx++] = f[i + j + l];
for (int i = 0; i < n; ++i) f[i] = r[i];
}
}

void NTT(int* f, const int& n, const int& type) {
static int tmp[MAXN];
Rev(f, n);
for (int k = 1, Mid = 1; k <= tot; Mid *= fact[k++]) {
const int& unit = W[n / (Mid * fact[k])];
memset(tmp, 0, n * sizeof (int));
// i = an, j = b, l = b'
for (int i = 0; i < n; i += Mid * fact[k]) {
int wk = 1;
for (int j = 0; j < Mid * fact[k]; ++j) {
for (int w = 1, l = j % Mid; l < Mid * fact[k]; l += Mid) {
tmp[i + j] = (tmp[i + j] + 1LL * w * f[i + l] % P) % P;
w = 1LL * w * wk % P;
}
wk = 1LL * wk * unit % P;
}
}
memcpy(f, tmp, n * sizeof (int));
}
if (type < 0) {
// n * n = 1 (mod n + 1)
for (int i = 0; i < n; ++i) f[i] = 1LL * f[i] * n % P;
reverse(f + 1, f + n);
}
}
}

int proot(int p) {
int phi = p - 1; tot = 0;
for (int d = 2; d*d <= phi; ++d)
while (phi % d == 0) phi /= d, fact[++tot] = d;
if (phi > 1) fact[++tot] = phi;
phi = p - 1;
for (int g = 2; g < p; ++g) {
bool flag = true;
for (int i = 1; i <= tot && flag; ++i)
if (fpow(g, phi / fact[i], p) == 1) flag = false;
if (flag) return g;
}
return -1;
}

int n, C;
int A[MAXN], B[MAXN];

inline void PolyPre() {
W[0] = 1, W[1] = proot(P);
for (int i = 2; i < n; ++i) W[i] = 1LL * W[1] * W[i-1] % P;
}

int main() {
#ifndef ONLINE_JUDGE
freopen("input.in", "r", stdin);
#endif
// input
read(n), read(C);
for (int i = 0; i < n; ++i) read(A[i]);
for (int i = 0; i < n; ++i) read(B[i]);
// solve
P = n + 1, PolyPre();
Poly::NTT(A, n, 1), Poly::NTT(B, n, 1);
for (int i = 0; i < n; ++i)
A[i] = 1LL * A[i] * fpow(B[i], C) % P;
Poly::NTT(A, n, -1);
// output
for (int i = 0; i < n; ++i) printf("%d\n", A[i]);
return 0;
}

参考资料