「51nod 1348」乘积之和 题解


题目链接

分治 FFT 如果使用三模数 NTT, 那么每次合并只需要记录实际模数下的答案…

题目思路

首先容易发现, 多次询问是个假的限制, 实际上我们可以预先计算出所以答案, 然后每次 $O(1)$ 回答…

对于每个数, 有选和不选两种选择, 容易构造出生成函数 $g(x)$ 如下, 选择 $k$ 个数的乘积即为 $x^k$ 对应项的系数.

现在的问题就是如何计算这个式子, 有一种无脑且直接的方法就是分治 FFT…

不过值得注意的是, 这屑题的模数为 $100003 = 2\times 3\times 7\times 2381 + 1$, 并不是友好的 NTT 模数, 于是使用三模数 NTT, (实际上这个模数很小, 用双模数 NTT 就足够了), 统计答案的时候中国剩余定理合并即可.

然后我就写出了这样的乐色代码

1
2
3
4
5
6
7
8
9
10
11
void divide(int L, int R, int* a, Num* f) {
if (L == R) return f[0] = Num(1), void( f[1] = Num(a[L]) );
int Mid = (L + R) / 2, Lim = 1, K = 0;
Num *f0 = tmp[++idx], *f1 = tmp[++idx];
divide(L, Mid, a, f0), divide(Mid+1, R, a, f1);
while (Lim <= R-L+1) Lim <<= 1, ++K;
init(Lim, K), NTT(f0, Lim, 1), NTT(f1, Lim, 1);
for (int i = 0; i < Lim; ++i) f[i] = f0[i] * f1[i], f0[i] = f1[i] = Num(0);
// 注意到这里 f 为在三模数意义下的三个值, 而不是模 100003 意义下的答案 (
NTT(f, Lim, -1), idx -= 2;
}

然后接下来就调了好长时间 = =, 终于发现了这个锅, 其实想一想也挺有道理的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
void divide(int L, int R, int* a, int* f) {
if (L == R) return f[0] = 1, void( f[1] = a[L] );
int Mid = (L + R) / 2, *f0 = tmp[++idx], *f1 = tmp[++idx];
divide(L, Mid, a, f0), divide(Mid+1, R, a, f1);
int Lim = 1, K = 0;
while (Lim <= R-L+1) Lim <<= 1, ++K;
for (int i = 0; i <= Mid-L+1; ++i) A[i] = Num(f0[i]);
for (int i = Mid-L+2; i < Lim; ++i) A[i] = Num(0);
for (int i = 0; i <= R-Mid; ++i) B[i] = Num(f1[i]);
for (int i = R-Mid+1; i < Lim; ++i) B[i] = Num(0);
init(Lim, K), NTT(A, Lim, 1), NTT(B, Lim, 1);
for (int i = 0; i < Lim; ++i) A[i] = A[i] * B[i], f0[i] = f1[i] = 0;
NTT(A, Lim, -1), idx -= 2;
for (int i = 0; i <= R-L+1; ++i) f[i] = A[i].Merge();
// 只记录模 100003 下的答案
}

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
// 51nod 1348
// DeP
#include <cctype>
#include <cstdio>
#include <algorithm>
using namespace std;

namespace IO {
const int MAXSIZE = 1 << 18 | 1;
char buf[MAXSIZE], *p1, *p2;

inline int Gc() {
return p1 == p2 &&
(p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin), p1 == p2)? EOF: *p1++;
}
template<typename T> void read(T& x) {
x = 0; int f = 0, ch = Gc();
while (!isdigit(ch)) f |= ch == '-', ch = Gc();
while (isdigit(ch)) x = x * 10 + ch - '0', ch = Gc();
if (f) x = -x;
}
}
using IO::read;

typedef long long LL;
const int MAXN = 3e5+5, invP1 = 669690699, invP12 = 354521948;
const int P1 = 998244353, P2 = 1004535809, P3 = 469762049, G = 3, mod = 100003;
const LL P12 = 1LL * P1 * P2;

