点分治

点分治

点分治,也叫重心剖分,是一种用于树形数据结构的分治思想的算法,解决树上满足条件的路径统计问题。点分治的核心思想是通过选择树的重心结点,将树划分为若干子树,然后递归地在这些子树上应用相同的策略,从而有效地解决问题。

实现步骤

  1. 确定当前树的中心,设为根结点,开始点分治算法;
  2. 枚举下一个结点,通过 dfs 更新子树贡献;
  3. 清空当前使用的统计答案的数据结构,比如数组或树状数组(不建议完全清零,可能复杂度退化),删除当前结点;
  4. 递归处理各个子树,回到步骤 11

因为每一次选的是重心,所以每一个划分的子树大小不超过 n2\frac{n}{2},所以递归的深度为 O(logn)O(\log n),每一层递归中需要 O(n)O(n) 的时间复杂度来处理当前树。即跑点分治实现的时间复杂度为 O(nlogn)O(n \log n)(不包含实现的数据结构的时间复杂度)。

例题

CF161D

CF161D - Distance in Tree | vjudge

给你一棵有 nn 个顶点的树,以及一个正整数 mm。请你计算有多少对不同的结点,它们之间的距离恰好是 mm。注意,结点对 (v,u)(v, u)(u,v)(u, v) 被视为同一对。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include<bits/stdc++.h>
typedef int int32;
#define int long long
using namespace std;
const int N = 5e4 + 5;
/*
n: 树的结点数
m: 目标路径长度
sz[]: 结点子树大小
cnt[]: 记录路径长度出现的次数
tmp[]: 临时存储当前子树的路径长度
tot: tmp 数组的大小
ctr: 当前树的重心结点
ans: 最终答案
*/
int n, m, sz[N], cnt[N], tmp[N], tot, ctr, ans;
bitset<N>del; // 标记结点是否被删除
vector<int>nbr[N]; // 邻接表存储树
void get_ctr(int x, int fa) // 求重心
{
sz[x] = 1;
int maxi = 0;
for (auto& nxt : nbr[x])
if (!del[nxt] && nxt != fa)
{
get_ctr(nxt, x);
if (ctr)
return;
maxi = max(maxi, sz[nxt]);
sz[x] += sz[nxt]; // 不要写在 return 之前,否则 sz[x] 的大小错误
}
maxi = max(maxi, n - sz[x]);
if (maxi <= n / 2)
{
ctr = x;
sz[fa] = n - sz[x]; // 更新父节点的大小,防止后续递归时 sz 错误
}
return;
}
void dfs(int x, int fa, int len) // 更新子树贡献
{
if (len > m)
return;
ans += cnt[m - len] + (len == m); // 统计路径数目,特判长度等于 m 的路径
tmp[++tot] = len;
for (auto& nxt : nbr[x])
if (!del[nxt] && nxt != fa)
dfs(nxt, x, len + 1);
return;
}
void work(int x) // 点分治主体
{
// 2. 枚举下一个结点,通过 dfs 更新子树贡献;
for (auto& nxt : nbr[x])
if (!del[nxt])
{
int tmp1 = tot;
dfs(nxt, x, 1);
for (int i = tmp1 + 1; i <= tot; i++) // 从 tmp1 + 1 开始,避免重复统计
cnt[tmp[i]]++;
}
// 3. 清空当前使用的统计答案的数据结构,比如数组或树状数组,删除当前结点;
for (int i = 1; i <= tot; i++)
cnt[tmp[i]]--;
tot = 0;
del[x] = true;
// 4. 递归处理各个子树,回到步骤 $1$。
for (auto& nxt : nbr[x])
if (!del[nxt])
{
n = sz[nxt]; // 更新当前的树大小
ctr = 0; // 重置重心,不要忘记
// 1. 确定当前树的中心,设为根结点,开始点分治算法;
get_ctr(nxt, 0);
work(ctr);
}
return;
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
cin >> n >> m;
for (int i = 1; i < n; i++)
{
int x, y;
cin >> x >> y;
nbr[x].push_back(y);
nbr[y].push_back(x);
}
// 1. 确定当前树的中心,设为根结点,开始点分治算法;
get_ctr(1, 0);
work(ctr);
cout << ans;
return 0;
}

