VEXOBEN
Vexoben
Oct 6, 2018
It takes 7 minutes to read this article.

题目链接:UOJ269 [清华集训2016]如何优雅地求和

(刷新以获取数学公式)

题意

给定整数,,,求:

其中是一个次函数,以点值形式给出,表示时的函数值

保证

答案模998244353输出

题解

题目给出的点值,这启发我们想到,如果可以求出, , … ,的点值,并且如果答案恰好是关于次多项式的话,就可以用拉格朗日插值求值了

事实上,确实可以证明答案是关于次多项式(证明搬自cly_none的博客):

我们可以将记做若干下降幂的和,第项为,于是有:

那么我们只需要求出的几个点值就好了

从而令, , 就有 ,用FFT加速计算即可

稍微写一下拉格朗日插值:

我们现在得到。我们构造个多项式,第个多项式满足时函数值为,时函数值为,那么要求的多项式就是这个多项式之和。显然可以得到第个多项式:

不用把多项式化简,求值时将带进去计算即可。

因为给出的是, , … ,求值可以做到的复杂度。

#pragma GCC optimize("2,Ofast,inline")
#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int G = 3;
const int N = 1 << 18;
const int mod = 998244353;

int m, rev[N];
LL n, x, a[N], b[N], f[N], inv[N], fac[N], fav[N];

LL Qpow(LL x, int p) {
	LL ans = 1;
	while (p) {
		if (p & 1) ans = ans * x % mod;
		x = x * x % mod;
		p >>= 1;
	}
	return ans;
}

LL Inv(LL x) {
	return Qpow(x, mod - 2);
}

void NTT(LL *a, int lim, int type) {
	for (int i = 0; i < lim; ++i) {
		if (i < rev[i]) swap(a[i], a[rev[i]]);
	}
	for (int i = 2; i <= lim; i <<= 1) {
		int mid = i >> 1;
		LL Wn = Qpow(G, (mod - 1) / i);
		if (type == -1) Wn = Inv(Wn);
		for (int j = 0; j < lim; j += i) {
			LL W = 1;
			for (int k = 0; k < mid; ++k) {
				LL x = a[j + k], y = a[j + k + mid] * W % mod;
				a[j + k] = (x + y >= mod) ? (x + y - mod) : (x + y);
				a[j + k + mid] = (x < y) ? (x + mod - y) : (x - y);
				W = W * Wn % mod;
			}
		}
	}
	if (type == -1) {
		LL tmp = Inv(lim);
		for (int i = 0; i < lim; ++i)
			a[i] = a[i] * tmp % mod;
	}
}

void init() {
	scanf("%lld%d%lld", &n, &m, &x);
	for (int i = 0; i <= m; ++i)
		scanf("%lld", &f[i]);
	fac[0] = fav[0] = 1;
	fac[1] = fav[1] = inv[1] = 1;
	for (int i = 2; i < N; ++i) {
		fac[i] = fac[i - 1] * i % mod;
		inv[i] = (-mod / i * inv[mod % i] % mod) + mod;
		fav[i] = fav[i - 1] * inv[i] % mod;
	}
	for (int i = 0; i <= m; ++i) {
		a[i] = f[i] * Qpow(x, i) % mod * fav[i] % mod;
		b[i] = Qpow(mod + 1 - x, i) * fav[i] % mod;
	}
	int lim = 1, l = 0;
	while (lim < m + m + 2) lim <<= 1, ++l;
	for (int i = 0; i < lim; ++i)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l - 1);
	NTT(a, lim, 1); NTT(b, lim, 1);
	for (int i = 0; i < lim; ++i)
		a[i] = a[i] * b[i] % mod;
	NTT(a, lim, -1);
	for (int i = 0; i <= m; ++i)
		a[i] = a[i] * fac[i] % mod;
}

void Solve() {
	LL ans = 0, tot = 1;
	for (int i = 0; i <= m; ++i)
		tot = tot * (n - i) % mod;
	for (int i = 0; i <= m; ++i) {
		LL up = tot * Inv(n - i) % mod;
		up = up * a[i] % mod;
		LL down = fac[i] * fac[m - i] % mod;
		if ((m - i) & 1) down = mod - down;
		LL tmp = up * Inv(down);
		ans = (ans + tmp) % mod;
	}
	printf("%lld\n", ans);
}

int main() {
	init();
	if (n <= m) return 0 * printf("%lld\n", a[n]);
	Solve();
	return 0;
}