概念

众所周知,BST 在特定的插入顺序下常常会退化成链,从而使单次查找的时间复杂度退化为 \(O(n)\),这是很糟糕的,因此出现了一箩筐的平衡树,在保证 BST 性质的同时,尽量保持 BST 的平衡。

Treap 就是当中概念比较简单而且比较好写的一种真的吗?。Treap 是 Tree 和 Heap 的合成,即 “树堆”,其精髓就在于通过维护,使得一棵树在保有 BST 性质的同时,也具有大根堆的性质。为此,每个节点都会随机生成一个额外的权值用于堆性质的维护:

struct node {
    int l, r; // 左右子树
    int cnt, sz; // 当前键值副本树与子树(包括自身)大小
    int val, w; // 键值与随机权值
} t[N];

int newnode(int val) {
    t[++tot].val = val;
    t[tot].cnt = t[tot].sz = 1;
    t[tot].l = t[tot].r = 0;
    t[tot].w = rand();
    return tot;
}

void pushup(int p) { // 维护 sz
    t[p].sz = t[t[p].l].sz + t[t[p].r].sz + t[p].cnt;
}

为了防止讨论边界条件,我们插入两个值为 \(\pm \infty\) 的节点:

void build() {
    tot = 0;
    rt = newnode(-INF);
    t[rt].r = newnode(INF);
    pushup(rt);
}

旋转

改变 BST 的形状而不破坏其性质的操作是二叉树的旋转,分为右旋(zig)和左旋(zag)两种,以右旋为例,右旋把一个节点 \(p\) 的左子树 \(q\) 绕其 “向右(顺时针)” 旋转到原来 \(p\) 的位置,并且让 \(p\) 成为 \(q\) 的右子树,\(q\) 原来的右子树成为 \(p\) 的左子树,左旋类似:

落实在代码上,就是:

void zig(int &p) {
    int q = t[p].l;
    t[p].l = t[q].r, t[q].r = p;
    p = q; pushup(t[p].r); pushup(p); 
}

void zag(int &p) {
    int q = t[p].r;
    t[p].r = t[q].l, t[q].l = p;
    p = q; pushup(t[p].l); pushup(p); 
}

注意 pushup 的顺序。

排名相关查询

我们接下来考虑平衡树的前两个查询操作,分别是查排名和按排名查数,两个都相对简单,可以递归实现:

int rank(int p, int val) {
    if (p == 0) return 0;
    if (val == t[p].val) return t[t[p].l].sz + 1;
    if (val < t[p].val) return rank(t[p].l, val);
    return t[t[p].l].sz + t[p].cnt + rank(t[p].r, val);
}

int query(int p, int rk) {
    if (p == 0) return INF;
    if (t[t[p].l].sz >= rk) return query(t[p].l, rk);
    if (t[t[p].l].sz + t[p].cnt >= rk) return t[p].val;
    return query(t[p].r, rk - t[t[p].l].sz - t[p].cnt);
}

注意这里查询的结果包含 \(\pm \infty\),要去掉。

前驱后继查询

接下来考虑查询 \(x\) 的前驱 / 后继的操作,以前驱为例,我们初始化答案 \(ans\)\(- \infty\),有以下几种情况:

  1. 没有找到 \(x\) 节点:那么答案就在经过的节点上。
  2. 找到 \(x\) 节点,可惜是叶子:同上,答案在经过的节点上。
  3. 找到 \(x\) 节点,不是叶子:若 \(x\) 节点有左子树 \(p\),从 \(p\) 开始一路往右走直到不存在右子树为止,此时节点对应的值即为前驱。

代码:

int querypre(int val) {
    int ans = 1; // t[1].val = -inf
    int p = rt;
    while (p) {
        if (val == t[p].val) {
            if (t[p].l) {
                p = t[p].l;
                while (t[p].r) p = t[p].r;
                ans = p;
            }
            break;
        }
        if (t[p].val < val && t[p].val > t[ans].val) ans = p;
        p = val < t[p].val ? t[p].l : t[p].r;
    }
    return t[ans].val;
}

int querynxt(int val) {
    int ans = 2; // t[2].val = inf
    int p = rt;
    while (p) {
        if (val == t[p].val) {
            if (t[p].r) {
                p = t[p].r;
                while (t[p].l) p = t[p].l;
                ans = p;
            }
            break;
        }
        if (t[p].val > val && t[p].val < t[ans].val) ans = p;
        p = val < t[p].val ? t[p].l : t[p].r;
    }
    return t[ans].val;
}

插入

递归进行,同时注意维护大根堆性质即可:

void insert(int &p, int val) {
    if (p == 0) {
        p = newnode(val);
        return;
    }
    if (val == t[p].val) {
        t[p].cnt++;
        pushup(p);
        return;
    }
    if (val < t[p].val) {
        insert(t[p].l, val);
        if (t[p].w < t[t[p].l].w) zig(p);
    } else {
        insert(t[p].r, val);
        if (t[p].w < t[t[p].r].w) zag(p);
    }
    pushup(p);
}

删除

