NTT笔记

NTT的思想与FFT类似,但是避免了浮点数运算带来的精度问题。

原根

定义:m > 1\gcd(a, m) = 1,使得a^d \equiv 1 \pmod m的最小d称为am,记作d = \delta_m(a)

定理:显然,对于所有a^d \equiv 1 \pmod m,都有\delta_m(a) \mid d

同时根据欧拉定理我们有

a^{\varphi(m)} \equiv 1 \pmod m
因此有\delta_m(a) \mid \varphi(m)

定义:\delta_m(g) = \varphi(m)则称g为模m意义下的原根

m​存在原根当且仅当m = 2,4,p^n, 2p^n​,其中p​为奇素数,n \in \mathbb{Z}​

原根的意义:如果m存在原根g,那g^0, g^1, g^2, \cdots, g^{\varphi(m) - 1}恰能表示所有\varphi(m)个与m互素的数,且构成m的简化剩余系。

原根的计算方法:原根只有\varphi(\varphi(m))个,因此只要枚举即可。

NTT

NTT中使用单位原根代替单位复数根进行计算,单位原根定义为:

g_n = g^{\varphi(p) / n} = g^{(p - 1)/n}
其中g为模素数p的原根。

显然g_n有与单位负数根类似的性质:

g_n^n = g^{p - 1} \equiv 1 \pmod p
又因为\left(g_n^{n / 2}\right)^2 \equiv 1 \pmod p以及一些二次剩余的知识,g_n^{n / 2} \bmod p只可能是\pm 1,又因为原根的性质
g_n^{n / 2} \equiv -1 \pmod p
同时,观察g_n的定义,易证g_{2n}^{2k} = g_n^k(折半引理),原根具备所有FFT中利用的单位复数根的性质,因此也可以使用单位原根进行FFT,得到的算法就是NTT(Number Theoretic Transform),NTT适用于整数序列的变换,且精度与数值稳定性较FFT更优。

代码:

#include <cstdio>
#include <algorithm>

using namespace std;
const int MOD = 998244353;
const int N = 270000;
const int G = 3, INVG = 332748118;
typedef long long ll;
int n, m; 
ll a[N], b[N];

ll qpow(ll x, ll y) {
    ll ret = 1;
    for (; y; y >>= 1) {
        if (y & 1) ret = ret * x % MOD;
        x = x * x % MOD;
    }
    return ret;
}

void ntt(int len, ll *x, int dir = 1) {
    static ll X[N];
    for (int i = 0, k = 0; i < len; i++) {
        X[i] = x[k];
        int y = len >> 1;
        while (y & k) k ^= y, y >>= 1;
        k |= y;
    }
    for (int s = 2; s <= len; s <<= 1) {
        ll w_ = qpow(G, (MOD - 1) / s);
        if (dir == -1) w_ = qpow(w_, MOD - 2);
        for (int l = 0; l < len; l += s) {
            int mid = l + (s >> 1);
            ll w = 1;
            for (int i = 0; i < s >> 1; i++) {
                ll t = w * X[mid + i] % MOD;
                X[mid + i] = (X[l + i] - t) % MOD;
                X[l + i] = (X[l + i] + t) % MOD;
                w = w * w_ % MOD;
            }
        }
    }
    copy(X, X + len, x);
}

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 0; i <= n; i++) scanf("%lld", &a[i]);
    for (int i = 0; i <= m; i++) scanf("%lld", &b[i]);
    int sz = 1; while (sz < n + m + 1) sz <<= 1;
    ntt(sz, a); ntt(sz, b);
    for (int i = 0; i < sz; i++) printf("%lld %lld\n", a[i], b[i]);
    for (int i = 0; i < sz; i++) a[i] = a[i] * b[i] % MOD;
    ntt(sz, a, -1); ll inv = qpow(sz, MOD - 2);
    for (int i = 0; i < n + m + 1; i++)
        printf("%lld ", (a[i] * inv % MOD + MOD) % MOD);
    puts("");
    return 0;
} 

常用的NTT素数

NTT所用的素数应该满足什么条件呢?观察g_n的定义显然我们期望对于一定范围内的k(比如说2^k \le 2\times 10^5)都最好有2^k \mid p-1。即p = a\times 2^k +1的形式。

常见的NTT素数为:

  1. 998244353 = 7\times 17\times 2^{23} + 1​
  2. 1004535809 = 479 \times 2^{21} + 1
  3. 469762049 = 7 \times 2^{26} + 1
  4. 2281701377 = 17 \times 2^{27} + 1
  5. 167772161 = 5 \times 2^{25} + 1

这些素数的原根都包含3

任意模数NTT

有些时候题目要求任意模数的NTT变换,这个时候我们可以选取三个上述列表中模数分别进行NTT,然后通过CRT进行合并。具体上来说,对于一个数x,若有:

