NTT笔记
NTT的思想与FFT类似,但是避免了浮点数运算带来的精度问题。
原根
定义:设且,使得的最小称为模的阶,记作。
定理:显然,对于所有,都有。
同时根据欧拉定理我们有
因此有。定义:若则称为模意义下的原根。
存在原根当且仅当,其中为奇素数,。
原根的意义:如果存在原根,那恰能表示所有个与互素的数,且构成的简化剩余系。
原根的计算方法:原根只有个,因此只要枚举即可。
NTT
NTT中使用单位原根代替单位复数根进行计算,单位原根定义为:
其中为模素数的原根。显然有与单位负数根类似的性质:
又因为以及一些二次剩余的知识,只可能是,又因为原根的性质 同时,观察的定义,易证(折半引理),原根具备所有FFT中利用的单位复数根的性质,因此也可以使用单位原根进行FFT,得到的算法就是NTT(Number Theoretic Transform),NTT适用于整数序列的变换,且精度与数值稳定性较FFT更优。代码:
#include <cstdio>
#include <algorithm>
using namespace std;
const int MOD = 998244353;
const int N = 270000;
const int G = 3, INVG = 332748118;
typedef long long ll;
int n, m;
ll a[N], b[N];
ll qpow(ll x, ll y) {
ll ret = 1;
for (; y; y >>= 1) {
if (y & 1) ret = ret * x % MOD;
x = x * x % MOD;
}
return ret;
}
void ntt(int len, ll *x, int dir = 1) {
static ll X[N];
for (int i = 0, k = 0; i < len; i++) {
X[i] = x[k];
int y = len >> 1;
while (y & k) k ^= y, y >>= 1;
k |= y;
}
for (int s = 2; s <= len; s <<= 1) {
ll w_ = qpow(G, (MOD - 1) / s);
if (dir == -1) w_ = qpow(w_, MOD - 2);
for (int l = 0; l < len; l += s) {
int mid = l + (s >> 1);
ll w = 1;
for (int i = 0; i < s >> 1; i++) {
ll t = w * X[mid + i] % MOD;
X[mid + i] = (X[l + i] - t) % MOD;
X[l + i] = (X[l + i] + t) % MOD;
w = w * w_ % MOD;
}
}
}
copy(X, X + len, x);
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 0; i <= n; i++) scanf("%lld", &a[i]);
for (int i = 0; i <= m; i++) scanf("%lld", &b[i]);
int sz = 1; while (sz < n + m + 1) sz <<= 1;
ntt(sz, a); ntt(sz, b);
for (int i = 0; i < sz; i++) printf("%lld %lld\n", a[i], b[i]);
for (int i = 0; i < sz; i++) a[i] = a[i] * b[i] % MOD;
ntt(sz, a, -1); ll inv = qpow(sz, MOD - 2);
for (int i = 0; i < n + m + 1; i++)
printf("%lld ", (a[i] * inv % MOD + MOD) % MOD);
puts("");
return 0;
}
常用的NTT素数
NTT所用的素数应该满足什么条件呢?观察的定义显然我们期望对于一定范围内的(比如说)都最好有。即的形式。
常见的NTT素数为:
这些素数的原根都包含。
任意模数NTT
有些时候题目要求任意模数的NTT变换,这个时候我们可以选取三个上述列表中模数分别进行NTT,然后通过CRT进行合并。具体上来说,对于一个数,若有:
我们考虑两两合并,先求解,不难得到(逆元在模意义下)。令,再求解,此时有(逆元在模意义下)。
在计算的最后再对于指定的模数取模即可。
代码:
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long ll;
const ll G = 3;
const ll MOD1 = 998244353, MOD2 = 1004535809, MOD3 = 469762049;
const int N = 270000;
int n, m, P;
int qpow(int x, int y, int mod) {
int ret = 1;
for (; y; y >>= 1) {
if (y & 1) ret = (ll)ret * x % mod;
x = (ll)x * x % mod;
}
return ret;
}
struct mint {
int m1, m2, m3;
mint() {}
mint(int x) : m1(x % MOD1), m2(x % MOD2), m3(x % MOD3) {}
mint(int m1, int m2, int m3) : m1(m1), m2(m2), m3(m3) {}
mint reduce() {
return mint(m1 + (m1 >> 31 & MOD1), // m1 >> 31 是符号位
m2 + (m2 >> 31 & MOD2),
m3 + (m3 >> 31 & MOD3));
}
mint operator +(const mint &x) const {
return mint(m1 + x.m1 - MOD1,
m2 + x.m2 - MOD2,
m3 + x.m3 - MOD3).reduce();
}
mint operator -(const mint &x) const {
return mint(m1 - x.m1, m2 - x.m2, m3 - x.m3).reduce();
}
mint operator *(const mint &x) const {
return mint((ll)m1 * x.m1 % MOD1,
(ll)m2 * x.m2 % MOD2,
(ll)m3 * x.m3 % MOD3);
}
mint inv() {
return mint(qpow(m1, MOD1 - 2, MOD1),
qpow(m2, MOD2 - 2, MOD2),
qpow(m3, MOD3 - 2, MOD3));
}
int get() {
const int INV_1 = qpow(MOD1, MOD2 - 2, MOD2);
const int INV_2 = qpow((ll)MOD1 * MOD2 % MOD3, MOD3 - 2, MOD3);
ll x = (ll)(m2 - m1 + MOD2) % MOD2 * INV_1 % MOD2 * MOD1 + m1;
return ((ll)(m3 - x % MOD3 + MOD3) % MOD3 * INV_2 % MOD3
* ((ll)MOD1 * MOD2 % P) % P + x) % P;
}
static mint unit(int s) {
return mint(qpow(G, (MOD1 - 1) / s, MOD1),
qpow(G, (MOD2 - 1) / s, MOD2),
qpow(G, (MOD3 - 1) / s, MOD3));
}
} a[N], b[N];
void ntt(int len, mint *x, int dir = 1) {
static mint X[N];
for (int i = 0, k = 0; i < len; i++) {
X[i] = x[k];
int y = len >> 1;
while (y & k) k ^= y, y >>= 1;
k |= y;
}
for (int s = 2; s <= len; s <<= 1) {
mint w_ = mint::unit(s);
if (dir == -1) w_ = w_.inv();
for (int l = 0; l < len; l += s) {
int mid = l + (s >> 1);
mint w = 1;
for (int i = 0; i < s >> 1; i++) {
mint t = w * X[mid + i];
X[mid + i] = X[l + i] - t;
X[l + i] = X[l + i] + t;
w = w * w_;
}
}
}
if (dir == -1) {
mint inv = mint(len, len, len).inv();
for (int i = 0; i < len; i++) X[i] = X[i] * inv;
}
copy(X, X + len, x);
}
int main() {
scanf("%d%d%d", &n, &m, &P);
for (int i = 0; i <= n; i++) { int tmp; scanf("%d", &tmp); a[i] = tmp; }
for (int i = 0; i <= m; i++) { int tmp; scanf("%d", &tmp); b[i] = tmp; }
int sz = 1; while (sz < n + m + 1) sz <<= 1;
ntt(sz, a); ntt(sz, b);
for (int i = 0; i < sz; i++) a[i] = a[i] * b[i];
ntt(sz, a, -1);
for (int i = 0; i < n + m + 1; i++) printf("%d ", a[i].get());
puts("");
return 0;
}