做一道AGC题要用整体二分就来学了。
都到现在了还不会整体二分真是菜得可以。
整体二分用于处理多组可二分的询问。
设答案取值区间长度为 C, 询问次数为 q, 一组询问直接二分复杂度为 O(f * logC),那么整体二分就可以在 O((f + q) * logC) 的复杂度内离线完成 q 组询问。
算法过程
类似于cdq分治,将序列的初值、修改、查询都作为一组询问处理。
定义一个函数 Solve(l, r, q) 表示要计算询问数组 q 的答案,当前二分到的答案区间为 [l, r]。
如果 (l == r), 得到答案,退出递归。
否则令 mid = (l + r) » 1。
对于所有的修改操作,如果权值 <= mid, 先修改,加入到数组 q1 中;否则,不修改,加入到数组 q2 中。
然后处理所有的查询操作。因为已经加入了所有权值 <= mid 的修改,我们可以对于所有查询判断这组查询的答案是否应当 <= mid。如果应当,加入数组 q1 中;否则,加入数组 q2 中。
分组完成之后,我们应当消除掉 q1 数组中的修改。
然后可以递归处理 Solve(l, mid, q1), Solve(mid + 1, r, q2)。
不用写成两个数组的递归,可以写在同一个数组里传一个区间。
模板题
BZOJ2738. 矩阵乘法
题意
给你一个N*N的矩阵,每次询问一个子矩形的第K小数。
N<=500,Q<=60000
题解
将每个数看做一次修改,和查询放在同一个数组里。
然后就是裸的整体二分,但是计算答案是否 <= mid 的时候需要用二维树状数组优化。
注意在这份代码中修改总是在查询的前面。
#include<bits/stdc++.h>
using namespace std;
const int N = 505;
int n, m, q;
int tr[N][N], ans[N * N];
struct que {
int x1, y1, x2, y2, k, ty, idx;
}a[N * N * 2], q1[N * N * 2], q2[N * N * 2];
void Updata(int l, int r, int del) {
for (int i = l; i <= n; i += i & -i) {
for (int j = r; j <= n; j += j & -j) {
tr[i][j] += del;
}
}
}
int Query(int l, int r) {
int res = 0;
for (int i = l; i; i -= i & -i) {
for (int j = r; j; j -= j & -j) {
res += tr[i][j];
}
}
return res;
}
int getsum(int x1, int y1, int x2, int y2) {
return Query(x2, y2) + Query(x1 - 1, y1 - 1)
- Query(x1 - 1, y2) - Query(x2, y1 - 1);
}
void Solve(int l, int r, int ql, int qr) {
if (ql > qr) return;
if (l == r) {
for (int i = ql; i <= qr; ++i) {
if (a[i].ty == 1) ans[a[i].idx] = l;
}
return;
}
int mid = (l + r) >> 1, cnt1 = 0, cnt2 = 0;
for (int i = ql; i <= qr; ++i) {
if (!a[i].ty) {
if (a[i].k <= mid) {
Updata(a[i].x1, a[i].y1, 1);
q1[++cnt1] = a[i];
}
else {
q2[++cnt2] = a[i];
}
}
else {
int tmp = getsum(a[i].x1, a[i].y1, a[i].x2, a[i].y2);
if (a[i].k <= tmp) {
q1[++cnt1] = a[i];
}
else {
a[i].k -= tmp;
q2[++cnt2] = a[i];
}
}
}
for (int i = 1; i <= cnt1; ++i) {
if (!q1[i].ty) {
Updata(q1[i].x1, q1[i].y1, -1);
}
a[ql + i - 1] = q1[i];
}
for (int i = 1; i <= cnt2; ++i) {
a[ql + cnt1 + i - 1] = q2[i];
}
Solve(l, mid, ql, ql + cnt1 - 1);
Solve(mid + 1, r, ql + cnt1, qr);
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= n; ++j) {
int x;
scanf("%d", &x);
a[++q] = (que) {i, j, 0, 0, x, 0, 0};
}
}
for (int i = 1; i <= m; ++i) {
int x1, x2, y1, y2, k;
scanf("%d%d%d%d%d", &x1, &y1, &x2, &y2, &k);
a[++q] = (que) {x1, y1, x2, y2, k, 1, i};
}
Solve(0, 1e9, 1, q);
for (int i = 1; i <= m; ++i)
printf("%d\n", ans[i]);
return 0;
}
AGC002D Stamp Rally
题意
给定一张 n 个点 m 条边的无向图,第 i 条边权值为 i ,有 q 组询问,每次有两个人分别位于 x[i] 和 y[i],他们要走过 z[i] 个点,且走过边最大权值最小, 输出这个最大权值。
3 ≤ n, m, q ≤ 100000
题解
修改是边的插入,和查询放在同一个数组里。
每次check就是加入所有权值 <= mid 的边,问 x[i] 和 y[i] 所在联通块大小之和是否大于z[i] (注意两者在同一联通块的情况), 这可以用并查集做。
但是修改是需要撤销的。所以要打个可撤销的并查集。
两个不熟悉的算法套在一起真是调得酸爽。
#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pii pair<int, int>
using namespace std;
const int N = 1e5 + 10;
int n, m, q;
int fa[N], siz[N], ans[N];
vector<pii> vec;
struct que {
int x, y, z, idx;
}a[N * 2], q1[N * 2], q2[N * 2];
int find(int x) {
return (fa[x] == x) ? x : find(fa[x]);
}
void merge(int x, int y) {
if (siz[x] > siz[y]) swap(x, y);
if (x == y) return;
fa[x] = y;
siz[y] += siz[x];
vec.push_back(mp(x, y));
}
void split(int x, int y) {
if (siz[x] > siz[y]) swap(x, y);
fa[x] = x;
siz[y] -= siz[x];
}
void Solve(int l, int r, int ql, int qr) {
if (ql > qr) return;
if (l == r) {
for (int i = ql; i <= qr; ++i) {
ans[a[i].idx] = l;
}
return;
}
vec.clear();
int mid = (l + r) >> 1, cnt1 = 0, cnt2 = 0;
for (int i = ql; i <= qr; ++i) {
if (!a[i].idx) {
if (a[i].z <= mid) {
merge(find(a[i].x), find(a[i].y));
q1[++cnt1] = a[i];
}
else {
q2[++cnt2] = a[i];
}
}
else {
int fx = find(a[i].x), fy = find(a[i].y);
if ((fx != fy && siz[fx] + siz[fy] >= a[i].z) || (siz[fx] >= a[i].z)) {
q1[++cnt1] = a[i];
}
else {
q2[++cnt2] = a[i];
}
}
}
for (int i = 1; i <= cnt1; ++i) {
a[ql + i - 1] = q1[i];
}
for (int i = 1; i <= cnt2; ++i) {
a[ql + cnt1 + i - 1] = q2[i];
}
for (int i = vec.size() - 1; i >= 0; --i) {
split(vec[i].fi, vec[i].se);
}
Solve(l, mid, ql, ql + cnt1 - 1);
for (int i = ql; i <= qr; ++i) {
if (!a[i].idx && a[i].z <= mid) {
merge(find(a[i].x), find(a[i].y));
}
}
Solve(mid + 1, r, ql + cnt1, qr);
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= m; ++i) {
scanf("%d%d", &a[i].x, &a[i].y);
a[i].z = i;
}
scanf("%d", &q);
for (int i = 1; i <= q; ++i) {
scanf("%d%d%d", &a[i + m].x, &a[i + m].y, &a[i + m].z);
a[i + m].idx = i;
}
for (int i = 1; i <= n; ++i) {
fa[i] = i;
siz[i] = 1;
}
Solve(1, m, 1, m + q);
for (int i = 1 ; i <= q; ++i)
printf("%d\n", ans[i]);
return 0;
}