「Luogu P2664」树上游戏


这是一道点分治好题, 至少我现在这么认为. = =

前置知识

  1. 点分治

题目思路

考虑点分治. 虽然隔壁 O(n) 做法又大又圆但是不管

点分治在处理树上路径信息的时候, 把路径分为两类: 一类不经过根节点, 一类经过. 后者为一端点是根节点的路径, 或者可以看作两条端点为根节点的路径拼接而来.

所以可以依次遍历子树, 每次单独考虑子树中的信息和其他子树信息之间的影响, 从而更新答案.

在处理这道题的时候我们沿用这个思路. 也就是说, 上面是两段废话, 和这题毫无关系.

为了表述方便, 记 C[u] 为节点 $u$ 的颜色, size[u] 为节点 $u$ 的子树大小.

可以观察到一个事实, 单独考虑根节点 $root$ 一个子树中的某一个节点 $u$, ($u$ 满足在 $root$ 到 $u$ 的路径中 C[u] 第一次出现), 那么这个节点 $u$ 可以对其他子树中的节点 $v$, ($v$ 满足 $v$ 到 $root$ 路径上不包含 C[u], 有 size[u] 的贡献.

正确性显然. 单独考虑每个颜色的贡献, 因为 $u$ 的颜色只在 $u, v$ 路径上出现过一次, 所以 $u$ 的子树都可以同 $v$ 构成一个点对, $u$ 的颜色对 $v$ 做一次贡献, 共 size[u].

考虑如何在点分治的过程中维护这个东西.

先维护满足以上条件 $u$ 的子树大小和, (也就是 $u$ 满足 $root$ 到 $u$ 的路径中 (不包括 $root$), $u$ 的颜色第一次出现), 以颜色为下标, 记作 $\operatorname{W}(c) = \sum\limits_{\operatorname{C}(u) = c} \operatorname{size}(u)$.

并记 tot 为所有 W[c] 的和.

  • 对于根节点, 自身的颜色可以给自己 size[root] 的贡献, 其他节点可以给根节点 tot - W[C[root]] 的贡献.

    其实很好理解, 因为自己的颜色已经计算过了, 自然减掉就好了.

  • 对于以 $u$ 为根子树中的点 $v$, 考虑其他子树对该子树的影响.

    num 为 $v$ 到 $root$ 路径 (不包括 $root$) 上的颜色数量.

    沿路更新 num, tot 的值: 遇到新颜色, 将 tot 减去 W[C[v]], 并将 num 增加 1.

    那么对 $v$ 的贡献为 tot + (size[root] - size[u]) * num.

    也就是满足开始所说的条件的其他子树中的点, 对 $v$ 的答案做贡献; 以及其他子树中的点经过 num 个点来到 $v$, 对 $v$ 的答案做贡献.

    注意回溯的时候把改变的 tot 以及 num 改回来. = =

那么, 为什么不会算重呢?

实际上计算贡献的方式, 每一步都对应树上不同形式的路径, 细节留给读者思考.

对于根节点的统计, 考虑的是当前分治范围内的子树对其的贡献; 对于子树中节点的统计, 考虑的是其他子树对其的影响, 而不考虑子树内部的贡献, 两者互不影响.

综上, 可以得到如下的算法流程:

  1. 以当前重心为根, DFS 整棵树, 维护 size[u], W[c].

  2. 维护根节点的答案, Ans[root] += tot - W[C[root]] + size[root].

  3. 遍历根节点子树, 将当前节点的贡献减去, 计算当前子树答案, Ans[v] += tot + (size[root] - size[v]) * num, 再加回当前节点贡献.

  4. 清空记录的信息.

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// Luogu P2664
// DeP
#include <cctype>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

namespace IO {
const int MAXSIZE = 1 << 18 | 1;
char buf[MAXSIZE], *p1, *p2;

inline int Gc() {
return p1 == p2 &&
(p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin), p1 == p2)? EOF: *p1++;
}
template<typename T> inline void read(T& x) {
x = 0; int f = 0, ch = Gc();
while (!isdigit(ch)) f |= ch == '-', ch = Gc();
while (isdigit(ch)) x = x * 10 + ch - '0', ch = Gc();
if (f) x = -x;
}
}
using IO::read;

