VEXOBEN
Vexoben
Oct 6, 2018
It takes 9 minutes to read this article.

题目链接:UOJ269 [清华集训2016]如何优雅地求和

(刷新以获取数学公式)

题意

给定整数\(n\),\(m\),\(x\),求:

\[\sum_{k = 0}^{n} f(k) {n \choose k} x^k (1 - x) ^ {n - k}\]

其中\(f(x)\)是一个\(m\)次函数,以点值形式给出,\(a_i\)表示\(x = i\)时的函数值

保证\(1≤n≤109, 1≤m≤20000, 0≤ai,x﹤998244353\)。

答案模998244353输出

题解

题目给出\(f(x)\)的点值,这启发我们想到,如果可以求出\(g(t) = \sum_{k = 0}^{t} f(k) {t \choose k} x^k (1 - x) ^ {t - k}\)在\(0\), \(1\), … ,\(m\)的点值,并且如果答案恰好是关于\(n\)的\(m\)次多项式的话,就可以用拉格朗日插值求值了

事实上,确实可以证明答案是关于\(n\)的\(m\)次多项式(证明搬自cly_none的博客):

我们可以将\(f(x)\)记做若干下降幂的和,第\(k\)项为\(f_k x ^ {\underline k}\),于是有:

那么我们只需要求出\(g(t)\)的几个点值就好了

从而令\(A(k) = \frac{f(k)x^k}{k!}\), \(B(k) = \frac{(1-x)^{k}}{k!}\), 就有 \(g = A * B\),用FFT加速计算即可

稍微写一下拉格朗日插值:

我们现在得到\((x_1, y_1), (x_2, y_2), ... (x_n, y_n), (x_{n + 1}, y_{n + 1})\)。我们构造\(n + 1\)个多项式,第\(i\)个多项式满足\(x = x_i\)时函数值为\(y_i\),\(x = x_j(j ≠ i)\)时函数值为\(0\),那么要求的多项式就是这\(n + 1\)个多项式之和。显然可以得到第\(i\)个多项式:

\[f_{i}(x) = y_i\frac{\prod_{j ≠ i} (x - x_j)}{\prod_{j ≠ i} (x_i - x_j)}\]

不用把多项式化简,求值时将\(x = x_0\)带进去计算即可。

因为\(x_i\)给出的是\(0\), \(1\), … \(m\),求值可以做到\(O(m)\)的复杂度。

#pragma GCC optimize("2,Ofast,inline")
#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int G = 3;
const int N = 1 << 18;
const int mod = 998244353;

int m, rev[N];
LL n, x, a[N], b[N], f[N], inv[N], fac[N], fav[N];

LL Qpow(LL x, int p) {
	LL ans = 1;
	while (p) {
		if (p & 1) ans = ans * x % mod;
		x = x * x % mod;
		p >>= 1;
	}
	return ans;
}

LL Inv(LL x) {
	return Qpow(x, mod - 2);
}

void NTT(LL *a, int lim, int type) {
	for (int i = 0; i < lim; ++i) {
		if (i < rev[i]) swap(a[i], a[rev[i]]);
	}
	for (int i = 2; i <= lim; i <<= 1) {
		int mid = i >> 1;
		LL Wn = Qpow(G, (mod - 1) / i);
		if (type == -1) Wn = Inv(Wn);
		for (int j = 0; j < lim; j += i) {
			LL W = 1;
			for (int k = 0; k < mid; ++k) {
				LL x = a[j + k], y = a[j + k + mid] * W % mod;
				a[j + k] = (x + y >= mod) ? (x + y - mod) : (x + y);
				a[j + k + mid] = (x < y) ? (x + mod - y) : (x - y);
				W = W * Wn % mod;
			}
		}
	}
	if (type == -1) {
		LL tmp = Inv(lim);
		for (int i = 0; i < lim; ++i)
			a[i] = a[i] * tmp % mod;
	}
}

void init() {
	scanf("%lld%d%lld", &n, &m, &x);
	for (int i = 0; i <= m; ++i)
		scanf("%lld", &f[i]);
	fac[0] = fav[0] = 1;
	fac[1] = fav[1] = inv[1] = 1;
	for (int i = 2; i < N; ++i) {
		fac[i] = fac[i - 1] * i % mod;
		inv[i] = (-mod / i * inv[mod % i] % mod) + mod;
		fav[i] = fav[i - 1] * inv[i] % mod;
	}
	for (int i = 0; i <= m; ++i) {
		a[i] = f[i] * Qpow(x, i) % mod * fav[i] % mod;
		b[i] = Qpow(mod + 1 - x, i) * fav[i] % mod;
	}
	int lim = 1, l = 0;
	while (lim < m + m + 2) lim <<= 1, ++l;
	for (int i = 0; i < lim; ++i)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l - 1);
	NTT(a, lim, 1); NTT(b, lim, 1);
	for (int i = 0; i < lim; ++i)
		a[i] = a[i] * b[i] % mod;
	NTT(a, lim, -1);
	for (int i = 0; i <= m; ++i)
		a[i] = a[i] * fac[i] % mod;
}

void Solve() {
	LL ans = 0, tot = 1;
	for (int i = 0; i <= m; ++i)
		tot = tot * (n - i) % mod;
	for (int i = 0; i <= m; ++i) {
		LL up = tot * Inv(n - i) % mod;
		up = up * a[i] % mod;
		LL down = fac[i] * fac[m - i] % mod;
		if ((m - i) & 1) down = mod - down;
		LL tmp = up * Inv(down);
		ans = (ans + tmp) % mod;
	}
	printf("%lld\n", ans);
}

int main() {
	init();
	if (n <= m) return 0 * printf("%lld\n", a[n]);
	Solve();
	return 0;
}