当然此题可以用动态规划,设 dpx,i\mathrm{dp}_{x,i} 表示以 xx 为根的子树中,距离 xx 恰好为 ii 的点的个数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include<bits/stdc++.h>
typedef int int32;
#define int long long
using namespace std;
const int N = 5e4 + 5, M = 500 + 5;
int n, m, dp[N][M], ans;
vector<int>nbr[N];
void dfs(int x, int fa)
{
dp[x][0] = 1;
for (auto& nxt : nbr[x])
if (nxt != fa)
{
dfs(nxt, x);
for (int i = 1; i <= m; i++)
{
if (m - i - 1 >= 0)
ans += dp[x][i] * dp[nxt][m - i - 1];
dp[x][i] += dp[nxt][i - 1];
}
}
ans += dp[x][m];
return;
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
cin >> n >> m;
for (int i = 1; i < n; i++)
{
int x, y;
cin >> x >> y;
nbr[x].push_back(y);
nbr[y].push_back(x);
}
dfs(1, 0);
cout << ans;
return 0;
}

CF1101D

CF1101D - GCD Counting

枚举质因数,然后跑 dfs 统计路径长度,这样可以确保 gcd(g(x,cur),g(cur,y))>1\gcd(g(x,cur),g(cur,y))\gt 1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include<bits/stdc++.h>
typedef int int32;
#define int long long
using namespace std;
const int N = 2e5 + 5;
int n, a[N], b[N], c, d, sz[N], cnt[N], tmp[N], tot, ctr, ans, maxi;
bitset<N>del, vis;
vector<int>nbr[N], e[N];
void get_ctr(int x, int fa)
{
sz[x] = 1;
int maxi = 0;
for (auto& nxt : nbr[x])
if (!del[nxt] && nxt != fa)
{
get_ctr(nxt, x);
if (ctr)
return;
maxi = max(maxi, sz[nxt]);
sz[x] += sz[nxt];
}
maxi = max(maxi, n - sz[x]);
if (maxi <= n / 2)
{
ctr = x;
sz[fa] = n - sz[x];
}
return;
}
void dfs(int x, int fa, int len, int len1)
{
if (len1 == 1) // g(cur,x) 为 1,不满足条件
return;
ans = max(ans, maxi + len - 1);
tmp[++tot] = len;
for (auto& nxt : nbr[x])
if (!del[nxt] && nxt != fa)
dfs(nxt, x, len + 1, __gcd(len1, a[nxt]));
return;
}
void work(int x)
{
for (auto& j : e[a[x]])
if (!(a[x] % b[j])) // 枚举质因数,但是这一行多此一举
{
maxi = 1;
for (auto& nxt : nbr[x])
if (!del[nxt] && __gcd(__gcd(a[x], a[nxt]), b[j]) == b[j]) // 只考虑该质因数相同的子树
{
int tmp1 = tot;
dfs(nxt, x, 2, b[j]);
for (int i = tmp1 + 1; i <= tot; i++)
maxi = max(maxi, tmp[i]);
}
tot = 0;
}
del[x] = true;
for (auto& nxt : nbr[x])
if (!del[nxt])
{
n = sz[nxt];
ctr = 0;
get_ctr(nxt, 0);
work(ctr);
}
return;
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
for (int i = 2; i <= 2e5; i++)
{
if (!vis[i])
{
b[++c] = i;
for (int j = i; j <= 2e5; j += i)
e[j].push_back(c);
}
for (int j = 1; j <= c; j++)
{
if (i * b[j] > 2e5)
break;
vis[i * b[j]] = true;
if (!(i % b[j]))
break;
}
}
cin >> n;
for (int i = 1; i <= n; i++)
cin >> a[i], ans = max<int>(ans, a[i] > 1);
for (int i = 1; i < n; i++)
{
int x, y;
cin >> x >> y;
nbr[x].push_back(y);
nbr[y].push_back(x);
}
get_ctr(1, 0);
work(ctr);
cout << ans;
return 0;
}