struct Num {
int a, b, c;
Num() { a = b = c = 0; }
Num(int _x): a(_x), b(_x), c(_x) { }
Num(int _a, int _b, int _c): a(_a), b(_b), c(_c) { }
Num Mod(const Num& x) const {
return Num(x.a + (x.a >> 31 & P1), x.b + (x.b >> 31 & P2), x.c + (x.c >> 31 & P3));
}
Num operator + (const Num& rhs) const { return Mod(Num(a + rhs.a - P1, b + rhs.b - P2, c + rhs.c - P3)); }
Num operator - (const Num& rhs) const { return Mod(Num(a - rhs.a, b - rhs.b, c - rhs.c)); }
Num operator * (const Num& rhs) const {
return Num(1LL * a * rhs.a % P1, 1LL * b * rhs.b % P2, 1LL * c * rhs.c % P3);
}
int Merge() {
LL x = 1LL * (b - a + P2) % P2 * invP1 % P2 * P1 + a;
return (1LL * (c - x % P3 + P3) % P3 * invP12 % P3 * (P12 % mod) % mod + x) % mod;
}
};

int fpow(int base, int b, int P) {
int ret = 1;
while (b > 0) {
if (b & 1) ret = 1LL * base * ret % P;
base = 1LL * base * base % P, b >>= 1;
}
return ret;
}

namespace Poly {
Num A[MAXN], B[MAXN];
int r[MAXN], tmp[31][MAXN], idx = -1;
inline void init(const int& Lim, const int& L) {
for (int i = 1; i < Lim; ++i) r[i] = (r[i>>1] >> 1) | ((i & 1) << (L-1));
}

void NTT(Num* f, const int& Lim, const int& type) {
for (int i = 1; i < Lim; ++i) if (i < r[i]) swap(f[i], f[r[i]]);
for (int Mid = 1; Mid < Lim; Mid <<= 1) {
Num unit( fpow(type > 0? G: fpow(G, P1-2, P1), (P1-1) / (Mid << 1), P1),
fpow(type > 0? G: fpow(G, P2-2, P2), (P2-1) / (Mid << 1), P2),
fpow(type > 0? G: fpow(G, P3-2, P3), (P3-1) / (Mid << 1), P3) );
for (int i = 0; i < Lim; i += Mid << 1) {
Num w(1);
for (int j = 0; j < Mid; ++j, w = w * unit) {
Num f0 = f[i+j], f1 = w * f[i+j+Mid];
f[i+j] = f0 + f1, f[i+j+Mid] = f0 - f1;
}
}
}
if (type < 0) {
Num inv( fpow(Lim, P1-2, P1), fpow(Lim, P2-2, P2), fpow(Lim, P3-2, P3) );
for (int i = 0; i < Lim; ++i) f[i] = f[i] * inv;
}
}

void divide(int L, int R, int* a, int* f) {
if (L == R) return f[0] = 1, void( f[1] = a[L] );
int Mid = (L + R) / 2, *f0 = tmp[++idx], *f1 = tmp[++idx];
divide(L, Mid, a, f0), divide(Mid+1, R, a, f1);
int Lim = 1, K = 0;
while (Lim <= R-L+1) Lim <<= 1, ++K;
for (int i = 0; i <= Mid-L+1; ++i) A[i] = Num(f0[i]);
for (int i = Mid-L+2; i < Lim; ++i) A[i] = Num(0);
for (int i = 0; i <= R-Mid; ++i) B[i] = Num(f1[i]);
for (int i = R-Mid+1; i < Lim; ++i) B[i] = Num(0);
init(Lim, K), NTT(A, Lim, 1), NTT(B, Lim, 1);
for (int i = 0; i < Lim; ++i) A[i] = A[i] * B[i], f0[i] = f1[i] = 0;
NTT(A, Lim, -1), idx -= 2;
for (int i = 0; i <= R-L+1; ++i) f[i] = A[i].Merge();
}
}

int n, q;
int A[MAXN], f[MAXN];

int main() {
#ifndef ONLINE_JUDGE
freopen("input.in", "r", stdin);
#endif
// input
read(n), read(q);
for (int i = 1; i <= n; ++i) read(A[i]);
// solve
Poly::divide(1, n, A, f);
// output
while (q--) {
static int K;
read(K), printf("%d\n", f[K]);
}
return 0;
}