问题概要

给定一个 \(n\) 个元素的环形序列(\(1\) 号与 \(n\) 号元素相邻),元素有正有负,求 \(m\) 段连续的子段,使子段和最大,保证至少有 \(m\) 个正数。

\(n \le 3 \times 10^5, m \le 10^5\)

\(O(n^2m)\) 暴力

破环成链,类似序列上的最大连续子段和,定义 \(dp[i][j]\) 表示以 \(i\) 为右端点选了 \(j\) 段的最大子段和,口胡状态转移: \[ dp[i][j] = a[i] + \max \begin{cases} dp[i - 1][j] \\ \max_{1 \le k < i} dp[k][j - 1] \end{cases} \] 后面的 \(\max\) 随便搞搞搞成 \(O(1)\) 的,然后整个状态转移就是 \(O(nm)\) 的,加上破环成链的代价,总的时间复杂度为 \(O(n^2m)\),实测拿 50 分。

\(O(n\log n)\) 正解

我们发现:既然我们要选一个正数,不如把这个正数所在的一整段正数全部选了;类似地,既然我们要不选一个负数,不如把这个负数所在的一整段负数排除掉。

于是乎我们就可以把整个环缩成若干段正负交替的 “段”,假设我们缩出来是 \(p\) 段正的夹着 \(p\) 段负的,那么分情况讨论:

  1. \(p \le m\),直接选所有的正数就行了,注意我们可以把这里的一段正数断成两段或更多来选,因此无论如何都能选够 \(m\) 段,此时明显最优。
  2. \(p > m\),我们不能每一段都选了,但是我们可以假设我们先选了 \(p\) 段,然后通过或消除,或合并的方式进行 “调整”。

考虑调整的方式:

  1. 选一段负数,这等价于把两边的正数段合并了,总段数 \(-1\),对答案有负贡献。
  2. 不选一段正数,这等价于把两边的负数段合并了,总段数 \(-1\),对答案也有负贡献。

注意:我们不能先选一段负数,然后再不选其相邻的正数(或反过来),因为这样不仅血亏,对总段数也没有影响。

因为调整一次总段数减少 1,以上的 “调整” 操作要进行 \(p - m\) 次。

既然对答案的贡献都是负的,我们把每一段的 “值” 变成该段数和的绝对值,然后问题就被转化为了:选择 \(p - m\) 个不相邻的数,使它们的和(对答案的负贡献)最小。

这又是另一个经典问题了,参见 CTSC 2007 数据备份,我们维护一个小根堆,每次取出堆顶元素,并在答案上减去其的贡献,然后执行 “合并”,即如果这个数的前驱是 \(a\),本身是 \(b\),右边是 \(c\),我们删除 \(a, b, c\),然后插入一个新的权值为 \(a + c - b\) 的等效节点。这样的话,如果我们下次再选到这个节点,对答案的影响等价于我们选了 \(a,c\) 而不选 \(b\)。这在思想上类似于网络流的退流。为了维护前驱 / 后继关系,我们同时需要引入一个环形链表,总的复杂度约为 \(O(n\log n)\)

我们发现这样做环形和非环形的差别就很小了,甚至是环形的更好写(不需要处理边界情况)。

代码:(注意因为人懒所以用了 priority_queue饮鸩止渴,所以用的是懒惰删除法,还有就是代码中是大根堆维护负绝对值,本质上是一样的)

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
#include <queue>
#define register

using namespace std;
typedef long long ll;
typedef unsigned int u32;
const int INF = 0x3f3f3f3f;
const int N = 600100;
int n, m, ans, cnt;
int a[N], s[N], tot = 0;
int pre[N], nxt[N];
bool inv[N];

struct cmp {
    bool operator ()(int a, int b) {
        return s[a] < s[b];
    }
};
priority_queue<int, vector<int>, cmp> q;

#define max(a, b) ((a) > (b) ? (a) : (b))
#define abs(a) ((a) > 0 ? (a) : -(a))
#define sgn(a) ((a) >= 0 ? 1 : -1)

void merge(int x) {
    ans += s[x];
    s[++cnt] = s[pre[x]] + s[nxt[x]] - s[x];
    inv[pre[x]] = inv[nxt[x]] = true;
    pre[cnt] = pre[pre[x]];
    nxt[pre[cnt]] = cnt;
    nxt[cnt] = nxt[nxt[x]];
    pre[nxt[cnt]] = cnt;
    q.push(cnt);
}

void solve() {
    ans = 0;
    for (int i = 1; i <= n; ) {
        s[++tot] = 0;
        if (a[i] >= 0) while (i <= n && a[i] >= 0) s[tot] += a[i++];
        else while (i <= n && a[i] < 0) s[tot] += a[i++];
        if (s[i] >= 0) ans += s[i]; 
    }
    if (sgn(s[1]) == sgn(s[tot])) s[1] += s[tot--];
    if (tot / 2 <= m) return;
    for (int i = 1; i <= tot; i++) {
        pre[i] = i == 1 ? tot : i - 1;
        nxt[i] = i == tot ? 1 : i + 1;
        s[i] = -abs(s[i]);
        q.push(i);
    }
    cnt = tot;
    int left = tot / 2 - m;
    while (left--) {
        int now = q.top(); q.pop();
        if (inv[now]) left++;
        else merge(now);
    }
}

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; ++i) scanf("%d", &a[i]);
    solve();
    printf("%d\n", ans);
    return 0;
}