typedef long long LL;
const int MAXN = 1e5+5;

int n;
int C[MAXN];
LL Ans[MAXN];

namespace Graph {
struct Edge { int nxt, to; } edges[MAXN << 1];
int head[MAXN], eidx;

inline void init() { memset(head, -1, sizeof head), eidx = 1; }
inline void AddEdge(int from, int to) {
edges[++eidx] = (Edge){ head[from], to };
head[from] = eidx;
}
}

namespace TreeDivide {
using namespace Graph;
bool vis[MAXN];
int cnt[MAXN], num;
LL W[MAXN], tot, Y;
int Balance[MAXN], size[MAXN], subsize, ct;

inline void init(const int& x) { subsize = x, Balance[ct = 0] = MAXN; }

void findCt(int u, int fa) {
Balance[u] = 0, size[u] = 1;
for (int v, i = head[u]; ~i; i = edges[i].nxt) {
if ((v = edges[i].to) == fa || vis[v]) continue;
findCt(v, u), size[u] += size[v];
Balance[u] = max(Balance[u], size[v]);
}
Balance[u] = max(Balance[u], subsize - size[u]);
if (Balance[u] < Balance[ct]) ct = u;
}

// Step 1
void dfs(int u, int fa) {
++cnt[C[u]], size[u] = 1;
for (int v, i = head[u]; ~i; i = edges[i].nxt) {
if ((v = edges[i].to) == fa || vis[v]) continue;
dfs(v, u), size[u] += size[v];
}
if (cnt[C[u]] == 1) tot += size[u], W[C[u]] += size[u];
--cnt[C[u]];
}

// Step 3
void subDfs(int u, int fa) {
++cnt[C[u]];
if (cnt[C[u]] == 1) tot -= W[C[u]], ++num;
Ans[u] += tot + Y * num;
for (int v, i = head[u]; ~i; i = edges[i].nxt)
if ((v = edges[i].to) != fa && !vis[v]) subDfs(v, u);
if (cnt[C[u]] == 1) tot += W[C[u]], --num;
--cnt[C[u]];
}

// Step 2 --> 3
void Mdy(int u, int fa, const int& type) {
++cnt[C[u]];
for (int v, i = head[u]; ~i; i = edges[i].nxt)
if ((v = edges[i].to) != fa && !vis[v]) Mdy(v, u, type);
if (cnt[C[u]] == 1) tot += 1LL * type * size[u], W[C[u]] += 1LL * type * size[u];
--cnt[C[u]];
}

// Step 4
void clear(int u, int fa) {
W[C[u]] = cnt[C[u]] = 0;
for (int v, i = head[u]; ~i; i = edges[i].nxt)
if ((v = edges[i].to) != fa && !vis[v]) clear(v, u);
}

void Divide(int u) {
vis[u] = true;
// now
// Step 1
dfs(u, -1);
// Step 2
Ans[u] += tot - W[C[u]] + size[u];
for (int v, i = head[u]; ~i; i = edges[i].nxt) {
if (vis[v = edges[i].to]) continue;
// Step 2 --> 3
++cnt[C[u]], tot -= size[v], W[C[u]] -= size[v];
Mdy(v, u, -1);
// Step 3
--cnt[C[u]], Y = size[u] - size[v];
subDfs(v, u);
++cnt[C[u]], tot += size[v], W[C[u]] += size[v];
Mdy(v, u, 1);
--cnt[C[u]];
}
// Step 4
num = tot = 0, clear(u, -1);
// nxt
for (int v, i = head[u]; ~i; i = edges[i].nxt) {
if (vis[v = edges[i].to]) continue;
init(size[v]), findCt(v, -1), Divide(ct);
}
}

inline void solve() { init(n), findCt(1, -1), Divide(ct); }
}

int main() {
#ifndef ONLINE_JUDGE
freopen("input.in", "r", stdin);
#endif
// init
Graph::init();
// input
read(n);
for (int i = 1; i <= n; ++i) read(C[i]);
for (int u, v, i = 1; i < n; ++i)
read(u), read(v), Graph::AddEdge(u, v), Graph::AddEdge(v, u);
// solve
TreeDivide::solve();
// output
for (int i = 1; i <= n; ++i) printf("%lld\n", Ans[i]);
return 0;
}

参考资料