题面
题面,给定一个 \(n\) 个节点的无根树,边带权,求树上长度不超过 \(k\) 的简单路径的数量。
\(n \le 10^5\)
解法
这道题所使用到的思想是分治,假设我们现在要处理一颗根为 \(u\) 的树,很明显这棵树的路径 \(x \to y\) 有两种:
- 如果 \(x, y\) 在 \(u\) 的同一棵子树里面,那么 \(x \to y\) 一定不经过 \(u\),只在某棵子树里面。
- 如果 \(x, y\) 在 \(u\) 的不同子树里面,那么 \(x \to y\) 一定经过 \(u\)。
很明显第一种情况我们可以递归在子树里面做,因此我们主要处理第二种情况,准确来说,我们从 \(u\) 开始 DFS 一波,求出 \(d[x]\) 为 \(x\) 到 \(u\) 的距离(深度),那么我们要求的答案就变成了 \(x, y\) 不属于 \(u\) 的同一棵子树,且 \(d[x] + d[y] \le k\) 的 \((x, y)\) 的个数。
我们先忽略 “不属于同一棵子树” 这个限制,考虑后面的条件,给定 \(d\),如何求 \(d[x] + d[y] \le k\) 的 \((x, y)\) 的个数?
我们使用双指针的方法,把 \(d\) 排序,然后维护两个指针 \(i, j\),假设有 \(cnt\) 个元素,初始化 \(i = 1, j = cnt\)。很明显随着 \(i\) 的增加,\(j\) 必定单调不增!而对于任意时刻的满足 \(d[i] + d[j] \le k\) 的 \(i, j\),显然对于 \(i < x \le j\),都有 \(d[i] + d[x] \le k\),因此我们把答案累加上 \(j - i\)。总体复杂度为 \(O(cnt \log cnt)\),代码很短:
(d + 1, d + 1 + cnt);
sortint ret = 0;
for (int i = 1, j = cnt; i < j; ) {
if (d[i] + d[j] <= k) ret += j - i, i++;
else j--;
}
我们接下来考虑如何加上” 不属于同一棵子树 “的限制,假设我们递归计算的函数为 \(calc(u)\),一开始初始化 \(d[u] = 0\),然后统计了一波,假设我们多统计了都在子树 \(v\) 里面的路径的数量,这是多少呢?很显然这些多统计的路径数等价于初始化 \(d[v] = w_{uv}\) 之后 \(calc(v)\) 的值,减去即可:
...
+= calc(u);
ans for (int e = fst[u]; e; e = nxt[e]) {
int v = to[e];
if (vis[v]) continue;
[v] = len[e]; ans -= calc(v); // 减去多算的
dis...
}
最后,还留下一个问题:从哪里开始分治呢?注意到如果我们出发点取得不好,选到了链的一端,那么复杂度就退化了,因此我们最好从无根树的重心开始分治,求重心是 \(O(n)\) 的:
void dfs_rt(int u, int p) { // sz [0] 是总结点数,p 是父节点
[u] = 1, f[u] = 0;
szfor (int e = fst[u]; e; e = nxt[e]) {
int v = to[e];
if (vis[v] || v == p) continue;
(v, u);
dfs_rt[u] += sz[v];
sz[u] = max(f[u], sz[v]);
f}
[u] = max(f[u], sz[0] - sz[u]);
fif (f[u] < f[rt]) rt = u;
}
因此,最后算法的总框架就是:
- 从重心 \(u\) 开始分治
- 处理完重心以后,删除 \(u\),递归处理其他的无根树。
void dfs(int u) {
[u] = 1; dis[u] = 0;
vis+= calc(u);
ans for (int e = fst[u]; e; e = nxt[e]) {
int v = to[e];
if (vis[v]) continue;
[v] = len[e]; ans -= calc(v);
dis= 0; sz[0] = sz[v];
rt (v, 0); dfs(rt);
dfs_rt}
}
易证递归的深度不超过 \(O(\log n)\),算上双指针中排序的时间,算法的总复杂度为 \(O(n\log ^2 n)\)。
代码:
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 10010;
int n, k, fst[N], nxt[2 * N], to[2 * N], len[2 * N];
int tot, cnt, vis[N], f[N], sz[N], dis[N], d[N], rt;
int ans;
void link(int a, int b, int w) {
[++tot] = fst[a];
nxt[a] = tot; to[tot] = b;
fst[tot] = w;
len}
void dfs_rt(int u, int p) {
[u] = 1, f[u] = 0;
szfor (int e = fst[u]; e; e = nxt[e]) {
int v = to[e];
if (vis[v] || v == p) continue;
(v, u);
dfs_rt[u] += sz[v];
sz[u] = max(f[u], sz[v]);
f}
[u] = max(f[u], sz[0] - sz[u]);
fif (f[u] < f[rt]) rt = u;
}
void dfs_d(int u, int p) {
[++cnt] = dis[u];
dfor (int e = fst[u]; e; e = nxt[e]) {
int v = to[e];
if (vis[v] || v == p) continue;
[v] = dis[u] + len[e];
dis(v, u);
dfs_d}
}
int calc(int u) {
= 0; dfs_d(u, 0);
cnt (d + 1, d + 1 + cnt);
sortint ret = 0;
for (int i = 1, j = cnt; i < j; ) {
if (d[i] + d[j] <= k) ret += j - i, i++;
else j--;
}
return ret;
}
void dfs(int u) {
[u] = 1; dis[u] = 0;
vis+= calc(u);
ans for (int e = fst[u]; e; e = nxt[e]) {
int v = to[e];
if (vis[v]) continue;
[v] = len[e]; ans -= calc(v);
dis= 0; sz[0] = sz[v];
rt (v, 0); dfs(rt);
dfs_rt}
}
int main() {
[0] = 0x3f3f3f3f;
fwhile (scanf("%d%d", &n, &k), n) {
= 1; ans = 0;
tot (vis + 1, vis + 1 + n, 0);
fill(fst + 1, fst + 1 + n, 0);
fillfor (int i = 1; i <= n - 1; i++) {
int a, b, w;
("%d%d%d", &a, &b, &w);
scanf(a, b, w);
link(b, a, w);
link}
[0] = n; dfs_rt(1, 0);
sz(rt);
dfs("%d\n", ans);
printf}
return 0;
}