VEXOBEN
Vexoben
Mar 13, 2019
It takes 8 minutes to read this article.

题目链接:AGC019 Yes or No

tourist出的AGC F题,确实是道不可多得的好题。

(刷新以获取数学公式)

题意

有\(n+m\)道判断题,你知道其中\(n\)道答案是\(true\),\(m\)道答案是\(false\),但是只能通过猜测确定每一题的答案。每猜一题,会有人告诉你猜得对不对。你要采取最优策略,使得答对题目的期望尽量多。输出这个期望。

数据范围:\(1≤n,m≤500000\)

题解

记\(f_{n,m}\)为还有\(n\)道答案为\(true\),\(m\)道答案为\(false\),采取最优策略,答错题目的期望数量。

可以写出一个递推式:

\[f_{n,m} = \min \{ \frac{n}{n+m}f_{n-1,m}+\frac{m}{n+m}(f_{n,m-1}+1),\frac{m}{n+m}f_{n,m-1}+\frac{n}{n+m}(f_{n-1,m}+1)\}\] \[f_{n,m}=\frac{n}{n+m}f_{n-1,m}+\frac{m}{n+m}f_{n,m-1}+\frac{min\{n,m\}}{n+m}​\]

对于每一个\((i,j)​\),我们计算\(\frac{min\{i,j\}}{i+j}​\)对\(g_{n,m}​\)的贡献:

\((i,j)\)到\((n,m)\)的路径有\(\dbinom{n+m-i-j}{n-i}\)条。而\((i,j)\)初始的权值是\(\frac{min\{i,j\}}{i+j}\),它每走一步,对应递推式里面\(n+m\)变大\(1\),\(n\)或者\(m\)变大\(1\),最终的权值会变成:

\[\frac{\Pi_{x=i+1}^{n}x\Pi_{x=j+1}^{m}x}{\Pi_{x=i+j+1}^{n+m}} \frac{min\{i,j\}}{i+j}​\]

最终\(f_{n,m}\)会被化简成:

\[\frac{1}{\dbinom{n+m}{n}} \sum_{i=1}^{n}\sum_{j=1}^{m}\dbinom{i+j}{i}\dbinom{n+m-i-j}{n-i} \frac{min\{i,j\}}{i+j}​\]

关键是要观察出这里的组合意义:定义\((i,j)\)的权值为\(\frac{min\{i,j\}}{i+j}\),\(\dbinom{i+j}{i}\dbinom{n+m-i-j}{n-i}\)就是从\((0,0)\)出发,只能向上或向右,经过\((i,j)\)走到\((n,m)\)的方案数,而这个式子就是求一条路径上点权和的期望。

但是这样并不好处理,我们可以继续简化这个式子来获得更好的组合意义。

当\(i≤j\):

\[\dbinom{i+j}{i}\dbinom{n+m-i-j}{n-i} \frac{min\{i,j\}}{i+j}=\dbinom{i+j-1}{i-1}\dbinom{n+m-i-j}{n-i}\]

等式右边就是经过\((i-1,j)\)到\((i,j)​\)这条边的路径条数。

当\(i>j​\)

\[\dbinom{i+j}{i}\dbinom{n+m-i-j}{n-i} \frac{min\{i,j\}}{i+j}=\dbinom{i+j-1}{j-1}\dbinom{n+m-i-j}{n-i}\]

等式右边就是经过\((i,j-1)\)到\((i,j)​\)这条边的路径条数。

那么我们就是在网格图上,有一些边有权值\(1\),问一条路径的期望权值。

大概长这样(有点抽象……):

注意横过来的边没有\((i-1,i)\)到\((i,i)\)的那种边(最后不封口)

我们先假装这些封口边存在,大概长这样(很抽象……):

可以证明这样任意一条路径的权值都是\(\min \{n,m\}\)。那些封口边只有\(\min \{n,m\}\)条,我们枚举它们把贡献减掉就好了。

#include<bits/stdc++.h>
#define fi first
#define se second
#define R register
#define mp make_pair
#define pb push_back
#define LL long long
#define Ldb long double
#define pii pair<int, int>
using namespace std;
const int N = 1e6 + 10;
const int mod = 998244353;

template <typename T> void read(T &x) {
	int f = 0;
	R char c = getchar();
	while (c < '0' || c > '9') f |= (c == '-'), c = getchar();
	for (x = 0; c >= '0' && c <= '9'; c = getchar())
		x = x * 10 - '0' + c;
	if (f) x = -x;
}

int Qpow(int x, int p) {
	int ans = 1;
	while (p) {
		if (p & 1) ans = 1LL * ans * x % mod;
		x = 1LL * x * x % mod;
		p >>= 1;
	}
	return ans;
}

int Inv(int x) {
	return Qpow(x, mod - 2);
}

// ------------------------------------------

int n, m;
int inv[N], fac[N], fav[N];

int C(int x, int y) {
	return 1LL * fac[x] * fav[y] % mod * fav[x - y] % mod;
}

int main() {
	read(n); read(m);
	if (n > m) swap(n, m);
	fac[0] = fav[0] = 1;
	inv[1] = fac[1] = fav[1] = 1;
	for (int i = 2; i <= n + m; ++i) {
		inv[i] = 1LL * -mod / i * inv[mod % i] % mod +mod;
		fac[i] = 1LL * fac[i - 1] * i % mod;
		fav[i] = 1LL * fav[i - 1] * inv[i] % mod;
	}
	int ans = 0;
	for (int i = 1; i <= n; ++i) {
		int tmp = 1LL * C(i * 2 - 1, i) * C(n + m - i * 2, n - i) % mod;
		ans = (ans + tmp >= mod) ? (ans + tmp - mod) : (ans + tmp);
	}
	ans = (n - 1LL * ans * Inv(C(n + m ,n))) % mod;
	ans = n + m - ans;
	if (ans < 0) ans += mod;
	cout << ans << endl;
	return 0;
}