0%

dsu on tree 樹上啟發式合併

很久沒PO文了 來放個教學文好了

之前想要學 dsu on tree 所以就跑去看CF的這篇

不過看了好久還是看不懂 決定還是自己來紀錄一下DSU on Tree

參考資料:第一篇CF 第二篇CF

為什麼叫做 dsu on tree?

dsu on tree 其實跟並查集(dsu)沒有什麼關聯

但是他用到了啟發式合併 也就是 dsu 常會用到的一種降低複雜度的作法

DSU 的啟發式合併

1
2
3
4
5
6
void unite(int u, int v){
u = find(u), v = find(v);
if(u==v) return;
if(sz[u] < sz[v]) v = dsu[u], sz[v] += sz[u];
else u = dsu[v], sz[u] += sz[v];
}

樹上問題的 Naive Solution

首先 在進入 dsu on Tree 之前 先來看看一個例題

給予一棵以 $1$ 為根的一棵樹

在 $O(1)$ 的時間詢問以 $x$ 為根的子樹有幾個顏色為 $c$ 的節點?

這個問題 我們可以很輕易的想到他的作法 也就是直接 dfs 去存數量

但是這並不是一個很好的方法 因為預處理的時候 會花 $O(n^2)$ 的時間

$O(n^2)$ 的程式碼

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
void add(int u, int par, int x){
cnt[color[u]] += x;
for(auto v : adj[u]){
if(v!=par) add(v,u,x);
}
}

void dfs(int u, int par){
add(u,par,1);

//這裡的cnt[c]就會是在u為根的子樹上顏色為c的節點數量
ans[u][c] = cnt[c];

add(u,par,-1);
for(auto v : adj[u]){
if(v!=par) dfs(v,u);
}
}

使用 dsu on tree

要改進上面的解法 我們就可以使用 dsu on tree

dsu on tree 的概念

我們可以先考慮樹鏈剖分的概念 先做一次找出子樹的大小

1
2
3
4
5
void dfs_size(int u, int par){
sz[u] = 1;
for(auto v : adj[u])
if(v != par) dfs(v,u), sz[u] += sz[v];
}

之後要找答案再做另一次 dfs

先去對輕鏈做 dfs 找到輕鏈上的答案

之後再對重鏈去找到重鏈上的答案 將輕重鏈的答案合併

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <bits/stdc++.h>

#define int long long
#define fastio ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);

using namespace std;

const int N = 1e5+5;
vector<int> adj[N], vec[N];
int color[N], cnt[N], sz[N];
map<int,int> ans[N];

void dfs_size(int u, int par){
sz[u] = 1;
for(auto v : adj[u])
if(v != par) dfs_size(v,u), sz[u] += sz[v];
}

void dfs(int u, int par, bool keep){
int mx = 0, heavy = 0;
for(auto v : adj[u]){
if(v != par && sz[v] > mx){
mx = sz[v];
heavy = v;
}
}
for(auto v : adj[u]){
//DFS 輕鏈
if(v != par && v != heavy){
//Keep = 0 代表 dfs 後會清除這個子樹的答案
dfs(v,u,0);
}
}
if(heavy){
//Keep = 1 代表我們不必清除這個子樹的答案
dfs(heavy,u,1);
vec[u] = vec[heavy];
}
vec[u].push_back(u);
cnt[color[u]]++;
ans[u][color[u]] = cnt[color[u]];
for(auto v : adj[u]){
if(v != par && v != heavy){
for(auto x : vec[v]){
//把答案加上去
cnt[color[x]]++;
ans[u][color[x]] = cnt[color[x]];
vec[u].push_back(x);
}
}
}
if(!keep){
//清除輕鏈的資料
for(auto x : vec[u]){
cnt[color[x]]--;
}
}
}

signed main(){
fastio
int n;
cin >> n;
for(int i = 1;i <= n;i++) cin >> color[i];
for(int i = 1;i < n;i++){
int u,v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs_size(1,-1);
dfs(1,-1,0);
int q;
cin >> q;
while(q--){
int x,c;
cin >> x >> c;
cout << ans[x][c] << "\n";
}
}

例題

Codeforces 600D - Lomsat gelral

給予一棵以 $1$ 為根的樹以及每一個節點的顏色

問以 $1$ ~ $n$ 為根的子樹內 出現最多次的顏色 $c$ 的總和為?

這題也是 dsu on tree 可以輕鬆解掉的題目

只要維護出現最多次的顏色就能解決了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <bits/stdc++.h>

#define int long long
#define fastio ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);

using namespace std;

const int N = 1e5+5;
vector<int> adj[N], vec[N];
int color[N], cnt[N], sz[N];
int res[N], ans[N];

void dfs_size(int u, int par){
sz[u] = 1;
for(auto v : adj[u])
if(v != par) dfs_size(v,u), sz[u] += sz[v];
}

void update(int &mx, int u, int x){
ans[cnt[color[u]]] -= color[u];
cnt[color[u]] += x;
ans[cnt[color[u]]] += color[u];
mx = max(cnt[color[u]],mx);
}

int dfs(int u, int par, bool keep){
int mx = 1, heavy = 0;
for(auto v : adj[u]){
if(v != par && sz[v] > mx){
mx = sz[v];
heavy = v;
}
}
for(auto v : adj[u]){
//DFS 輕鏈
if(v != par && v != heavy){
//Keep = 0 代表 dfs 後會清除這個子樹的答案
dfs(v,u,0);
}
}
mx = 0;
if(heavy){
mx = max(mx,dfs(heavy,u,1));
swap(vec[u],vec[heavy]);
}
vec[u].push_back(u);
update(mx,u,1);
for(auto v : adj[u]){
if(v != par && v != heavy){
for(auto x : vec[v]){
update(mx,x,1);
vec[u].push_back(x);
}
}
}
res[u] = ans[mx];
if(!keep){
for(auto x : vec[u]){
update(mx,x,-1);
}
}
return mx;
}

signed main(){
fastio
int n;
cin >> n;
for(int i = 1;i <= n;i++) cin >> color[i];
for(int i = 1,u,v;i < n;i++){
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs_size(1,-1);
dfs(1,-1,0);
for(int i = 1;i <= n;i++) cout << res[i] << " ";
}