题面

题面,给定一个 \(n\) 个节点的无根树,边带权,求树上长度不超过 \(k\) 的简单路径的数量。

\(n \le 10^5\)

解法

这道题所使用到的思想是分治,假设我们现在要处理一颗根为 \(u\) 的树,很明显这棵树的路径 \(x \to y\) 有两种:

  1. 如果 \(x, y\)\(u\) 的同一棵子树里面,那么 \(x \to y\) 一定不经过 \(u\),只在某棵子树里面。
  2. 如果 \(x, y\)\(u\) 的不同子树里面,那么 \(x \to y\) 一定经过 \(u\)

很明显第一种情况我们可以递归在子树里面做,因此我们主要处理第二种情况,准确来说,我们从 \(u\) 开始 DFS 一波,求出 \(d[x]\)\(x\)\(u\) 的距离(深度),那么我们要求的答案就变成了 \(x, y\) 不属于 \(u\) 的同一棵子树,且 \(d[x] + d[y] \le k\)\((x, y)\) 的个数。

我们先忽略 “不属于同一棵子树” 这个限制,考虑后面的条件,给定 \(d​\),如何求 \(d[x] + d[y] \le k​\)\((x, y)​\) 的个数?

我们使用双指针的方法,把 \(d\) 排序,然后维护两个指针 \(i, j\),假设有 \(cnt\) 个元素,初始化 \(i = 1, j = cnt\)。很明显随着 \(i\) 的增加,\(j\) 必定单调不增!而对于任意时刻的满足 \(d[i] + d[j] \le k\)\(i, j\),显然对于 \(i < x \le j\),都有 \(d[i] + d[x] \le k\),因此我们把答案累加上 \(j - i\)。总体复杂度为 \(O(cnt \log cnt)\),代码很短:

sort(d + 1, d + 1 + cnt);
int ret = 0;
for (int i = 1, j = cnt; i < j; ) {
    if (d[i] + d[j] <= k) ret += j - i, i++;
    else j--;
}

我们接下来考虑如何加上” 不属于同一棵子树 “的限制,假设我们递归计算的函数为 \(calc(u)\),一开始初始化 \(d[u] = 0\),然后统计了一波,假设我们多统计了都在子树 \(v\) 里面的路径的数量,这是多少呢?很显然这些多统计的路径数等价于初始化 \(d[v] = w_{uv}\) 之后 \(calc(v)\) 的值,减去即可:

...
ans += calc(u);
for (int e = fst[u]; e; e = nxt[e]) {
    int v = to[e];
    if (vis[v]) continue;
    dis[v] = len[e]; ans -= calc(v); // 减去多算的
    ...
}

最后,还留下一个问题:从哪里开始分治呢?注意到如果我们出发点取得不好,选到了链的一端,那么复杂度就退化了,因此我们最好从无根树的重心开始分治,求重心是 \(O(n)\) 的:

void dfs_rt(int u, int p) { // sz [0] 是总结点数,p 是父节点
    sz[u] = 1, f[u] = 0;
    for (int e = fst[u]; e; e = nxt[e]) {
        int v = to[e];
        if (vis[v] || v == p) continue;
        dfs_rt(v, u);
        sz[u] += sz[v];
        f[u] = max(f[u], sz[v]);
    }
    f[u] = max(f[u], sz[0] - sz[u]);
    if (f[u] < f[rt]) rt = u;
}

因此,最后算法的总框架就是:

  1. 从重心 \(u\) 开始分治
  2. 处理完重心以后,删除 \(u\),递归处理其他的无根树。
void dfs(int u) {
    vis[u] = 1; dis[u] = 0;
    ans += calc(u);
    for (int e = fst[u]; e; e = nxt[e]) {
        int v = to[e];
        if (vis[v]) continue;
        dis[v] = len[e]; ans -= calc(v);
        rt = 0; sz[0] = sz[v];
        dfs_rt(v, 0); dfs(rt);
    }
}

易证递归的深度不超过 \(O(\log n)\),算上双指针中排序的时间,算法的总复杂度为 \(O(n\log ^2 n)\)

代码:

#include <cstdio>
#include <algorithm>

using namespace std;
const int N = 10010;
int n, k, fst[N], nxt[2 * N], to[2 * N], len[2 * N];
int tot, cnt, vis[N], f[N], sz[N], dis[N], d[N], rt;
int ans;

void link(int a, int b, int w) {
    nxt[++tot] = fst[a];
    fst[a] = tot; to[tot] = b;
    len[tot] = w;
}

void dfs_rt(int u, int p) {
    sz[u] = 1, f[u] = 0;
    for (int e = fst[u]; e; e = nxt[e]) {
        int v = to[e];
        if (vis[v] || v == p) continue;
        dfs_rt(v, u);
        sz[u] += sz[v];
        f[u] = max(f[u], sz[v]);
    }
    f[u] = max(f[u], sz[0] - sz[u]);
    if (f[u] < f[rt]) rt = u;
}

void dfs_d(int u, int p) {
    d[++cnt] = dis[u];
    for (int e = fst[u]; e; e = nxt[e]) {
        int v = to[e];
        if (vis[v] || v == p) continue;
        dis[v] = dis[u] + len[e];
        dfs_d(v, u);
    }
}

int calc(int u) {
    cnt = 0; dfs_d(u, 0);
    sort(d + 1, d + 1 + cnt);
    int ret = 0;
    for (int i = 1, j = cnt; i < j; ) {
        if (d[i] + d[j] <= k) ret += j - i, i++;
        else j--;
    }
    return ret;
}

void dfs(int u) {
    vis[u] = 1; dis[u] = 0;
    ans += calc(u);
    for (int e = fst[u]; e; e = nxt[e]) {
        int v = to[e];
        if (vis[v]) continue;
        dis[v] = len[e]; ans -= calc(v);
        rt = 0; sz[0] = sz[v];
        dfs_rt(v, 0); dfs(rt);
    }
}

int main() {
    f[0] = 0x3f3f3f3f;
    while (scanf("%d%d", &n, &k), n) {
        tot = 1; ans = 0;
        fill(vis + 1, vis + 1 + n, 0);
        fill(fst + 1, fst + 1 + n, 0);
        for (int i = 1; i <= n - 1; i++) {
            int a, b, w;
            scanf("%d%d%d", &a, &b, &w);
            link(a, b, w);
            link(b, a, w);
        }
        sz[0] = n; dfs_rt(1, 0);
        dfs(rt);
        printf("%d\n", ans);
    }
    return 0;
}