NTT 的思想与 FFT 类似,但是避免了浮点数运算带来的精度问题。

原根

定义:\(m > 1\)\(\gcd(a, m) = 1\),使得 \(a^d \equiv 1 \pmod m\) 的最小 \(d\) 称为 \(a\)\(m\),记作 \(d = \delta_m(a)\)

定理:显然,对于所有 \(a^d \equiv 1 \pmod m\),都有 \(\delta_m(a) \mid d\)

同时根据欧拉定理我们有 \[ a^{\varphi(m)} \equiv 1 \pmod m \] 因此有 \(\delta_m(a) \mid \varphi(m)\)

定义:\(\delta_m(g) = \varphi(m)\) 则称 \(g\) 为模 \(m\) 意义下的原根

\(m​\) 存在原根当且仅当 \(m = 2,4,p^n, 2p^n​\),其中 \(p​\) 为奇素数,\(n \in \mathbb{Z}​\)

原根的意义:如果 \(m\) 存在原根 \(g\),那 \(g^0, g^1, g^2, \cdots, g^{\varphi(m) - 1}\) 恰能表示所有 \(\varphi(m)\) 个与 \(m\) 互素的数,且构成 \(m\) 的简化剩余系。

原根的计算方法:原根只有 \(\varphi(\varphi(m))\) 个,因此只要枚举即可。

NTT

NTT 中使用单位原根代替单位复数根进行计算,单位原根定义为: \[ g_n = g^{\varphi(p) / n} = g^{(p - 1)/n} \] 其中 \(g\) 为模素数 \(p\) 的原根。

显然 \(g_n\) 有与单位负数根类似的性质: \[ g_n^n = g^{p - 1} \equiv 1 \pmod p \] 又因为 \(\left(g_n^{n / 2}\right)^2 \equiv 1 \pmod p\) 以及一些二次剩余的知识,\(g_n^{n / 2} \bmod p\) 只可能是 \(\pm 1\),又因为原根的性质 \[ g_n^{n / 2} \equiv -1 \pmod p \] 同时,观察 \(g_n\) 的定义,易证 \(g_{2n}^{2k} = g_n^k\)(折半引理),原根具备所有 FFT 中利用的单位复数根的性质,因此也可以使用单位原根进行 FFT,得到的算法就是 NTT(Number Theoretic Transform),NTT 适用于整数序列的变换,且精度与数值稳定性较 FFT 更优。

代码:

#include <cstdio>
#include <algorithm>

using namespace std;
const int MOD = 998244353;
const int N = 270000;
const int G = 3, INVG = 332748118;
typedef long long ll;
int n, m; 
ll a[N], b[N];

ll qpow(ll x, ll y) {
    ll ret = 1;
    for (; y; y >>= 1) {
        if (y & 1) ret = ret * x % MOD;
        x = x * x % MOD;
    }
    return ret;
}

void ntt(int len, ll *x, int dir = 1) {
    static ll X[N];
    for (int i = 0, k = 0; i < len; i++) {
        X[i] = x[k];
        int y = len >> 1;
        while (y & k) k ^= y, y >>= 1;
        k |= y;
    }
    for (int s = 2; s <= len; s <<= 1) {
        ll w_ = qpow(G, (MOD - 1) / s);
        if (dir == -1) w_ = qpow(w_, MOD - 2);
        for (int l = 0; l < len; l += s) {
            int mid = l + (s >> 1);
            ll w = 1;
            for (int i = 0; i < s >> 1; i++) {
                ll t = w * X[mid + i] % MOD;
                X[mid + i] = (X[l + i] - t) % MOD;
                X[l + i] = (X[l + i] + t) % MOD;
                w = w * w_ % MOD;
            }
        }
    }
    copy(X, X + len, x);
}

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 0; i <= n; i++) scanf("%lld", &a[i]);
    for (int i = 0; i <= m; i++) scanf("%lld", &b[i]);
    int sz = 1; while (sz < n + m + 1) sz <<= 1;
    ntt(sz, a); ntt(sz, b);
    for (int i = 0; i < sz; i++) printf("%lld %lld\n", a[i], b[i]);
    for (int i = 0; i < sz; i++) a[i] = a[i] * b[i] % MOD;
    ntt(sz, a, -1); ll inv = qpow(sz, MOD - 2);
    for (int i = 0; i < n + m + 1; i++)
        printf("%lld ", (a[i] * inv % MOD + MOD) % MOD);
    puts("");
    return 0;
} 

常用的 NTT 素数

NTT 所用的素数应该满足什么条件呢?观察 \(g_n\) 的定义显然我们期望对于一定范围内的 \(k\)(比如说 \(2^k \le 2\times 10^5\))都最好有 \(2^k \mid p-1\)。即 \(p = a\times 2^k +1\) 的形式。

常见的 NTT 素数为:

  1. \(998244353 = 7\times 17\times 2^{23} + 1​\)
  2. \(1004535809 = 479 \times 2^{21} + 1\)
  3. \(469762049 = 7 \times 2^{26} + 1\)
  4. \(2281701377 = 17 \times 2^{27} + 1\)
  5. \(167772161 = 5 \times 2^{25} + 1\)

这些素数的原根都包含 \(3\)

任意模数 NTT

