题意

给定一个 \(n\) 个顶点 \(m\) 条边的无向图,第 \(i\) 个节点有可靠度 \(r_i\),要求在节点中选取四个节点建设基站,其中有两个主基站,两个副基站。要求两个主基站必须和分别剩下的三个基站相连边,而两个副基站必须分别和两个主基站相连边。一个由两个主基站 \(a, b\) 和两个副基站 \(c, d\) 组成的基站的总可靠度定义为 \((r_a + 1)(r_b + b) + r_cr_d\),求最大总可靠度。

\(n \le 5 \times 10^4, m \le 2 \times 10^5\)

解法

题目要求我们寻找的四个点其实是三元环的并,主基站就是公共边的两个端点,而副基站就是公共边两侧的另外两个节点,因此,如果我们能够找出所有的三元环,并且对于每一条边维护他所在三元环相对顶点可靠度的最大和次大值,我们就可以通过枚举每一条边的方式完成最大总可靠度的计算。

因此,这个题目的核心就是寻找三元环,朴素的算法是 \(O(n^3)\) 的,明显不足以通过此题,因此我们要寻找复杂度更优的做法。

我们把顶点分为两类 —— 大点和小点,分类的依据是点的度数,如果一个顶点的度数大于 \(\sqrt m\) 就分在大点里,反之就落在小点里。我们对大点和小点分别处理。

对于大点,我们暴力枚举每个点

因为对于无向图来说,有 \[ \sum_{i = 1}^n \deg(i) = 2m \] 因此我们可以断言:大点的数量是 \(O(\sqrt m)\) 的,因此枚举大点的时间复杂度是 \(O((\sqrt m)^3) = O(m\sqrt m)\)

在这里我们处理了全部由大点构成的三元环。

对于小点,我们暴力枚举它的两条出边并判断对面两个点是否相连

因为我们在枚举一个顶点 \(u\) 的出边的时候,显然每条边都被枚举了 \(\deg(u) - 1\) 次,所以图中小点发出的每一条边都被枚举了 \(O(\sqrt m)\) 次,总共有 \(O(m)\) 条边,因此处理小点的时间复杂度还是 \(O(m \sqrt m)\)

在这里我们处理了至少包含一个小点的三元环。

合并在一起,这个算法可以找到所有的三元环,总的时间复杂度为 \(O(m \sqrt m)\),虽然对于稠密图而言因为 \(m = O(n^2)\) 因此算法的复杂度还是接近 \(O(n^3)\),但对于题目中的图大小还是非常优秀的。

代码实现也不是非常复杂,使用链式前向星存图,然后额外维护数组 grtsnd 表示一条边所在三元环中所对点可靠值最大和次大的点的编号,再首先一个哈希表用于快速查询点和点之间的相连关系就可以了,注意常数。

#include <cstdio>
#include <cmath>

using namespace std;
typedef long long ll;
const int N = 50100;
const int M = 500100;
const int HASH = 1000007;
const int INF = 0x3f3f3f3f;
int n, m;
int r[N];
int tot, fst[N], nxt[M], to[M];
int deg[N], bc, sc;
int big[N], small[N];
int grt[M], snd[M];

#define max(a, b) ((a) > (b) ? (a) : (b))
#define hashEdge(a, b) (((a) << 12) | (b))

int hfst[HASH], hnxt[M], hval[M], htot;
int hl[M], hr[M];

void link(int a, int b) {
    deg[a]++;
    nxt[++tot] = fst[a];
    fst[a] = tot; to[tot] = b;
    hl[++htot] = a; hr[htot] = b;
    hval[htot] = tot;
    int ent = hashEdge(a, b) % HASH;
    hnxt[htot] = hfst[ent];
    hfst[ent] = htot;
} 

inline int findEdge(int a, int b) { 
    int ent = hashEdge(a, b) % HASH;
    for (int i = hfst[ent]; i; i = hnxt[i])
        if (hl[i] == a && hr[i] == b) return hval[i];
    return -1; 
}

inline void updateEdge(int e, int c) {
    if (c == grt[e] || c == snd[e]) return;
    if (r[c] >= r[grt[e]]) snd[e] = grt[e], grt[e] = c;
    else if (r[c] >= r[snd[e]]) snd[e] = c;
}

inline void updateTriplet(int a, int b, int c, int ab, int ac, int bc) {
    updateEdge(ab, c);
    updateEdge(ab ^ 1, c);
    updateEdge(ac, b);
    updateEdge(ac ^ 1, b);
    updateEdge(bc, a);
    updateEdge(bc ^ 1, a);
}

int main() {    
    scanf("%d%d", &n, &m); tot = htot = 1; r[0] = -INF;
    for (int i = 1; i <= n; i++) scanf("%d", &r[i]);
    for (int i = 1; i <= m; i++) {
        int a, b;
        scanf("%d%d", &a, &b);
        link(a, b);
        link(b, a);
    }
    int s = sqrt(m);
    sc = bc = 0;
    for (int i = 1; i <= n; i++) {
        if (deg[i] < s) small[++sc] = i;
        else big[++bc] = i;
    }
    for (int i = 1; i <= sc; i++) {
        int w = small[i];
        for (int wu = fst[w]; wu; wu = nxt[wu]) {
            int u = to[wu];
            for (int wv = fst[w]; wv != wu; wv = nxt[wv]) {
                int v = to[wv];
                int uv = findEdge(u, v);
                if (uv > 0) updateTriplet(w, u, v, wu, wv, uv);
            }
        }
    }
    for (int i = 1; i <= bc; i++) {
        int w = big[i];
        for (int j = 1; j < i; j++) {
            int u = big[j];
            int wu = findEdge(w, u);
            if (wu < 0) continue;
            for (int k = 1; k < j; k++) {
                int v = big[k];
                int wv = findEdge(w, v), uv = findEdge(u, v);
                if (wv > 0 && uv > 0)
                    updateTriplet(w, u, v, wu, wv, uv);
            }
        }
    }
    int ans = 0; 
    for (int i = 2; i <= tot; i++) {
        int a = to[i], b = to[i ^ 1];
        int c = grt[i], d = snd[i];
        if (c == 0 || d == 0) continue;
        ans = max(ans, (r[a] + 1) * (r[b] + 1) + r[c] * r[d]);
    }
    printf("%d\n", ans);
    return 0;
}