问题概要
给定一个 \(n\) 个元素的环形序列(\(1\) 号与 \(n\) 号元素相邻),元素有正有负,求 \(m\) 段连续的子段,使子段和最大,保证至少有 \(m\) 个正数。
\(n \le 3 \times 10^5, m \le 10^5\)。
\(O(n^2m)\) 暴力
破环成链,类似序列上的最大连续子段和,定义 \(dp[i][j]\) 表示以 \(i\) 为右端点选了 \(j\) 段的最大子段和,口胡状态转移:
\[
dp[i][j] = a[i] + \max \begin{cases}
dp[i - 1][j] \\
\max_{1 \le k < i} dp[k][j - 1]
\end{cases}
\] 后面的 \(\max\) 随便搞搞搞成 \(O(1)\) 的,然后整个状态转移就是 \(O(nm)\) 的,加上破环成链的代价,总的时间复杂度为 \(O(n^2m)\),实测拿 50 分。
\(O(n\log n)\) 正解
我们发现:既然我们要选一个正数,不如把这个正数所在的一整段正数全部选了;类似地,既然我们要不选一个负数,不如把这个负数所在的一整段负数排除掉。
于是乎我们就可以把整个环缩成若干段正负交替的 “段”,假设我们缩出来是 \(p\) 段正的夹着 \(p\) 段负的,那么分情况讨论:
- \(p \le m\),直接选所有的正数就行了,注意我们可以把这里的一段正数断成两段或更多来选,因此无论如何都能选够 \(m\) 段,此时明显最优。
- \(p > m\),我们不能每一段都选了,但是我们可以假设我们先选了 \(p\) 段,然后通过或消除,或合并的方式进行 “调整”。
考虑调整的方式:
- 选一段负数,这等价于把两边的正数段合并了,总段数 \(-1\),对答案有负贡献。
- 不选一段正数,这等价于把两边的负数段合并了,总段数 \(-1\),对答案也有负贡献。
注意:我们不能先选一段负数,然后再不选其相邻的正数(或反过来),因为这样不仅血亏,对总段数也没有影响。
因为调整一次总段数减少 1,以上的 “调整” 操作要进行 \(p - m\) 次。
既然对答案的贡献都是负的,我们把每一段的 “值” 变成该段数和的绝对值,然后问题就被转化为了:选择 \(p - m\) 个不相邻的数,使它们的和(对答案的负贡献)最小。
这又是另一个经典问题了,参见 CTSC 2007 数据备份,我们维护一个小根堆,每次取出堆顶元素,并在答案上减去其的贡献,然后执行 “合并”,即如果这个数的前驱是 \(a\),本身是 \(b\),右边是 \(c\),我们删除 \(a, b, c\),然后插入一个新的权值为 \(a + c - b\) 的等效节点。这样的话,如果我们下次再选到这个节点,对答案的影响等价于我们选了 \(a,c\) 而不选 \(b\)。这在思想上类似于网络流的退流。为了维护前驱 / 后继关系,我们同时需要引入一个环形链表,总的复杂度约为 \(O(n\log n)\)。
我们发现这样做环形和非环形的差别就很小了,甚至是环形的更好写(不需要处理边界情况)。
代码:(注意因为人懒所以用了 priority_queue
饮鸩止渴,所以用的是懒惰删除法,还有就是代码中是大根堆维护负绝对值,本质上是一样的)
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
#include <queue>
#define register
using namespace std;
typedef long long ll;
typedef unsigned int u32;
const int INF = 0x3f3f3f3f;
const int N = 600100;
int n, m, ans, cnt;
int a[N], s[N], tot = 0;
int pre[N], nxt[N];
bool inv[N];
struct cmp {
bool operator ()(int a, int b) {
return s[a] < s[b];
}
};
<int, vector<int>, cmp> q;
priority_queue
#define max(a, b) ((a) > (b) ? (a) : (b))
#define abs(a) ((a) > 0 ? (a) : -(a))
#define sgn(a) ((a) >= 0 ? 1 : -1)
void merge(int x) {
+= s[x];
ans [++cnt] = s[pre[x]] + s[nxt[x]] - s[x];
s[pre[x]] = inv[nxt[x]] = true;
inv[cnt] = pre[pre[x]];
pre[pre[cnt]] = cnt;
nxt[cnt] = nxt[nxt[x]];
nxt[nxt[cnt]] = cnt;
pre.push(cnt);
q}
void solve() {
= 0;
ans for (int i = 1; i <= n; ) {
[++tot] = 0;
sif (a[i] >= 0) while (i <= n && a[i] >= 0) s[tot] += a[i++];
else while (i <= n && a[i] < 0) s[tot] += a[i++];
if (s[i] >= 0) ans += s[i];
}
if (sgn(s[1]) == sgn(s[tot])) s[1] += s[tot--];
if (tot / 2 <= m) return;
for (int i = 1; i <= tot; i++) {
[i] = i == 1 ? tot : i - 1;
pre[i] = i == tot ? 1 : i + 1;
nxt[i] = -abs(s[i]);
s.push(i);
q}
= tot;
cnt int left = tot / 2 - m;
while (left--) {
int now = q.top(); q.pop();
if (inv[now]) left++;
else merge(now);
}
}
int main() {
("%d%d", &n, &m);
scanffor (int i = 1; i <= n; ++i) scanf("%d", &a[i]);
();
solve("%d\n", ans);
printfreturn 0;
}