题目链接:AGC019 Yes or No
tourist出的AGC F题,确实是道不可多得的好题。
(刷新以获取数学公式)
题意
有\(n+m\)道判断题,你知道其中\(n\)道答案是\(true\),\(m\)道答案是\(false\),但是只能通过猜测确定每一题的答案。每猜一题,会有人告诉你猜得对不对。你要采取最优策略,使得答对题目的期望尽量多。输出这个期望。
数据范围:\(1≤n,m≤500000\)
题解
记\(f_{n,m}\)为还有\(n\)道答案为\(true\),\(m\)道答案为\(false\),采取最优策略,答错题目的期望数量。
可以写出一个递推式:
\[f_{n,m} = \min \{ \frac{n}{n+m}f_{n-1,m}+\frac{m}{n+m}(f_{n,m-1}+1),\frac{m}{n+m}f_{n,m-1}+\frac{n}{n+m}(f_{n-1,m}+1)\}\] \[f_{n,m}=\frac{n}{n+m}f_{n-1,m}+\frac{m}{n+m}f_{n,m-1}+\frac{min\{n,m\}}{n+m}\]对于每一个\((i,j)\),我们计算\(\frac{min\{i,j\}}{i+j}\)对\(g_{n,m}\)的贡献:
\((i,j)\)到\((n,m)\)的路径有\(\dbinom{n+m-i-j}{n-i}\)条。而\((i,j)\)初始的权值是\(\frac{min\{i,j\}}{i+j}\),它每走一步,对应递推式里面\(n+m\)变大\(1\),\(n\)或者\(m\)变大\(1\),最终的权值会变成:
\[\frac{\Pi_{x=i+1}^{n}x\Pi_{x=j+1}^{m}x}{\Pi_{x=i+j+1}^{n+m}} \frac{min\{i,j\}}{i+j}\]最终\(f_{n,m}\)会被化简成:
\[\frac{1}{\dbinom{n+m}{n}} \sum_{i=1}^{n}\sum_{j=1}^{m}\dbinom{i+j}{i}\dbinom{n+m-i-j}{n-i} \frac{min\{i,j\}}{i+j}\]关键是要观察出这里的组合意义:定义\((i,j)\)的权值为\(\frac{min\{i,j\}}{i+j}\),\(\dbinom{i+j}{i}\dbinom{n+m-i-j}{n-i}\)就是从\((0,0)\)出发,只能向上或向右,经过\((i,j)\)走到\((n,m)\)的方案数,而这个式子就是求一条路径上点权和的期望。
但是这样并不好处理,我们可以继续简化这个式子来获得更好的组合意义。
当\(i≤j\):
\[\dbinom{i+j}{i}\dbinom{n+m-i-j}{n-i} \frac{min\{i,j\}}{i+j}=\dbinom{i+j-1}{i-1}\dbinom{n+m-i-j}{n-i}\]等式右边就是经过\((i-1,j)\)到\((i,j)\)这条边的路径条数。
当\(i>j\)
\[\dbinom{i+j}{i}\dbinom{n+m-i-j}{n-i} \frac{min\{i,j\}}{i+j}=\dbinom{i+j-1}{j-1}\dbinom{n+m-i-j}{n-i}\]等式右边就是经过\((i,j-1)\)到\((i,j)\)这条边的路径条数。
那么我们就是在网格图上,有一些边有权值\(1\),问一条路径的期望权值。
大概长这样(有点抽象……):
.png)
注意横过来的边没有\((i-1,i)\)到\((i,i)\)的那种边(最后不封口)
我们先假装这些封口边存在,大概长这样(很抽象……):
.png)
可以证明这样任意一条路径的权值都是\(\min \{n,m\}\)。那些封口边只有\(\min \{n,m\}\)条,我们枚举它们把贡献减掉就好了。
#include<bits/stdc++.h>
#define fi first
#define se second
#define R register
#define mp make_pair
#define pb push_back
#define LL long long
#define Ldb long double
#define pii pair<int, int>
using namespace std;
const int N = 1e6 + 10;
const int mod = 998244353;
template <typename T> void read(T &x) {
int f = 0;
R char c = getchar();
while (c < '0' || c > '9') f |= (c == '-'), c = getchar();
for (x = 0; c >= '0' && c <= '9'; c = getchar())
x = x * 10 - '0' + c;
if (f) x = -x;
}
int Qpow(int x, int p) {
int ans = 1;
while (p) {
if (p & 1) ans = 1LL * ans * x % mod;
x = 1LL * x * x % mod;
p >>= 1;
}
return ans;
}
int Inv(int x) {
return Qpow(x, mod - 2);
}
// ------------------------------------------
int n, m;
int inv[N], fac[N], fav[N];
int C(int x, int y) {
return 1LL * fac[x] * fav[y] % mod * fav[x - y] % mod;
}
int main() {
read(n); read(m);
if (n > m) swap(n, m);
fac[0] = fav[0] = 1;
inv[1] = fac[1] = fav[1] = 1;
for (int i = 2; i <= n + m; ++i) {
inv[i] = 1LL * -mod / i * inv[mod % i] % mod +mod;
fac[i] = 1LL * fac[i - 1] * i % mod;
fav[i] = 1LL * fav[i - 1] * inv[i] % mod;
}
int ans = 0;
for (int i = 1; i <= n; ++i) {
int tmp = 1LL * C(i * 2 - 1, i) * C(n + m - i * 2, n - i) % mod;
ans = (ans + tmp >= mod) ? (ans + tmp - mod) : (ans + tmp);
}
ans = (n - 1LL * ans * Inv(C(n + m ,n))) % mod;
ans = n + m - ans;
if (ans < 0) ans += mod;
cout << ans << endl;
return 0;
}