NTT 的思想与 FFT 类似,但是避免了浮点数运算带来的精度问题。
原根
定义:设 \(m > 1\) 且 \(\gcd(a, m) = 1\),使得 \(a^d \equiv 1 \pmod m\) 的最小 \(d\) 称为 \(a\) 模 \(m\) 的阶,记作 \(d = \delta_m(a)\)。
定理:显然,对于所有 \(a^d \equiv 1 \pmod m\),都有 \(\delta_m(a) \mid d\)。
同时根据欧拉定理我们有 \[ a^{\varphi(m)} \equiv 1 \pmod m \] 因此有 \(\delta_m(a) \mid \varphi(m)\)。
定义:若 \(\delta_m(g) = \varphi(m)\) 则称 \(g\) 为模 \(m\) 意义下的原根。
\(m\) 存在原根当且仅当 \(m = 2,4,p^n, 2p^n\),其中 \(p\) 为奇素数,\(n \in \mathbb{Z}\)。
原根的意义:如果 \(m\) 存在原根 \(g\),那 \(g^0, g^1, g^2, \cdots, g^{\varphi(m) - 1}\) 恰能表示所有 \(\varphi(m)\) 个与 \(m\) 互素的数,且构成 \(m\) 的简化剩余系。
原根的计算方法:原根只有 \(\varphi(\varphi(m))\) 个,因此只要枚举即可。
NTT
NTT 中使用单位原根代替单位复数根进行计算,单位原根定义为: \[ g_n = g^{\varphi(p) / n} = g^{(p - 1)/n} \] 其中 \(g\) 为模素数 \(p\) 的原根。
显然 \(g_n\) 有与单位负数根类似的性质: \[ g_n^n = g^{p - 1} \equiv 1 \pmod p \] 又因为 \(\left(g_n^{n / 2}\right)^2 \equiv 1 \pmod p\) 以及一些二次剩余的知识,\(g_n^{n / 2} \bmod p\) 只可能是 \(\pm 1\),又因为原根的性质 \[ g_n^{n / 2} \equiv -1 \pmod p \] 同时,观察 \(g_n\) 的定义,易证 \(g_{2n}^{2k} = g_n^k\)(折半引理),原根具备所有 FFT 中利用的单位复数根的性质,因此也可以使用单位原根进行 FFT,得到的算法就是 NTT(Number Theoretic Transform),NTT 适用于整数序列的变换,且精度与数值稳定性较 FFT 更优。
代码:
#include <cstdio>
#include <algorithm>
using namespace std;
const int MOD = 998244353;
const int N = 270000;
const int G = 3, INVG = 332748118;
typedef long long ll;
int n, m;
[N], b[N];
ll a
(ll x, ll y) {
ll qpow= 1;
ll ret for (; y; y >>= 1) {
if (y & 1) ret = ret * x % MOD;
= x * x % MOD;
x }
return ret;
}
void ntt(int len, ll *x, int dir = 1) {
static ll X[N];
for (int i = 0, k = 0; i < len; i++) {
[i] = x[k];
Xint y = len >> 1;
while (y & k) k ^= y, y >>= 1;
|= y;
k }
for (int s = 2; s <= len; s <<= 1) {
w_ = qpow(G, (MOD - 1) / s);
ll if (dir == -1) w_ = qpow(w_, MOD - 2);
for (int l = 0; l < len; l += s) {
int mid = l + (s >> 1);
= 1;
ll w for (int i = 0; i < s >> 1; i++) {
= w * X[mid + i] % MOD;
ll t [mid + i] = (X[l + i] - t) % MOD;
X[l + i] = (X[l + i] + t) % MOD;
X= w * w_ % MOD;
w }
}
}
(X, X + len, x);
copy}
int main() {
("%d%d", &n, &m);
scanffor (int i = 0; i <= n; i++) scanf("%lld", &a[i]);
for (int i = 0; i <= m; i++) scanf("%lld", &b[i]);
int sz = 1; while (sz < n + m + 1) sz <<= 1;
(sz, a); ntt(sz, b);
nttfor (int i = 0; i < sz; i++) printf("%lld %lld\n", a[i], b[i]);
for (int i = 0; i < sz; i++) a[i] = a[i] * b[i] % MOD;
(sz, a, -1); ll inv = qpow(sz, MOD - 2);
nttfor (int i = 0; i < n + m + 1; i++)
("%lld ", (a[i] * inv % MOD + MOD) % MOD);
printf("");
putsreturn 0;
}
常用的 NTT 素数
NTT 所用的素数应该满足什么条件呢?观察 \(g_n\) 的定义显然我们期望对于一定范围内的 \(k\)(比如说 \(2^k \le 2\times 10^5\))都最好有 \(2^k \mid p-1\)。即 \(p = a\times 2^k +1\) 的形式。
常见的 NTT 素数为:
- \(998244353 = 7\times 17\times 2^{23} + 1\)
- \(1004535809 = 479 \times 2^{21} + 1\)
- \(469762049 = 7 \times 2^{26} + 1\)
- \(2281701377 = 17 \times 2^{27} + 1\)
- \(167772161 = 5 \times 2^{25} + 1\)
这些素数的原根都包含 \(3\)。
任意模数 NTT
有些时候题目要求任意模数的 NTT 变换,这个时候我们可以选取三个上述列表中模数分别进行 NTT,然后通过 CRT 进行合并。具体上来说,对于一个数 \(x\),若有: \[ \begin{aligned} x &\equiv b_1 &\pmod{p_1} \\ x &\equiv b_2 &\pmod{p_2} \\ x &\equiv b_3 &\pmod{p_3} \end{aligned} \] 我们考虑两两合并,先求解 \(b_1 + k_1p_1 \equiv b_2 \pmod {p_2}\),不难得到 \(k_1 = p_1^{-1}(b_2 - b_1)\)(逆元在模 \(p_2\) 意义下)。
令 \(x' = b_1 + k_1p_1\),再求解 \(x = x' + k_2p_1p_2 \equiv b_3 \pmod {p_3}\),此时有 \(k_2 = (p_1p_2)^{-1}(b_3 - x')\)(逆元在模 \(p_3\) 意义下)。
在计算的最后再对于指定的模数取模即可。
代码:
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long ll;
const ll G = 3;
const ll MOD1 = 998244353, MOD2 = 1004535809, MOD3 = 469762049;
const int N = 270000;
int n, m, P;
int qpow(int x, int y, int mod) {
int ret = 1;
for (; y; y >>= 1) {
if (y & 1) ret = (ll)ret * x % mod;
= (ll)x * x % mod;
x }
return ret;
}
struct mint {
int m1, m2, m3;
() {}
mint(int x) : m1(x % MOD1), m2(x % MOD2), m3(x % MOD3) {}
mint(int m1, int m2, int m3) : m1(m1), m2(m2), m3(m3) {}
mint() {
mint reducereturn mint(m1 + (m1 >> 31 & MOD1), // m1 >> 31 是符号位
+ (m2 >> 31 & MOD2),
m2 + (m3 >> 31 & MOD3));
m3 }
operator +(const mint &x) const {
mint return mint(m1 + x.m1 - MOD1,
+ x.m2 - MOD2,
m2 + x.m3 - MOD3).reduce();
m3 }
operator -(const mint &x) const {
mint return mint(m1 - x.m1, m2 - x.m2, m3 - x.m3).reduce();
}
operator *(const mint &x) const {
mint return mint((ll)m1 * x.m1 % MOD1,
(ll)m2 * x.m2 % MOD2,
(ll)m3 * x.m3 % MOD3);
}
() {
mint invreturn mint(qpow(m1, MOD1 - 2, MOD1),
(m2, MOD2 - 2, MOD2),
qpow(m3, MOD3 - 2, MOD3));
qpow}
int get() {
const int INV_1 = qpow(MOD1, MOD2 - 2, MOD2);
const int INV_2 = qpow((ll)MOD1 * MOD2 % MOD3, MOD3 - 2, MOD3);
= (ll)(m2 - m1 + MOD2) % MOD2 * INV_1 % MOD2 * MOD1 + m1;
ll x return ((ll)(m3 - x % MOD3 + MOD3) % MOD3 * INV_2 % MOD3
* ((ll)MOD1 * MOD2 % P) % P + x) % P;
}
static mint unit(int s) {
return mint(qpow(G, (MOD1 - 1) / s, MOD1),
(G, (MOD2 - 1) / s, MOD2),
qpow(G, (MOD3 - 1) / s, MOD3));
qpow}
} a[N], b[N];
void ntt(int len, mint *x, int dir = 1) {
static mint X[N];
for (int i = 0, k = 0; i < len; i++) {
[i] = x[k];
Xint y = len >> 1;
while (y & k) k ^= y, y >>= 1;
|= y;
k }
for (int s = 2; s <= len; s <<= 1) {
w_ = mint::unit(s);
mint if (dir == -1) w_ = w_.inv();
for (int l = 0; l < len; l += s) {
int mid = l + (s >> 1);
= 1;
mint w for (int i = 0; i < s >> 1; i++) {
= w * X[mid + i];
mint t [mid + i] = X[l + i] - t;
X[l + i] = X[l + i] + t;
X= w * w_;
w }
}
}
if (dir == -1) {
= mint(len, len, len).inv();
mint inv for (int i = 0; i < len; i++) X[i] = X[i] * inv;
}
(X, X + len, x);
copy}
int main() {
("%d%d%d", &n, &m, &P);
scanffor (int i = 0; i <= n; i++) { int tmp; scanf("%d", &tmp); a[i] = tmp; }
for (int i = 0; i <= m; i++) { int tmp; scanf("%d", &tmp); b[i] = tmp; }
int sz = 1; while (sz < n + m + 1) sz <<= 1;
(sz, a); ntt(sz, b);
nttfor (int i = 0; i < sz; i++) a[i] = a[i] * b[i];
(sz, a, -1);
nttfor (int i = 0; i < n + m + 1; i++) printf("%d ", a[i].get());
("");
putsreturn 0;
}