最大多段连续子段和

问题概要

给定一个n个元素的环形序列(1号与n号元素相邻),元素有正有负,求m段连续的子段,使子段和最大,保证至少有m个正数。

n \le 3 \times 10^5, m \le 10^5

O(n^2m)暴力

破环成链,类似序列上的最大连续子段和,定义dp[i][j]表示以i为右端点选了j段的最大子段和,口胡状态转移:

dp[i][j] = a[i] + \max \begin{cases}
    dp[i - 1][j] \\
    \max_{1 \le k < i} dp[k][j - 1]
\end{cases}
后面的\max随便搞搞搞成O(1)的,然后整个状态转移就是O(nm)的,加上破环成链的代价,总的时间复杂度为O(n^2m),实测拿50分。

O(n\log n)正解

我们发现:既然我们要选一个正数,不如把这个正数所在的一整段正数全部选了;类似地,既然我们要不选一个负数,不如把这个负数所在的一整段负数排除掉。

于是乎我们就可以把整个环缩成若干段正负交替的“段”,假设我们缩出来是p段正的夹着p段负的,那么分情况讨论:

  1. p \le m,直接选所有的正数就行了,注意我们可以把这里的一段正数断成两段或更多来选,因此无论如何都能选够m段,此时明显最优。
  2. p > m,我们不能每一段都选了,但是我们可以假设我们先选了p段,然后通过或消除,或合并的方式进行“调整”。

考虑调整的方式:

  1. 选一段负数,这等价于把两边的正数段合并了,总段数-1,对答案有负贡献。
  2. 不选一段正数,这等价于把两边的负数段合并了,总段数-1,对答案也有负贡献。

注意:我们不能先选一段负数,然后再不选其相邻的正数(或反过来),因为这样不仅血亏,对总段数也没有影响。

因为调整一次总段数减少1,以上的“调整”操作要进行p - m次。

既然对答案的贡献都是负的,我们把每一段的“值”变成该段数和的绝对值,然后问题就被转化为了:选择p - m个不相邻的数,使它们的和(对答案的负贡献)最小。

这又是另一个经典问题了,参见CTSC 2007 数据备份,我们维护一个小根堆,每次取出堆顶元素,并在答案上减去其的贡献,然后执行“合并”,即如果这个数的前驱是a,本身是b,右边是c,我们删除a, b, c,然后插入一个新的权值为a + c - b的等效节点。这样的话,如果我们下次再选到这个节点,对答案的影响等价于我们选了a,c而不选b。这在思想上类似于网络流的退流。为了维护前驱/后继关系,我们同时需要引入一个环形链表,总的复杂度约为O(n\log n)

我们发现这样做环形和非环形的差别就很小了,甚至是环形的更好写(不需要处理边界情况)。

代码:(注意因为人懒所以用了priority_queue饮鸩止渴,所以用的是懒惰删除法,还有就是代码中是大根堆维护负绝对值,本质上是一样的)

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
#include <queue>
#define register

using namespace std;
typedef long long ll;
typedef unsigned int u32;
const int INF = 0x3f3f3f3f;
const int N = 600100;
int n, m, ans, cnt;
int a[N], s[N], tot = 0;
int pre[N], nxt[N];
bool inv[N];

struct cmp {
    bool operator ()(int a, int b) {
        return s[a] < s[b];
    }
};
priority_queue<int, vector<int>, cmp> q;

#define max(a, b) ((a) > (b) ? (a) : (b))
#define abs(a) ((a) > 0 ? (a) : -(a))
#define sgn(a) ((a) >= 0 ? 1 : -1)

void merge(int x) {
    ans += s[x];
    s[++cnt] = s[pre[x]] + s[nxt[x]] - s[x];
    inv[pre[x]] = inv[nxt[x]] = true;
    pre[cnt] = pre[pre[x]];
    nxt[pre[cnt]] = cnt;
    nxt[cnt] = nxt[nxt[x]];
    pre[nxt[cnt]] = cnt;
    q.push(cnt);
}

void solve() {
    ans = 0;
    for (int i = 1; i <= n; ) {
        s[++tot] = 0;
        if (a[i] >= 0) while (i <= n && a[i] >= 0) s[tot] += a[i++];
        else while (i <= n && a[i] < 0) s[tot] += a[i++];
        if (s[i] >= 0) ans += s[i]; 
    }
    if (sgn(s[1]) == sgn(s[tot])) s[1] += s[tot--];
    if (tot / 2 <= m) return;
    for (int i = 1; i <= tot; i++) {
        pre[i] = i == 1 ? tot : i - 1;
        nxt[i] = i == tot ? 1 : i + 1;
        s[i] = -abs(s[i]);
        q.push(i);
    }
    cnt = tot;
    int left = tot / 2 - m;
    while (left--) {
        int now = q.top(); q.pop();
        if (inv[now]) left++;
        else merge(now);
    }
}

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; ++i) scanf("%d", &a[i]);
    solve();
    printf("%d\n", ans);
    return 0;
}