「集训队作业 2018」喂鸽子 题解


个人现在感觉这是一道有启发意义的期望题.

大体是利用 Min-Max 容斥计算期望, 以及 NTT 优化时间复杂度.

我 喂 我 自 己

解题思路

显然可以使用 Min-Max 容斥.

记 $X_i$ 为第 $i$ 只鸽子喂饱的期望时间, 也可称作期望喂饱该鸽子的玉米粒个数. 同时记 $S$ 为鸽子的全集, 也就是全部的鸽子. 套用 Min-Max 容斥可得

注意到每只鸽子之间并没有区别, 所以直接枚举集合大小即可. 设 $g(m)$ 表示喂 $m$ 只鸽子, 最早喂饱一只的期望时间. 则

此处等式右侧需乘上 $\frac{n}{m}$, 表示同时喂 $n$ 只鸽子, 喂到钦定的 $m$ 只鸽子的期望次数.

现在的问题转为如何快速求出 $g(m)$. 设 $f(m, i)$ 表示喂 $m$ 只鸽子, 喂 $i$ 粒玉米, 不存在任何一只鸽子被喂饱的概率. 那么至多只能喂 $m(k-1)$ 粒, 有

此处将 $f(m, 0)$ 单独列出, 以避免讨论.

对于计算 $f(m, i)$, 有一个朴素的想法是, 用不喂饱任何一只鸽子的方案数除以总共的方案数. 设 $h(m, i)$ 表示喂 $m$ 只鸽子, 喂 $i$ 粒玉米, 不存在任何一只鸽子被喂饱的方案数, 则有

对于 $h(m, i)$, 可以枚举喂某一只鸽子的玉米粒数, 得出

是一个卷积的形式, 用 NTT 计算即可. 和式中的 $\min$ 其实很好处理, 每次计算只取后者的前 $k - 1$ 次项即可.

时间复杂度 $O(n ^ 2 k \log nk)$.

另外存在时间复杂度为 $O(n ^ 2 k)$ 的优秀做法. 已经拿去喂鸽子了, 不妨参看 https://yhx-12243.github.io/OI-transit/records/uoj449.html.

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
// UOJ #449
// DeP
#include <cstdio>
#include <algorithm>
using namespace std;

const int MAXK = 1e3 + 5, MAXM = 5e1 + 5;
const int LOG = 17, MAXN = 1 << LOG | 1, P = 998244353, G = 3;

int W[LOG][MAXN];
int inv[MAXN], fac[MAXN], ifac[MAXN];

inline int C(int n, int m) {
return (n < m)? 0: 1LL * fac[n] * ifac[m] % P * ifac[n - m] % P;
}

int fpow(int base, int b) {
int ret = 1;
while (b > 0) {
if (b & 1) ret = 1LL * ret * base % P;
base = 1LL * base * base % P, b >>= 1;
}
return ret;
}

namespace Poly {
int r[MAXN];
inline void init(const int& Lim, const int& L) {
for (int i = 1; i < Lim; ++i) r[i] = (r[i>>1] >> 1) | ((i & 1) << (L-1));
}

void NTT(int* f, const int& Lim, const int& type) {
for (int i = 1; i < Lim; ++i) if (i < r[i]) swap(f[i], f[r[i]]);
for (int k = 0, Mid = 1; Mid < Lim; ++k, Mid <<= 1) {
const int* w = W[k];
for (int i = 0; i < Lim; i += Mid << 1)
for (int j = 0; j < Mid; ++j) {
int f0 = f[i+j], f1 = 1LL * w[j] * f[i+j+Mid] % P;
f[i+j] = (f0 + f1) % P, f[i+j+Mid] = (f0 - f1 + P) % P;
}
}
if (type < 0) {
int iv = fpow(Lim, P - 2);
for (int i = 0; i < Lim; ++i) f[i] = 1LL * iv * f[i] % P;
reverse(f + 1, f + Lim);
}
}

void Mul(int* f, const int& n, int* g, const int& m, int* h) {
static int A[MAXN], B[MAXN];
int Lim = 1, L = 0;
while (Lim < n + m - 1) Lim <<= 1, ++L;
for (int i = 0; i < Lim; ++i)
A[i] = (i < n)? f[i]: 0, B[i] = (i < m)? g[i]: 0;
init(Lim, L), NTT(A, Lim, 1), NTT(B, Lim, 1);
for (int i = 0; i < Lim; ++i) h[i] = 1LL * A[i] * B[i] % P;
NTT(h, Lim, -1);
}
}

void PolyPre(int N) {
inv[1] = 1;
for (int i = 2; i <= N; ++i)
inv[i] = 1LL * inv[P % i] * (P - P / i) % P;
fac[0] = ifac[0] = 1;
for (int i = 1; i <= N; ++i) {
fac[i] = 1LL * i * fac[i - 1] % P;
ifac[i] = 1LL * inv[i] * ifac[i - 1] % P;
}
for (int w, i = 0, Mid = 1; i < LOG; ++i, Mid <<= 1) {
W[i][0] = 1, w = fpow(G, (P - 1) / (Mid << 1));
for (int j = 1; j < Mid; ++j)
W[i][j] = 1LL * w * W[i][j - 1] % P;
}
}

int n, K;
int g[MAXM], f[MAXM][MAXN], h[MAXM][MAXN];

int A[MAXN], B[MAXN];

int main() {
#ifndef ONLINE_JUDGE
freopen("input.in", "r", stdin);
#endif
scanf("%d%d", &n, &K);
PolyPre(n * K);

// h(m, i) = sum_{j = 0} ^ min(i, k - 1) C(i, j) * h(m - 1, i - j)
h[0][0] = 1;
for (int m = 1; m <= n; ++m) {
for (int i = 0; i <= (m - 1) * (K - 1); ++i)
A[i] = 1LL * h[m - 1][i] * ifac[i] % P;
for (int i = 0; i <= K - 1; ++i) // i <= K - 1
B[i] = ifac[i];
Poly::Mul(A, (m - 1) * (K - 1) + 1, B, K, A);
for (int i = 0; i <= m * (K - 1); ++i)
h[m][i] = 1LL * fac[i] * A[i] % P;
}

// f(m, i) = h(m, i) / (i ^ m)
for (int m = 1; m <= n; ++m)
for (int i = 1; i <= m * (K - 1); ++i)
f[m][i] = 1LL * h[m][i] * fpow(inv[m], i) % P;

// g(m) = sum_{i = 0} ^ m (k-1) f(m, i)
for (int m = 1; m <= n; ++m) {
g[m] = 1;
for (int i = 1; i <= m * (K - 1); ++i)
g[m] = (g[m] + f[m][i]) % P;
}

int ans = 0;
for (int k = 1; k <= n; ++k) {
int s = 1LL * C(n, k) * g[k] % P * n % P * inv[k] % P;
ans = (ans + ((k & 1)? s: P - s)) % P;
}

printf("%d\n", ans);
return 0;
}