「清华集训 2017」生成树计数 题解


大概是利用树的 Prufer 序列解决某些计数问题, 第一次见感觉很新鲜.

前置知识

先来考虑一个简化版的问题:

在一个 $s$ 个点的图中, 存在 $s - n$ 条边, 使图中形成了 $n$ 个连通块, 第 $i$ 个连通块中有 $a_i$ 个点. 再连接 $n - 1$ 条边, 使得图连通.

求方案数.

Prufer 序列有一个很基本的性质: 将一个有标号无根树用 $[1, n]$ 中的 $n - 2$ 个整数唯一表示.

将 $n$ 个连通块看作点, 那么这些点之间连 $n - 1$ 条边之后得出的连通图就是棵树, 对这棵树构造 Prufer 序列.

设 $n$ 个连通块的度数分别为 $d_i$, 则所有连通块度数之和为 $2n - 2$. 那么对于一组已经确定的 $d_i$, Prufer 序列的个数为

可列出式子

式子最后的乘积就是考虑连通块内部对外连边的情况.

考虑用多项式定理去化简. 多项式定理即

其中 $\sum\limits_{i = 1} ^ m n_i = n$, $0 \le n_i \le n$.

记 $c_i = d_i -1$. (此时的 $c_i$ 也可理解为第 $i$ 个点在 Prufer 序列中的出现次数). 那么有

如果直接用 ci 的组合意义去考虑, 可以直接得到这个结果.

套多项式定理, 得

最终答案为

解题思路

沿用上面的思路, 其实只是套用多项式定理化简前的思路. 那么答案可写作

将组合数拆开, 并从最后的乘积中提出来一项, 得到

整理, 得

设后半部分 OGF 为 $F(x)$, 那么答案为

前半部分为常数, 直接计算即可. 现在的问题在于如何快速求 $F(x)$. 不妨设 $A(x)$, $B(x)$ 两个 EGF, 分别为

此时的 $F(x)$ 可写作

利用对数解决这个乘积. 得

考虑到 $a_i$ 的影响, 这个好像不能直接做. 做一步转化, 在计算出 $\frac{A(x)}{B(x)}$ 及 $\ln B(x)$ 后, 对于第 $k$ 项系数乘 $\sum\limits_{i = 1} ^ n a_i ^ k$ 即可得出 $\sum\limits_{i = 1} ^ n \frac{A(a_i x)}{B(a_i x)}$ 和 $\sum\limits_{i = 1} ^ n \ln B(a_i x)$.

仅剩的问题是计算 $\sum\limits_{i = 1} ^ n a_i ^ k$. 设其 OGF 为 $G(x)$, 即

此时的 $G(x)$ 不便于快速计算. 下面利用对数和导数对 $G(x)$ 进行转化, 使其成为乘积的形式.

注意到

那么 $G(x)$ 可化作

考虑到求导的线性性, 以及对数的性质, 得到

分治 FFT 计算即可. 时间复杂度 $O(n \log ^ 2 n)$.

另外存在基于第二类 Stirling 数的, 时间复杂度为 $O(nm\log ^ 2 n)$ 的做法, 在复杂度和推导过程上都不占优势.

代码实现

注意 $n = 1$ 的情况需要特判, UOJ 上有类似的 Hack 数据.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
// UOJ #335
// DeP
#include <cctype>
#include <cstdio>
#include <algorithm>
using namespace std;

namespace IO {
const int MAXSIZE = 1 << 18 | 1;
char buf[MAXSIZE], *p1, *p2;

inline int Gc() {
return p1 == p2 &&
(p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin), p1 == p2)? EOF: *p1++;
}
template<typename T> inline void read(T& x) {
x = 0; int f = 0, ch = Gc();
while (!isdigit(ch)) f |= ch == '-', ch = Gc();
while (isdigit(ch)) x = x * 10 + ch - '0', ch = Gc();
if (f) x = -x;
}
}
using IO::read;

