(刷新以获取数学公式)
题意
给定整数\(n\),\(m\),\(x\),求:
\[\sum_{k = 0}^{n} f(k) {n \choose k} x^k (1 - x) ^ {n - k}\]其中\(f(x)\)是一个\(m\)次函数,以点值形式给出,\(a_i\)表示\(x = i\)时的函数值
保证\(1≤n≤109, 1≤m≤20000, 0≤ai,x﹤998244353\)。
答案模998244353输出
题解
题目给出\(f(x)\)的点值,这启发我们想到,如果可以求出\(g(t) = \sum_{k = 0}^{t} f(k) {t \choose k} x^k (1 - x) ^ {t - k}\)在\(0\), \(1\), … ,\(m\)的点值,并且如果答案恰好是关于\(n\)的\(m\)次多项式的话,就可以用拉格朗日插值求值了
事实上,确实可以证明答案是关于\(n\)的\(m\)次多项式(证明搬自cly_none的博客):
我们可以将\(f(x)\)记做若干下降幂的和,第\(k\)项为\(f_k x ^ {\underline k}\),于是有:
.png)
那么我们只需要求出\(g(t)\)的几个点值就好了
.png)
从而令\(A(k) = \frac{f(k)x^k}{k!}\), \(B(k) = \frac{(1-x)^{k}}{k!}\), 就有 \(g = A * B\),用FFT加速计算即可
稍微写一下拉格朗日插值:
我们现在得到\((x_1, y_1), (x_2, y_2), ... (x_n, y_n), (x_{n + 1}, y_{n + 1})\)。我们构造\(n + 1\)个多项式,第\(i\)个多项式满足\(x = x_i\)时函数值为\(y_i\),\(x = x_j(j ≠ i)\)时函数值为\(0\),那么要求的多项式就是这\(n + 1\)个多项式之和。显然可以得到第\(i\)个多项式:
\[f_{i}(x) = y_i\frac{\prod_{j ≠ i} (x - x_j)}{\prod_{j ≠ i} (x_i - x_j)}\]不用把多项式化简,求值时将\(x = x_0\)带进去计算即可。
因为\(x_i\)给出的是\(0\), \(1\), … \(m\),求值可以做到\(O(m)\)的复杂度。
#pragma GCC optimize("2,Ofast,inline")
#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int G = 3;
const int N = 1 << 18;
const int mod = 998244353;
int m, rev[N];
LL n, x, a[N], b[N], f[N], inv[N], fac[N], fav[N];
LL Qpow(LL x, int p) {
LL ans = 1;
while (p) {
if (p & 1) ans = ans * x % mod;
x = x * x % mod;
p >>= 1;
}
return ans;
}
LL Inv(LL x) {
return Qpow(x, mod - 2);
}
void NTT(LL *a, int lim, int type) {
for (int i = 0; i < lim; ++i) {
if (i < rev[i]) swap(a[i], a[rev[i]]);
}
for (int i = 2; i <= lim; i <<= 1) {
int mid = i >> 1;
LL Wn = Qpow(G, (mod - 1) / i);
if (type == -1) Wn = Inv(Wn);
for (int j = 0; j < lim; j += i) {
LL W = 1;
for (int k = 0; k < mid; ++k) {
LL x = a[j + k], y = a[j + k + mid] * W % mod;
a[j + k] = (x + y >= mod) ? (x + y - mod) : (x + y);
a[j + k + mid] = (x < y) ? (x + mod - y) : (x - y);
W = W * Wn % mod;
}
}
}
if (type == -1) {
LL tmp = Inv(lim);
for (int i = 0; i < lim; ++i)
a[i] = a[i] * tmp % mod;
}
}
void init() {
scanf("%lld%d%lld", &n, &m, &x);
for (int i = 0; i <= m; ++i)
scanf("%lld", &f[i]);
fac[0] = fav[0] = 1;
fac[1] = fav[1] = inv[1] = 1;
for (int i = 2; i < N; ++i) {
fac[i] = fac[i - 1] * i % mod;
inv[i] = (-mod / i * inv[mod % i] % mod) + mod;
fav[i] = fav[i - 1] * inv[i] % mod;
}
for (int i = 0; i <= m; ++i) {
a[i] = f[i] * Qpow(x, i) % mod * fav[i] % mod;
b[i] = Qpow(mod + 1 - x, i) * fav[i] % mod;
}
int lim = 1, l = 0;
while (lim < m + m + 2) lim <<= 1, ++l;
for (int i = 0; i < lim; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l - 1);
NTT(a, lim, 1); NTT(b, lim, 1);
for (int i = 0; i < lim; ++i)
a[i] = a[i] * b[i] % mod;
NTT(a, lim, -1);
for (int i = 0; i <= m; ++i)
a[i] = a[i] * fac[i] % mod;
}
void Solve() {
LL ans = 0, tot = 1;
for (int i = 0; i <= m; ++i)
tot = tot * (n - i) % mod;
for (int i = 0; i <= m; ++i) {
LL up = tot * Inv(n - i) % mod;
up = up * a[i] % mod;
LL down = fac[i] * fac[m - i] % mod;
if ((m - i) & 1) down = mod - down;
LL tmp = up * Inv(down);
ans = (ans + tmp) % mod;
}
printf("%lld\n", ans);
}
int main() {
init();
if (n <= m) return 0 * printf("%lld\n", a[n]);
Solve();
return 0;
}