概念
众所周知,BST 在特定的插入顺序下常常会退化成链,从而使单次查找的时间复杂度退化为 \(O(n)\),这是很糟糕的,因此出现了一箩筐的平衡树,在保证 BST 性质的同时,尽量保持 BST 的平衡。
Treap 就是当中概念比较简单而且比较好写的一种真的吗?。Treap 是 Tree 和 Heap 的合成,即 “树堆”,其精髓就在于通过维护,使得一棵树在保有 BST 性质的同时,也具有大根堆的性质。为此,每个节点都会随机生成一个额外的权值用于堆性质的维护:
struct node {
int l, r; // 左右子树
int cnt, sz; // 当前键值副本树与子树(包括自身)大小
int val, w; // 键值与随机权值
} t[N];
int newnode(int val) {
[++tot].val = val;
t[tot].cnt = t[tot].sz = 1;
t[tot].l = t[tot].r = 0;
t[tot].w = rand();
treturn tot;
}
void pushup(int p) { // 维护 sz
[p].sz = t[t[p].l].sz + t[t[p].r].sz + t[p].cnt;
t}
为了防止讨论边界条件,我们插入两个值为 \(\pm \infty\) 的节点:
void build() {
= 0;
tot = newnode(-INF);
rt [rt].r = newnode(INF);
t(rt);
pushup}
旋转
改变 BST 的形状而不破坏其性质的操作是二叉树的旋转,分为右旋(zig)和左旋(zag)两种,以右旋为例,右旋把一个节点 \(p\) 的左子树 \(q\) 绕其 “向右(顺时针)” 旋转到原来 \(p\) 的位置,并且让 \(p\) 成为 \(q\) 的右子树,\(q\) 原来的右子树成为 \(p\) 的左子树,左旋类似:
落实在代码上,就是:
void zig(int &p) {
int q = t[p].l;
[p].l = t[q].r, t[q].r = p;
t= q; pushup(t[p].r); pushup(p);
p }
void zag(int &p) {
int q = t[p].r;
[p].r = t[q].l, t[q].l = p;
t= q; pushup(t[p].l); pushup(p);
p }
注意 pushup
的顺序。
排名相关查询
我们接下来考虑平衡树的前两个查询操作,分别是查排名和按排名查数,两个都相对简单,可以递归实现:
int rank(int p, int val) {
if (p == 0) return 0;
if (val == t[p].val) return t[t[p].l].sz + 1;
if (val < t[p].val) return rank(t[p].l, val);
return t[t[p].l].sz + t[p].cnt + rank(t[p].r, val);
}
int query(int p, int rk) {
if (p == 0) return INF;
if (t[t[p].l].sz >= rk) return query(t[p].l, rk);
if (t[t[p].l].sz + t[p].cnt >= rk) return t[p].val;
return query(t[p].r, rk - t[t[p].l].sz - t[p].cnt);
}
注意这里查询的结果包含 \(\pm \infty\),要去掉。
前驱后继查询
接下来考虑查询 \(x\) 的前驱 / 后继的操作,以前驱为例,我们初始化答案 \(ans\) 为 \(- \infty\),有以下几种情况:
- 没有找到 \(x\) 节点:那么答案就在经过的节点上。
- 找到 \(x\) 节点,可惜是叶子:同上,答案在经过的节点上。
- 找到 \(x\) 节点,不是叶子:若 \(x\) 节点有左子树 \(p\),从 \(p\) 开始一路往右走直到不存在右子树为止,此时节点对应的值即为前驱。
代码:
int querypre(int val) {
int ans = 1; // t[1].val = -inf
int p = rt;
while (p) {
if (val == t[p].val) {
if (t[p].l) {
= t[p].l;
p while (t[p].r) p = t[p].r;
= p;
ans }
break;
}
if (t[p].val < val && t[p].val > t[ans].val) ans = p;
= val < t[p].val ? t[p].l : t[p].r;
p }
return t[ans].val;
}
int querynxt(int val) {
int ans = 2; // t[2].val = inf
int p = rt;
while (p) {
if (val == t[p].val) {
if (t[p].r) {
= t[p].r;
p while (t[p].l) p = t[p].l;
= p;
ans }
break;
}
if (t[p].val > val && t[p].val < t[ans].val) ans = p;
= val < t[p].val ? t[p].l : t[p].r;
p }
return t[ans].val;
}
插入
递归进行,同时注意维护大根堆性质即可:
void insert(int &p, int val) {
if (p == 0) {
= newnode(val);
p return;
}
if (val == t[p].val) {
[p].cnt++;
t(p);
pushupreturn;
}
if (val < t[p].val) {
(t[p].l, val);
insertif (t[p].w < t[t[p].l].w) zig(p);
} else {
(t[p].r, val);
insertif (t[p].w < t[t[p].r].w) zag(p);
}
(p);
pushup}
删除
我们利用旋转一路把要删的节点转到叶子的位置,然后直接删除即可,同时一路上维护堆的性质:
void remove(int &p, int val) {
if (p == 0) return;
if (val == t[p].val) {
if (t[p].cnt > 1) {
[p].cnt--;
t(p);
pushupreturn;
}
if (t[p].l || t[p].r) {
if (t[p].r == 0 || t[t[p].l].w > t[t[p].r].w)
(p), remove(t[p].r, val);
zigelse
(p), remove(t[p].l, val);
zag(p);
pushup} else p = 0;
return;
}
< t[p].val ? remove(t[p].l, val) : remove(t[p].r, val);
val (p);
pushup}
完整代码
以下是通过普通平衡树模板题的全部代码,共计 150 行左右,感觉平衡树这个东西代码量还是有点恐怖,而且需要注意非常多的细节,需要结合理解记忆:
#include <cstdio>
#include <cstdlib>
#include <ctime>
using namespace std;
const int N = 100010;
const int INF = 0x3f3f3f3f;
struct node {
int l, r;
int cnt, sz;
int val, w;
} t[N];
int n, tot = 0, rt;
int newnode(int val) {
[++tot].val = val;
t[tot].cnt = t[tot].sz = 1;
t[tot].l = t[tot].r = 0;
t[tot].w = rand();
treturn tot;
}
void pushup(int p) {
[p].sz = t[t[p].l].sz + t[t[p].r].sz + t[p].cnt;
t}
void zig(int &p) {
int q = t[p].l;
[p].l = t[q].r, t[q].r = p;
t= q; pushup(t[p].r); pushup(p);
p }
void zag(int &p) {
int q = t[p].r;
[p].r = t[q].l, t[q].l = p;
t= q; pushup(t[p].l); pushup(p);
p }
void build() {
= 0;
tot = newnode(-INF);
rt [rt].r = newnode(INF);
t(rt);
pushup}
int rank(int p, int val) {
if (p == 0) return 0;
if (val == t[p].val) return t[t[p].l].sz + 1;
if (val < t[p].val) return rank(t[p].l, val);
return t[t[p].l].sz + t[p].cnt + rank(t[p].r, val);
}
int query(int p, int rk) {
if (p == 0) return INF;
if (t[t[p].l].sz >= rk) return query(t[p].l, rk);
if (t[t[p].l].sz + t[p].cnt >= rk) return t[p].val;
return query(t[p].r, rk - t[t[p].l].sz - t[p].cnt);
}
int querypre(int val) {
int ans = 1;
int p = rt;
while (p) {
if (val == t[p].val) {
if (t[p].l) {
= t[p].l;
p while (t[p].r) p = t[p].r;
= p;
ans }
break;
}
if (t[p].val < val && t[p].val > t[ans].val) ans = p;
= val < t[p].val ? t[p].l : t[p].r;
p }
return t[ans].val;
}
int querynxt(int val) {
int ans = 2;
int p = rt;
while (p) {
if (val == t[p].val) {
if (t[p].r) {
= t[p].r;
p while (t[p].l) p = t[p].l;
= p;
ans }
break;
}
if (t[p].val > val && t[p].val < t[ans].val) ans = p;
= val < t[p].val ? t[p].l : t[p].r;
p }
return t[ans].val;
}
void insert(int &p, int val) {
if (p == 0) {
= newnode(val);
p return;
}
if (val == t[p].val) {
[p].cnt++;
t(p);
pushupreturn;
}
if (val < t[p].val) {
(t[p].l, val);
insertif (t[p].w < t[t[p].l].w) zig(p);
} else {
(t[p].r, val);
insertif (t[p].w < t[t[p].r].w) zag(p);
}
(p);
pushup}
void remove(int &p, int val) {
if (p == 0) return;
if (val == t[p].val) {
if (t[p].cnt > 1) {
[p].cnt--;
t(p);
pushupreturn;
}
if (t[p].l || t[p].r) {
if (t[p].r == 0 || t[t[p].l].w > t[t[p].r].w)
(p), remove(t[p].r, val);
zigelse
(p), remove(t[p].l, val);
zag(p);
pushup} else p = 0;
return;
}
< t[p].val ? remove(t[p].l, val) : remove(t[p].r, val);
val (p);
pushup}
int main() {
(time(0)); build();
srand("%d", &n);
scanfwhile (n--) {
int op, x;
("%d%d", &op, &x);
scanfswitch (op) {
case 1: insert(rt, x); break;
case 2: remove(rt, x); break;
case 3: printf("%d\n", rank(rt, x) - 1); break;
case 4: printf("%d\n", query(rt, x + 1)); break;
case 5: printf("%d\n", querypre(x)); break;
case 6: printf("%d\n", querynxt(x)); break;
}
}
return 0;
}