Vexoben
Oct 6, 2018
(刷新以获取数学公式)
题意
给定整数,,,求:
其中是一个次函数,以点值形式给出,表示时的函数值
保证。
答案模998244353输出
题解
题目给出的点值,这启发我们想到,如果可以求出在, , … ,的点值,并且如果答案恰好是关于的次多项式的话,就可以用拉格朗日插值求值了
事实上,确实可以证明答案是关于的次多项式(证明搬自cly_none的博客):
我们可以将记做若干下降幂的和,第项为,于是有:
那么我们只需要求出的几个点值就好了
从而令, , 就有 ,用FFT加速计算即可
稍微写一下拉格朗日插值:
我们现在得到。我们构造个多项式,第个多项式满足时函数值为,时函数值为,那么要求的多项式就是这个多项式之和。显然可以得到第个多项式:
不用把多项式化简,求值时将带进去计算即可。
因为给出的是, , … ,求值可以做到的复杂度。
#pragma GCC optimize("2,Ofast,inline")
#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int G = 3;
const int N = 1 << 18;
const int mod = 998244353;
int m, rev[N];
LL n, x, a[N], b[N], f[N], inv[N], fac[N], fav[N];
LL Qpow(LL x, int p) {
LL ans = 1;
while (p) {
if (p & 1) ans = ans * x % mod;
x = x * x % mod;
p >>= 1;
}
return ans;
}
LL Inv(LL x) {
return Qpow(x, mod - 2);
}
void NTT(LL *a, int lim, int type) {
for (int i = 0; i < lim; ++i) {
if (i < rev[i]) swap(a[i], a[rev[i]]);
}
for (int i = 2; i <= lim; i <<= 1) {
int mid = i >> 1;
LL Wn = Qpow(G, (mod - 1) / i);
if (type == -1) Wn = Inv(Wn);
for (int j = 0; j < lim; j += i) {
LL W = 1;
for (int k = 0; k < mid; ++k) {
LL x = a[j + k], y = a[j + k + mid] * W % mod;
a[j + k] = (x + y >= mod) ? (x + y - mod) : (x + y);
a[j + k + mid] = (x < y) ? (x + mod - y) : (x - y);
W = W * Wn % mod;
}
}
}
if (type == -1) {
LL tmp = Inv(lim);
for (int i = 0; i < lim; ++i)
a[i] = a[i] * tmp % mod;
}
}
void init() {
scanf("%lld%d%lld", &n, &m, &x);
for (int i = 0; i <= m; ++i)
scanf("%lld", &f[i]);
fac[0] = fav[0] = 1;
fac[1] = fav[1] = inv[1] = 1;
for (int i = 2; i < N; ++i) {
fac[i] = fac[i - 1] * i % mod;
inv[i] = (-mod / i * inv[mod % i] % mod) + mod;
fav[i] = fav[i - 1] * inv[i] % mod;
}
for (int i = 0; i <= m; ++i) {
a[i] = f[i] * Qpow(x, i) % mod * fav[i] % mod;
b[i] = Qpow(mod + 1 - x, i) * fav[i] % mod;
}
int lim = 1, l = 0;
while (lim < m + m + 2) lim <<= 1, ++l;
for (int i = 0; i < lim; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l - 1);
NTT(a, lim, 1); NTT(b, lim, 1);
for (int i = 0; i < lim; ++i)
a[i] = a[i] * b[i] % mod;
NTT(a, lim, -1);
for (int i = 0; i <= m; ++i)
a[i] = a[i] * fac[i] % mod;
}
void Solve() {
LL ans = 0, tot = 1;
for (int i = 0; i <= m; ++i)
tot = tot * (n - i) % mod;
for (int i = 0; i <= m; ++i) {
LL up = tot * Inv(n - i) % mod;
up = up * a[i] % mod;
LL down = fac[i] * fac[m - i] % mod;
if ((m - i) & 1) down = mod - down;
LL tmp = up * Inv(down);
ans = (ans + tmp) % mod;
}
printf("%lld\n", ans);
}
int main() {
init();
if (n <= m) return 0 * printf("%lld\n", a[n]);
Solve();
return 0;
}