\begin{aligned}
x &\equiv b_1 &\pmod{p_1} \\
x &\equiv b_2 &\pmod{p_2} \\
x &\equiv b_3 &\pmod{p_3} 
\end{aligned}
我们考虑两两合并,先求解b_1 + k_1p_1 \equiv b_2 \pmod {p_2}​,不难得到k_1 = p_1^{-1}(b_2 - b_1)​(逆元在模p_2​意义下)。

x' = b_1 + k_1p_1,再求解x = x' + k_2p_1p_2 \equiv b_3 \pmod {p_3},此时有k_2 = (p_1p_2)^{-1}(b_3 - x')(逆元在模p_3意义下)。

在计算的最后再对于指定的模数取模即可。

代码:

#include <cstdio>
#include <algorithm>

using namespace std;
typedef long long ll;
const ll G = 3;
const ll MOD1 = 998244353, MOD2 = 1004535809, MOD3 = 469762049;
const int N = 270000;
int n, m, P;

int qpow(int x, int y, int mod) {
    int ret = 1;
    for (; y; y >>= 1) {
        if (y & 1) ret = (ll)ret * x % mod;
        x = (ll)x * x % mod;
    }
    return ret;
}

struct mint {
    int m1, m2, m3;
    mint() {}
    mint(int x) : m1(x % MOD1), m2(x % MOD2), m3(x % MOD3) {}
    mint(int m1, int m2, int m3) : m1(m1), m2(m2), m3(m3) {}
    mint reduce() { 
        return mint(m1 + (m1 >> 31 & MOD1), // m1 >> 31 是符号位
                    m2 + (m2 >> 31 & MOD2), 
                    m3 + (m3 >> 31 & MOD3));
    }
    mint operator +(const mint &x) const {
        return mint(m1 + x.m1 - MOD1, 
                    m2 + x.m2 - MOD2, 
                    m3 + x.m3 - MOD3).reduce();
    }
    mint operator -(const mint &x) const {
        return mint(m1 - x.m1, m2 - x.m2, m3 - x.m3).reduce();
    }
    mint operator *(const mint &x) const {
        return mint((ll)m1 * x.m1 % MOD1, 
                    (ll)m2 * x.m2 % MOD2, 
                    (ll)m3 * x.m3 % MOD3);
    }
    mint inv() {
        return mint(qpow(m1, MOD1 - 2, MOD1), 
                    qpow(m2, MOD2 - 2, MOD2), 
                    qpow(m3, MOD3 - 2, MOD3));
    }
    int get() {
        const int INV_1 = qpow(MOD1, MOD2 - 2, MOD2);
        const int INV_2 = qpow((ll)MOD1 * MOD2 % MOD3, MOD3 - 2, MOD3);
        ll x = (ll)(m2 - m1 + MOD2) % MOD2 * INV_1 % MOD2 * MOD1 + m1;
        return ((ll)(m3 - x % MOD3 + MOD3) % MOD3 * INV_2 % MOD3 
                * ((ll)MOD1 * MOD2 % P) % P + x) % P;
    }
    static mint unit(int s) {
        return mint(qpow(G, (MOD1 - 1) / s, MOD1), 
                    qpow(G, (MOD2 - 1) / s, MOD2), 
                    qpow(G, (MOD3 - 1) / s, MOD3));
    }
} a[N], b[N];

void ntt(int len, mint *x, int dir = 1) {
    static mint X[N];
    for (int i = 0, k = 0; i < len; i++) {
        X[i] = x[k];
        int y = len >> 1;
        while (y & k) k ^= y, y >>= 1;
        k |= y;
    }

    for (int s = 2; s <= len; s <<= 1) {
        mint w_ = mint::unit(s);
        if (dir == -1) w_ = w_.inv();
        for (int l = 0; l < len; l += s) {
            int mid = l + (s >> 1);
            mint w = 1;
            for (int i = 0; i < s >> 1; i++) {
                mint t = w * X[mid + i];
                X[mid + i] = X[l + i] - t;
                X[l + i] = X[l + i] + t;
                w = w * w_;
            }
        }
    }
    if (dir == -1) {
        mint inv = mint(len, len, len).inv();
        for (int i = 0; i < len; i++) X[i] = X[i] * inv;
    }
    copy(X, X + len, x);
}

int main() {
    scanf("%d%d%d", &n, &m, &P);
    for (int i = 0; i <= n; i++) { int tmp; scanf("%d", &tmp); a[i] = tmp; }
    for (int i = 0; i <= m; i++) { int tmp; scanf("%d", &tmp); b[i] = tmp; }
    int sz = 1; while (sz < n + m + 1) sz <<= 1;
    ntt(sz, a); ntt(sz, b);
    for (int i = 0; i < sz; i++) a[i] = a[i] * b[i];
    ntt(sz, a, -1);
    for (int i = 0; i < n + m + 1; i++) printf("%d ", a[i].get());
    puts("");
    return 0;
}