洛谷 P2634

P2634 [国家集训队] 聪聪可可

模板题,注意答案需要化简。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include<bits/stdc++.h>
typedef int int32;
#define int long long
using namespace std;
const int N = 2e4 + 5;
int n, sz[N], cnt[N], tmp[N], tot, ctr, ans;
bitset<N>del;
vector<pair<int, int>>nbr[N];
void get_ctr(int x, int fa)
{
sz[x] = 1;
int maxi = 0;
for (auto& y : nbr[x])
{
auto& nxt = y.first, & w = y.second;
if (!del[nxt] && nxt != fa)
{
get_ctr(nxt, x);
if (ctr)
return;
maxi = max(maxi, sz[nxt]);
sz[x] += sz[nxt];
}
}
maxi = max(maxi, n - sz[x]);
if (maxi <= n / 2)
{
ctr = x;
sz[fa] = n - sz[x];
}
return;
}
void dfs(int x, int fa, int len)
{
ans += cnt[(3 - len) % 3] + !len;
tmp[++tot] = len;
for (auto& y : nbr[x])
{
auto& nxt = y.first, & w = y.second;
if (!del[nxt] && nxt != fa)
dfs(nxt, x, (len + w) % 3);
}
return;
}
void work(int x)
{
for (auto& y : nbr[x])
{
auto& nxt = y.first, & w = y.second;
if (!del[nxt])
{
int tmp1 = tot;
dfs(nxt, x, w % 3);
for (int i = tmp1 + 1; i <= tot; i++)
cnt[tmp[i]]++;
}
}
for (int i = 1; i <= tot; i++)
cnt[tmp[i]]--;
tot = 0;
del[x] = true;
for (auto& y : nbr[x])
{
auto& nxt = y.first, & w = y.second;
if (!del[nxt])
{
n = sz[nxt];
ctr = 0;
get_ctr(nxt, 0);
work(ctr);
}
}
return;
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
cin >> n;
int m = n;
for (int i = 1; i < n; i++)
{
int x, y, z;
cin >> x >> y >> z;
nbr[x].push_back({ y,z });
nbr[y].push_back({ x,z });
}
get_ctr(1, 0);
work(ctr);
int tmp = __gcd(ans * 2 + m, m * m);
cout << (ans * 2 + m) / tmp << '/' << m * m / tmp;
return 0;
}

洛谷 P4178

P4178 Tree

统计答案时记得因为是小于等于 mm(题目中 kk)的,所以要查询 mlen+1m-{len}+1 并且加一。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
void update(int x, int k)
{
while (x <= m + 1)
{
tree[x] += k;
x += lowbit(x);
}
return;
}
int query(int x)
{
int sum = 0;
while (x)
{
sum += tree[x];
x -= lowbit(x);
}
return sum;
}
void dfs(int x, int fa, int len)
{
if (len > m)
return;
ans += query(m - len + 1) + 1;
tmp[++tot] = len;
for (auto& y : nbr[x])
{
auto& nxt = y.first, & w = y.second;
if (!del[nxt] && nxt != fa)
dfs(nxt, x, len + w);
}
return;
}
void work(int x)
{
for (auto& y : nbr[x])
{
auto& nxt = y.first, & w = y.second;
if (!del[nxt])
{
int tmp1 = tot;
dfs(nxt, x, w);
for (int i = tmp1 + 1; i <= tot; i++)
update(tmp[i] + 1, 1);
}
}
for (int i = 1; i <= tot; i++)
update(tmp[i] + 1, -1);
tot = 0;
del[x] = true;
for (auto& y : nbr[x])
{
auto& nxt = y.first, & w = y.second;
if (!del[nxt])
{
n = sz[nxt];
ctr = 0;
get_ctr(nxt, 0);
work(ctr);
}
}
return;
}