我们利用旋转一路把要删的节点转到叶子的位置,然后直接删除即可,同时一路上维护堆的性质:

void remove(int &p, int val) {
    if (p == 0) return;
    if (val == t[p].val) {
        if (t[p].cnt > 1) {
            t[p].cnt--;
            pushup(p);
            return;
        }
        if (t[p].l || t[p].r) {
            if (t[p].r == 0 || t[t[p].l].w > t[t[p].r].w)
                zig(p), remove(t[p].r, val);
            else
                zag(p), remove(t[p].l, val);
            pushup(p);
        } else p = 0;
        return;
    }
    val < t[p].val ? remove(t[p].l, val) : remove(t[p].r, val);
    pushup(p);
}

完整代码

以下是通过普通平衡树模板题的全部代码,共计 150 行左右,感觉平衡树这个东西代码量还是有点恐怖,而且需要注意非常多的细节,需要结合理解记忆:

#include <cstdio>
#include <cstdlib>
#include <ctime>

using namespace std;
const int N = 100010;
const int INF = 0x3f3f3f3f;
struct node {
    int l, r;
    int cnt, sz;
    int val, w;
} t[N];
int n, tot = 0, rt;

int newnode(int val) {
    t[++tot].val = val;
    t[tot].cnt = t[tot].sz = 1;
    t[tot].l = t[tot].r = 0;
    t[tot].w = rand();
    return tot;
}

void pushup(int p) {
    t[p].sz = t[t[p].l].sz + t[t[p].r].sz + t[p].cnt;
}

void zig(int &p) {
    int q = t[p].l;
    t[p].l = t[q].r, t[q].r = p;
    p = q; pushup(t[p].r); pushup(p); 
}

void zag(int &p) {
    int q = t[p].r;
    t[p].r = t[q].l, t[q].l = p;
    p = q; pushup(t[p].l); pushup(p); 
}

void build() {
    tot = 0;
    rt = newnode(-INF);
    t[rt].r = newnode(INF);
    pushup(rt);
}

int rank(int p, int val) {
    if (p == 0) return 0;
    if (val == t[p].val) return t[t[p].l].sz + 1;
    if (val < t[p].val) return rank(t[p].l, val);
    return t[t[p].l].sz + t[p].cnt + rank(t[p].r, val);
}

int query(int p, int rk) {
    if (p == 0) return INF;
    if (t[t[p].l].sz >= rk) return query(t[p].l, rk);
    if (t[t[p].l].sz + t[p].cnt >= rk) return t[p].val;
    return query(t[p].r, rk - t[t[p].l].sz - t[p].cnt);
}

int querypre(int val) {
    int ans = 1;
    int p = rt;
    while (p) {
        if (val == t[p].val) {
            if (t[p].l) {
                p = t[p].l;
                while (t[p].r) p = t[p].r;
                ans = p;
            }
            break;
        }
        if (t[p].val < val && t[p].val > t[ans].val) ans = p;
        p = val < t[p].val ? t[p].l : t[p].r;
    }
    return t[ans].val;
}

int querynxt(int val) {
    int ans = 2;
    int p = rt;
    while (p) {
        if (val == t[p].val) {
            if (t[p].r) {
                p = t[p].r;
                while (t[p].l) p = t[p].l;
                ans = p;
            }
            break;
        }
        if (t[p].val > val && t[p].val < t[ans].val) ans = p;
        p = val < t[p].val ? t[p].l : t[p].r;
    }
    return t[ans].val;
}

void insert(int &p, int val) {
    if (p == 0) {
        p = newnode(val);
        return;
    }
    if (val == t[p].val) {
        t[p].cnt++;
        pushup(p);
        return;
    }
    if (val < t[p].val) {
        insert(t[p].l, val);
        if (t[p].w < t[t[p].l].w) zig(p);
    } else {
        insert(t[p].r, val);
        if (t[p].w < t[t[p].r].w) zag(p);
    }
    pushup(p);
}

void remove(int &p, int val) {
    if (p == 0) return;
    if (val == t[p].val) {
        if (t[p].cnt > 1) {
            t[p].cnt--;
            pushup(p);
            return;
        }
        if (t[p].l || t[p].r) {
            if (t[p].r == 0 || t[t[p].l].w > t[t[p].r].w)
                zig(p), remove(t[p].r, val);
            else
                zag(p), remove(t[p].l, val);
            pushup(p);
        } else p = 0;
        return;
    }
    val < t[p].val ? remove(t[p].l, val) : remove(t[p].r, val);
    pushup(p);
}

int main() {
    srand(time(0)); build();
    scanf("%d", &n);
    while (n--) {
        int op, x;
        scanf("%d%d", &op, &x);
        switch (op) {
        case 1: insert(rt, x); break;
        case 2: remove(rt, x); break;
        case 3: printf("%d\n", rank(rt, x) - 1); break;
        case 4: printf("%d\n", query(rt, x + 1)); break;
        case 5: printf("%d\n", querypre(x)); break;
        case 6: printf("%d\n", querynxt(x)); break;
        }
    }
    return 0;
}