const int LOG = 16, MAXN = 1 << LOG | 1, P = 998244353, G = 3;

int W[LOG][MAXN];
int inv[MAXN], fac[MAXN], ifac[MAXN];

int fpow(int base, int b) {
int ret = 1;
while (b > 0) {
if (b & 1) ret = 1LL * ret * base % P;
base = 1LL * base * base % P, b >>= 1;
}
return ret;
}

inline void OutPoly(const int* f, const int& n) {
for (int i = 0; i < n; ++i)
fprintf(stderr, "%d%c", f[i], " \n"[i == n - 1]);
}

namespace Poly {
int r[MAXN];
inline void init(const int& Lim, const int& L) {
for (int i = 1; i < Lim; ++i) r[i] = (r[i>>1] >> 1) | ((i & 1) << (L-1));
}

void NTT(int* f, const int& Lim, const int& type) {
for (int i = 1; i < Lim; ++i) if (i < r[i]) swap(f[i], f[r[i]]);
for (int k = 0, Mid = 1; Mid < Lim; ++k, Mid <<= 1) {
const int* w = W[k];
for (int i = 0; i < Lim; i += Mid << 1)
for (int j = 0; j < Mid; ++j) {
int f0 = f[i+j], f1 = 1LL * w[j] * f[i+j+Mid] % P;
f[i+j] = (f0 + f1) % P, f[i+j+Mid] = (f0 - f1 + P) % P;
}
}
if (type < 0) {
int iv = fpow(Lim, P - 2);
for (int i = 0; i < Lim; ++i) f[i] = 1LL * f[i] * iv % P;
reverse(f + 1, f + Lim);
}
}

void Inv(int* f, int* g, const int& n) {
static int A[MAXN], B[MAXN];
g[0] = fpow(f[0], P - 2);
for (int L = 0, Lim = 1, Mid = 2; Mid < 2*n; Mid <<= 1) {
while (Lim < 2*Mid) Lim <<= 1, ++L;
for (int i = 0; i < Mid; ++i) A[i] = f[i], B[i] = g[i];
for (int i = Mid; i < Lim; ++i) A[i] = B[i] = 0;
init(Lim, L), NTT(A, Lim, 1), NTT(B, Lim, 1);
for (int i = 0; i < Lim; ++i)
g[i] = ((B[i] + B[i]) % P - 1LL * A[i] * B[i] % P * B[i] % P + P) % P;
NTT(g, Lim, -1);
for (int i = min(n, Mid); i < Lim; ++i) g[i] = 0;
}
}

inline void Der(int* f, int* g, const int& n) {
for (int i = 1; i < n; ++i) g[i - 1] = 1LL * i * f[i] % P;
g[n - 1] = 0;
}
inline void Int(int* f, int* g, const int& n) {
for (int i = n - 1; i; --i) g[i] = 1LL * inv[i] * f[i - 1] % P;
g[0] = 0;
}

void Ln(int* f, int* g, const int& n) {
static int ivf[MAXN], df[MAXN];
Der(f, df, n), Inv(f, ivf, n);
int Lim = 1, L = 0;
while (Lim < 2*n) Lim <<= 1, ++L;
for (int i = n; i < Lim; ++i) ivf[i] = df[i] = 0;
init(Lim, L), NTT(df, Lim, 1), NTT(ivf, Lim, 1);
for (int i = 0; i < Lim; ++i) df[i] = 1LL * df[i] * ivf[i] % P;
NTT(df, Lim, -1), Int(df, g, n);
}

void Exp(int* f, int* g, const int& n) {
static int lng[MAXN], A[MAXN], B[MAXN];
g[0] = 1;
for (int L = 0, Lim = 1, Mid = 2; Mid < 2*n; Mid <<= 1) {
Ln(g, lng, Mid);
while (Lim < 2*Mid) Lim <<= 1, ++L;
for (int i = 0; i < Mid; ++i)
A[i] = (f[i] - lng[i] + P) % P, B[i] = g[i];
A[0] = (A[0] + 1) % P;
for (int i = Mid; i < Lim; ++i) A[i] = B[i] = 0;
init(Lim, L), NTT(A, Lim, 1), NTT(B, Lim, 1);
for (int i = 0; i < Lim; ++i) g[i] = 1LL * A[i] * B[i] % P;
NTT(g, Lim, -1);
for (int i = min(n, Mid); i < Lim; ++i) g[i] = 0;
}
}

void Mul(int* f, const int& n, int* g, const int& m, int* h) {
static int A[MAXN], B[MAXN];
int Lim = 1, L = 0;
while (Lim < n + m - 1) Lim <<= 1, ++L;
for (int i = 0; i < Lim; ++i)
A[i] = (i < n)? f[i]: 0, B[i] = (i < m)? g[i]: 0;
init(Lim, L), NTT(A, Lim, 1), NTT(B, Lim, 1);
for (int i = 0; i < Lim; ++i) h[i] = 1LL * A[i] * B[i] % P;
NTT(h, Lim, -1);
}
}

