本地笔记整理第一发,比较散,2019/12/8 上传 \[ %dontshow \newcommand{\dpa}{\mathrm{dp}} \newcommand{\BigO}[1]{\mathcal{O}\left(#1\right)} \newcommand{\ldp}{\mathrm{ldp}} \newcommand{\lch}[1]{\mathrm{lch}(#1)} \newcommand{\hch}[1]{\mathrm{hch}(#1)} \]
分析
题目要求求取带修改的树上最大独立集,朴素 \(\BigO{n^2}\) DP 如下: \[ \begin{aligned} \dpa[u][0] &= \sum_{v \in \mathrm{ch}(u)} \max\left\{\dpa[v][1], \dpa[v][0]\right\} \\ \dpa[u][1] &= \mathrm{val}[u] + \sum_{v \in \mathrm{ch}(u)} \dpa[v][0] \end{aligned} \]
但是这么一次转移是 \(\BigO{n}\) 的,考虑优化,定义 \(\mathrm{lch}(u)\) 为 \(u\) 的轻儿子集合。\(\mathrm{hch}(u)\) 为 \(u\) 的重儿子,稍微修改一下,定义
\[ \begin{aligned} \ldp[u][0] &= \sum_{v \in \lch{u}} \max\left\{\dpa[v][1], \dpa[v][0]\right\} \\ \ldp[u][1] &= \mathrm{val}[u] + \sum_{v \in \lch{u}} \dpa[v][0] \\ \end{aligned} \]
则有
\[ \begin{aligned} \dpa[u][0] &= \max\left\{\dpa[\hch{u}][0],\dpa[\hch{u}][1]\right\} + \ldp[u][0]\\ \dpa[u][1] &= \dpa[\hch{u}][0] + \ldp[u][1] \end{aligned} \]
使用 Max-Plus Algebra 的方式重写:
\[ \begin{pmatrix} \dpa[u][0] \\ \dpa[u][1] \end{pmatrix} = \begin{pmatrix} \ldp[u][0] & \ldp[u][0] \\ \ldp[u][1] & -\infty \end{pmatrix} \begin{pmatrix} \dpa[\hch{u}][0] \\ \dpa[\hch{u}][1] \end{pmatrix} \] 接下来只需要树链剖分维护矩阵转移就行了,单次修改的时间复杂度为 \(\BigO{\log^2n}\)。
代码
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
const int N = 100010;
const int T = 4 * N;
const int INF = 0x3f3f3f3f;
int n;
int fst[N], nxt[2 * N], to[2 * N], tot = 1;
struct vec {
int a[2];
vec() { memset(a, 0, sizeof a); }
int &operator [](int x) { return a[x]; }
const int &operator [](int x) const { return a[x]; }
} eye;
struct mat {
int a[2][2];
mat() { memset(a, 0, sizeof a); }
int *operator [](int x) { return a[x]; }
const int *operator [](int x) const { return a[x]; }
mat operator *(const mat &b) const {
mat c;
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
c[i][j] = -INF;
for (int k = 0; k < 2; k++)
c[i][j] = max(c[i][j], a[i][k] + b[k][j]);
}
}
return c;
}
vec operator *(const vec &b) const {
vec c;
for (int i = 0; i < 2; i++) {
c[i] = -INF;
for (int j = 0; j < 2; j++)
c[i] = max(c[i], a[i][j] + b[j]);
}
return c;
}
};
void link(int u, int v) {
nxt[++tot] = fst[u];
fst[u] = tot; to[tot] = v;
}
int cnt = 0, id[N], top[N], dep[N], bot[N], sz[N], son[N], fa[N];
int a[N], ldp[N][2], dp[N][2];
mat val[T], init[N];
void dfs1(int u, int p) {
dep[u] = dep[p] + 1; fa[u] = p;
sz[u] = 1; int mx = -1;
for (int e = fst[u]; e; e = nxt[e]) {
int v = to[e];
if (v == p) continue;
dfs1(v, u); sz[u] += sz[v];
if (sz[v] > mx) mx = sz[v], son[u] = v;
}
}
void dfs2(int u, int tp) {
id[u] = ++cnt, top[u] = tp;
bot[top[u]] = u;
if (son[u]) dfs2(son[u], tp);
for (int e = fst[u]; e; e = nxt[e]) {
int v = to[e];
if (v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
}
#define lch (rt << 1)
#define rch (lch | 1)
#define larg lch, l, mid
#define rarg rch, mid + 1, r
void pushup(int rt) {
val[rt] = val[lch] * val[rch];
}
void build(int rt, int l, int r) {
if (l == r) return;
int mid = (l + r) >> 1;
build(larg); build(rarg);
pushup(rt);
}
mat query(int rt, int l, int r, int ql, int qr) {
if (ql <= l && r <= qr) return val[rt];
int mid = (l + r) >> 1;
if (qr <= mid) return query(larg, ql, qr);
if (ql > mid) return query(rarg, ql, qr);
return query(larg, ql, qr) * query(rarg, ql, qr);
}
void update(int rt, int l, int r, int p, const mat &v) {
if (l == r) { val[rt] = v; return; }
int mid = (l + r) >> 1;
if (p <= mid) update(larg, p, v);
else update(rarg, p, v);
pushup(rt);
}
void updateu(int u) {
mat m; m[0][0] = m[0][1] = ldp[u][0];
m[1][0] = ldp[u][1]; m[1][1] = -INF;
update(1, 1, n, id[u], m);
}
void dfs3(int u) {
ldp[u][0] = 0, ldp[u][1] = a[u];
for (int e = fst[u]; e; e = nxt[e]) {
int v = to[e];
if (v == fa[u]) continue;
dfs3(v);
if (v == son[u]) continue;
ldp[u][0] += max(dp[v][0], dp[v][1]);
ldp[u][1] += dp[v][0];
}
updateu(u);
vec res = query(1, 1, n, id[u], id[bot[top[u]]]) * eye;
dp[u][0] = res[0]; dp[u][1] = res[1];
}
void modify(int u, int p) {
ldp[u][1] += p - a[u]; a[u] = p;
updateu(u);
while (true) {
u = top[u];
vec res = query(1, 1, n, id[u], id[bot[u]]) * eye;
if (u != 1) {
ldp[fa[u]][0] += max(res[0], res[1]) - max(dp[u][0], dp[u][1]);
ldp[fa[u]][1] += res[0] - dp[u][0];
}
dp[u][0] = res[0]; dp[u][1] = res[1];
if (u == 1) break;
updateu(u = fa[u]);
}
}
int main() {
int q; scanf("%d%d", &n, &q);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1; i < n; i++) {
int x, y; scanf("%d%d", &x, &y);
link(x, y); link(y, x);
}
dfs1(1, 0); dfs2(1, 1); build(1, 1, n); dfs3(1);
while (q--) {
int x, y; scanf("%d%d", &x, &y);
modify(x, y);
printf("%d\n", max(dp[1][0], dp[1][1]));
}
return 0;
}