最大多段连续子段和
问题概要
给定一个个元素的环形序列(号与号元素相邻),元素有正有负,求段连续的子段,使子段和最大,保证至少有个正数。
。
暴力
破环成链,类似序列上的最大连续子段和,定义表示以为右端点选了段的最大子段和,口胡状态转移:
正解
我们发现:既然我们要选一个正数,不如把这个正数所在的一整段正数全部选了;类似地,既然我们要不选一个负数,不如把这个负数所在的一整段负数排除掉。
于是乎我们就可以把整个环缩成若干段正负交替的“段”,假设我们缩出来是段正的夹着段负的,那么分情况讨论:
- ,直接选所有的正数就行了,注意我们可以把这里的一段正数断成两段或更多来选,因此无论如何都能选够段,此时明显最优。
- ,我们不能每一段都选了,但是我们可以假设我们先选了段,然后通过或消除,或合并的方式进行“调整”。
考虑调整的方式:
- 选一段负数,这等价于把两边的正数段合并了,总段数,对答案有负贡献。
- 不选一段正数,这等价于把两边的负数段合并了,总段数,对答案也有负贡献。
注意:我们不能先选一段负数,然后再不选其相邻的正数(或反过来),因为这样不仅血亏,对总段数也没有影响。
因为调整一次总段数减少1,以上的“调整”操作要进行次。
既然对答案的贡献都是负的,我们把每一段的“值”变成该段数和的绝对值,然后问题就被转化为了:选择个不相邻的数,使它们的和(对答案的负贡献)最小。
这又是另一个经典问题了,参见CTSC 2007 数据备份,我们维护一个小根堆,每次取出堆顶元素,并在答案上减去其的贡献,然后执行“合并”,即如果这个数的前驱是,本身是,右边是,我们删除,然后插入一个新的权值为的等效节点。这样的话,如果我们下次再选到这个节点,对答案的影响等价于我们选了而不选。这在思想上类似于网络流的退流。为了维护前驱/后继关系,我们同时需要引入一个环形链表,总的复杂度约为。
我们发现这样做环形和非环形的差别就很小了,甚至是环形的更好写(不需要处理边界情况)。
代码:(注意因为人懒所以用了priority_queue
饮鸩止渴,所以用的是懒惰删除法,还有就是代码中是大根堆维护负绝对值,本质上是一样的)
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
#include <queue>
#define register
using namespace std;
typedef long long ll;
typedef unsigned int u32;
const int INF = 0x3f3f3f3f;
const int N = 600100;
int n, m, ans, cnt;
int a[N], s[N], tot = 0;
int pre[N], nxt[N];
bool inv[N];
struct cmp {
bool operator ()(int a, int b) {
return s[a] < s[b];
}
};
priority_queue<int, vector<int>, cmp> q;
#define max(a, b) ((a) > (b) ? (a) : (b))
#define abs(a) ((a) > 0 ? (a) : -(a))
#define sgn(a) ((a) >= 0 ? 1 : -1)
void merge(int x) {
ans += s[x];
s[++cnt] = s[pre[x]] + s[nxt[x]] - s[x];
inv[pre[x]] = inv[nxt[x]] = true;
pre[cnt] = pre[pre[x]];
nxt[pre[cnt]] = cnt;
nxt[cnt] = nxt[nxt[x]];
pre[nxt[cnt]] = cnt;
q.push(cnt);
}
void solve() {
ans = 0;
for (int i = 1; i <= n; ) {
s[++tot] = 0;
if (a[i] >= 0) while (i <= n && a[i] >= 0) s[tot] += a[i++];
else while (i <= n && a[i] < 0) s[tot] += a[i++];
if (s[i] >= 0) ans += s[i];
}
if (sgn(s[1]) == sgn(s[tot])) s[1] += s[tot--];
if (tot / 2 <= m) return;
for (int i = 1; i <= tot; i++) {
pre[i] = i == 1 ? tot : i - 1;
nxt[i] = i == tot ? 1 : i + 1;
s[i] = -abs(s[i]);
q.push(i);
}
cnt = tot;
int left = tot / 2 - m;
while (left--) {
int now = q.top(); q.pop();
if (inv[now]) left++;
else merge(now);
}
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++i) scanf("%d", &a[i]);
solve();
printf("%d\n", ans);
return 0;
}