void PolyPre(int N) {
inv[0] = inv[1] = fac[0] = ifac[0] = 1;
for (int i = 2; i <= N; ++i)
inv[i] = 1LL * inv[P % i] * (P - P / i) % P;
for (int i = 1; i <= N; ++i) {
fac[i] = 1LL * i * fac[i - 1] % P;
ifac[i] = 1LL * inv[i] * ifac[i - 1] % P;
}
for (int w, i = 0, Mid = 1; i < LOG; ++i, Mid <<= 1) {
W[i][0] = 1, w = fpow(G, (P - 1) / (Mid << 1));
for (int j = 1; j < Mid; ++j)
W[i][j] = 1LL * w * W[i][j - 1] % P;
}
}

int n, m;
int a[MAXN];

int tmp[LOG << 1][MAXN], ptr;

int solve(int* f, int L, int R) { // [L, R)
if (R - L < 2)
return f[0] = 1, f[1] = (P - a[L]) % P, 2;
int Mid = (L + R) / 2, *f0 = tmp[ptr++], *f1 = tmp[ptr++];
int dl = solve(f0, L, Mid), dr = solve(f1, Mid, R);
Poly::Mul(f0, dl, f1, dr, f), ptr -= 2;
return dl + dr - 1;
}

int A[MAXN], B[MAXN], ivB[MAXN];
int f[MAXN], g[MAXN];

int main() {
#ifndef ONLINE_JUDGE
freopen("input.in", "r", stdin);
#endif
// input
read(n), read(m);
for (int i = 1; i <= n; ++i) read(a[i]);
// solve
if (n == 1)
return puts((m == 0)? "1": "0"), 0;
PolyPre(n);
// A, B
for (int i = 0; i <= n; ++i) {
A[i] = 1LL * ifac[i] * fpow(i + 1, 2 * m) % P;
B[i] = 1LL * ifac[i] * fpow(i + 1, m) % P;
}
Poly::Inv(B, ivB, n + 1), Poly::Ln(B, B, n + 1);
Poly::Mul(A, n + 1, ivB, n + 1, A);
// g
solve(g, 1, n + 1);
Poly::Ln(g, g, n + 1);
for (int i = 1; i <= n; ++i)
g[i] = (P - 1LL * g[i] * i % P) % P;
g[0] = n;
// f
for (int i = 0; i <= n; ++i)
A[i] = 1LL * A[i] * g[i] % P, B[i] = 1LL * B[i] * g[i] % P;
Poly::Exp(B, f, n + 1);
Poly::Mul(f, n + 1, A, n + 1, f);
// output
int ans = 1LL * fac[n - 2] * f[n - 2] % P;
for (int i = 1; i <= n; ++i)
ans = 1LL * ans * a[i] % P;
printf("%d\n", ans);
return 0;
}

参考资料