有些时候题目要求任意模数的 NTT 变换,这个时候我们可以选取三个上述列表中模数分别进行 NTT,然后通过 CRT 进行合并。具体上来说,对于一个数 \(x\),若有: \[ \begin{aligned} x &\equiv b_1 &\pmod{p_1} \\ x &\equiv b_2 &\pmod{p_2} \\ x &\equiv b_3 &\pmod{p_3} \end{aligned} \] 我们考虑两两合并,先求解 \(b_1 + k_1p_1 \equiv b_2 \pmod {p_2}​\),不难得到 \(k_1 = p_1^{-1}(b_2 - b_1)​\)(逆元在模 \(p_2​\) 意义下)。

\(x' = b_1 + k_1p_1\),再求解 \(x = x' + k_2p_1p_2 \equiv b_3 \pmod {p_3}\),此时有 \(k_2 = (p_1p_2)^{-1}(b_3 - x')\)(逆元在模 \(p_3\) 意义下)。

在计算的最后再对于指定的模数取模即可。

代码:

#include <cstdio>
#include <algorithm>

using namespace std;
typedef long long ll;
const ll G = 3;
const ll MOD1 = 998244353, MOD2 = 1004535809, MOD3 = 469762049;
const int N = 270000;
int n, m, P;

int qpow(int x, int y, int mod) {
    int ret = 1;
    for (; y; y >>= 1) {
        if (y & 1) ret = (ll)ret * x % mod;
        x = (ll)x * x % mod;
    }
    return ret;
}

struct mint {
    int m1, m2, m3;
    mint() {}
    mint(int x) : m1(x % MOD1), m2(x % MOD2), m3(x % MOD3) {}
    mint(int m1, int m2, int m3) : m1(m1), m2(m2), m3(m3) {}
    mint reduce() { 
        return mint(m1 + (m1 >> 31 & MOD1), // m1 >> 31 是符号位
                    m2 + (m2 >> 31 & MOD2), 
                    m3 + (m3 >> 31 & MOD3));
    }
    mint operator +(const mint &x) const {
        return mint(m1 + x.m1 - MOD1, 
                    m2 + x.m2 - MOD2, 
                    m3 + x.m3 - MOD3).reduce();
    }
    mint operator -(const mint &x) const {
        return mint(m1 - x.m1, m2 - x.m2, m3 - x.m3).reduce();
    }
    mint operator *(const mint &x) const {
        return mint((ll)m1 * x.m1 % MOD1, 
                    (ll)m2 * x.m2 % MOD2, 
                    (ll)m3 * x.m3 % MOD3);
    }
    mint inv() {
        return mint(qpow(m1, MOD1 - 2, MOD1), 
                    qpow(m2, MOD2 - 2, MOD2), 
                    qpow(m3, MOD3 - 2, MOD3));
    }
    int get() {
        const int INV_1 = qpow(MOD1, MOD2 - 2, MOD2);
        const int INV_2 = qpow((ll)MOD1 * MOD2 % MOD3, MOD3 - 2, MOD3);
        ll x = (ll)(m2 - m1 + MOD2) % MOD2 * INV_1 % MOD2 * MOD1 + m1;
        return ((ll)(m3 - x % MOD3 + MOD3) % MOD3 * INV_2 % MOD3 
                * ((ll)MOD1 * MOD2 % P) % P + x) % P;
    }
    static mint unit(int s) {
        return mint(qpow(G, (MOD1 - 1) / s, MOD1), 
                    qpow(G, (MOD2 - 1) / s, MOD2), 
                    qpow(G, (MOD3 - 1) / s, MOD3));
    }
} a[N], b[N];

void ntt(int len, mint *x, int dir = 1) {
    static mint X[N];
    for (int i = 0, k = 0; i < len; i++) {
        X[i] = x[k];
        int y = len >> 1;
        while (y & k) k ^= y, y >>= 1;
        k |= y;
    }

    for (int s = 2; s <= len; s <<= 1) {
        mint w_ = mint::unit(s);
        if (dir == -1) w_ = w_.inv();
        for (int l = 0; l < len; l += s) {
            int mid = l + (s >> 1);
            mint w = 1;
            for (int i = 0; i < s >> 1; i++) {
                mint t = w * X[mid + i];
                X[mid + i] = X[l + i] - t;
                X[l + i] = X[l + i] + t;
                w = w * w_;
            }
        }
    }
    if (dir == -1) {
        mint inv = mint(len, len, len).inv();
        for (int i = 0; i < len; i++) X[i] = X[i] * inv;
    }
    copy(X, X + len, x);
}

int main() {
    scanf("%d%d%d", &n, &m, &P);
    for (int i = 0; i <= n; i++) { int tmp; scanf("%d", &tmp); a[i] = tmp; }
    for (int i = 0; i <= m; i++) { int tmp; scanf("%d", &tmp); b[i] = tmp; }
    int sz = 1; while (sz < n + m + 1) sz <<= 1;
    ntt(sz, a); ntt(sz, b);
    for (int i = 0; i < sz; i++) a[i] = a[i] * b[i];
    ntt(sz, a, -1);
    for (int i = 0; i < n + m + 1; i++) printf("%d ", a[i].get());
    puts("");
    return 0;
}