
フーリエ変換を用いて多項式の掛け算を行う
株式会社アプリボットでサーバーサイドエンジニアをやっている斎藤です。
この記事は Applibot Advent Calendar 2020 15日目の記事です。
概要
多項式の掛け算の応用範囲は広く、形式的べき級数を用いた数え上げや非常に大きな数同士の掛け算などに利用されています。
ところが \(N\) 次多項式同士の掛け算の時間計算量は、愚直に行うとランダウの記法を用いて \(O(N^2)\) となります。しかし多項式の掛け算は畳み込み演算によって表現できることが知られており、高速フーリエ変換(FFT) を用いて計算をおこなうことで時間計算量を \(O(N \log N)\) とすることができます。本稿ではその手法について紹介します。
畳み込みと多項式の掛け算
畳み込みを \(f*g\) と書くこととすると、離散値に対して行う畳み込みは\[(f*g)(x)=\sum_{n=0}^{N-1}f(n)g(x-n)\]で定義されます。
次に\(m-1\) 次多項式同士の掛け算を考えます。
\[\begin{aligned}a(x)&=a_0+a_1x\cdots+a_{m-1}x^{m-1} \\ b(x)&=b_0+b_1x\cdots+b_{m-1}x^{m-1}\end{aligned}\]ここで、\(a(x)b(x)\) の \(x^k\) の係数を \(c_k\) とすると\[c_k=\sum_{n=0}^{2m-2}a_nb_{k-n}\]で表されることから、\(f(x)\) を \(f(n)=a_n\)、\(g(x)\) を \(g(n)=b_n\) となるように取り、\((f*g)(k)\)の値を求めることで畳み込みを用いて係数の計算ができることがわかります。(\(n>m-1\) となるような \(a_n,b_n\) は \(0\) とし、\(N=2m-1\) とすると良い)
ここまでで \(N\) 点に対する畳み込み結果を \(N\) 点分求める必要があるので、時間計算量としては \(O(N^2)\) となることがわかりました。これをなんとかして早く求めることを考えます。
離散フーリエ変換(DFT)
離散フーリエ変換(DFT) の定義は以下の形式で与えられます。\[F(t)=\sum_{x=0}^{N-1} f(x)e^{-i\frac{2\pi tx}{N}}\]また逆変換は
\[f(x)=\frac1N\sum_{t=0}^{N-1} F(t)e^{i\frac{2\pi tx}{N}}=\frac1N\overline{\sum_{t=0}^{N-1}\overline{F(t)}e^{-i\frac{2\pi tx}{N}}}\]です。正変換が行えれば逆変換は複素共役をうまく取ることにより同様の式から求められることがわかります。
\(f\) の \(N\) 点の離散フーリエ変換を \(\mathcal{F}_N(f)\) として表すこととすると、畳み込みは\[\mathcal{F}_N(f*g)=\mathcal{F}_N(f)\mathcal{F}_N(g)\]となります。
証明\[\begin{aligned} \mathcal{F}_N(f*g) &= \sum_{x=0}^{N-1}\sum_{n=0}^{N-1}f(n)g(x-n)e^{-i\frac{2\pi tx}{N}} \\ &= \sum_{x=0}^{N-1}\sum_{n=0}^{N-1}f(n)e^{-i\frac{2\pi tn}{N}}g(x-n)e^{-i\frac{2\pi t(x-n)}{N}} \\ &= \sum_{n=0}^{N-1}f(n)e^{-i\frac{2\pi tn}{N}}\sum_{x=0}^{N-1}g(x-n)e^{-i\frac{2\pi t(x-n)}{N}} \\ &= \mathcal{F}_N(f)\mathcal{F}_N(g) \end{aligned}\]
以上より、フーリエ変換を行い \(\mathcal{F}_N(f)\) と \(\mathcal{F}_N(g)\) を求めてから逆フーリエ変換を行うことにより \(f*g\) を求めるようにすると良いことがわかりますが
- フーリエ変換 \(\mathcal{F}_N(f)\) と \(\mathcal{F}_N(g)\) を求める \(O(N^2)\)
- \(\mathcal{F}_N(f)\mathcal{F}_N(g)\) を求める \(O(N)\)
- 逆フーリエ変換\(\mathcal{F}_N^{-1}(\mathcal{F}_N(f)\mathcal{F}_N(g))=f*g\) を求める \(O(N^2)\)
以上の操作は時間計算量 \(O(N^2)\) です。しかし \(N=2^k\) となるような \(N\) に対しては \(O( N \log N)\) で求まることが知られています。以下 \(N=2^k\) で表せることを仮定します。
高速フーリエ変換(FFT)
さて、ここで \(\omega_N=e^{-i\frac{2\pi}{N}}\)としフーリエ変換の式を偶数番目の和と奇数番目の和に分割することを考えます。 \[\begin{aligned} F(t) &= \sum_{x=0}^{N-1} f(x)\omega_N^{xt} \\ &= \sum_{x=0}^{\frac{N}2-1} f(2x)\omega_N^{2xt} + \sum_{x=0}^{\frac{N}2-1} f(2x+1)\omega_N^{(2x+1)t} \\ &= \sum_{x=0}^{\frac{N}2-1} f(2x)\omega_N^{2xt} + \omega_N^t\sum_{x=0}^{\frac{N}2-1} f(2x+1)\omega_N^{2xt} \\ \end{aligned}\] \(f_\text{odd}=f(2x), f_\text{even}=f(2x+1)\) とし \(\omega_N^{2xt}=\omega_{N/2}^{xt}\) であることに注意すると \[\begin{aligned} \mathcal{F}_N(f)=\mathcal{F}_\frac{N}{2}(f_\text{even}) + \omega_N^t\mathcal{F}_\frac{N}{2}(f_\text{odd}) \end{aligned}\] となり、分割数 \(N/2\) のDFTを2つ用いて分割数 \(N\) のDFTを表すことが出来ました。以上の議論により計算量は \[ \begin{eqnarray} T(N)=\left\{ \begin{array}{l} 2T(N/2) + O(N) &(N > 2) \\ O(1) &(N = 2) \end{array} \right. \end{eqnarray} \]で、結局 \(O(N \log N)\) となります。
バタフライ演算

高速フーリエ変換のそれぞれのステップを可視化すると上の図のようになりますが、一つのペアの計算に着目すると次のような関係を表しています。\(x_\alpha\) を黒線で \(x_\beta\) を赤線で表しています。 \[\begin{aligned} X_\alpha &= x_\alpha + x_\beta \omega_M^{m} \\ X_\beta &= x_\alpha + x_\beta \omega_M^{m+\frac{M}2} \end{aligned}\]この演算を図示したときに1つのペアの計算に着目すると蝶のように見えることから、バタフライ演算と呼ばれています。
バタフライ演算の1ステップに着目すると、インプレースに計算を実行することができることが分かります。入力列の順序についてもビットを逆順にする演算(bit-reverse)によるインデックスで並べることによりインプレースに並び替えることが出来るので適切に前処理を行えば良いです。
よって実際のアルゴリズムは
- 入力列をbit-reverseに並び替える \(O(N)\)
- バタフライ演算を実行しフーリエ変換 \(\mathcal{F}_N(f)\) と \(\mathcal{F}_N(g)\) を求める \(O(N \log N)\)
- \(\mathcal{F}_N(f)\mathcal{F}_N(g)\) を求める \(O(N)\)
- 逆フーリエ変換を用い\(\mathcal{F}_N^{-1}(\mathcal{F}_N(f)\mathcal{F}_N(g))=f*g\) を求める \(O(N \log N)\)
となり、合計して \(O(N \log N)\) で求められることが分かります。
数論変換(NTT)
さて、これまでDFTの性質について論じましたがこれらの変換の性質は \(\omega_N=e^{-i\frac{2\pi}{N}}\) が \(N\) 乗して初めて1になるという性質によっています。このような数を原始根(primitive root)と呼び \(\omega_N\) は \(\mathbb{C}\) における1の原始 \(N\) 乗根と呼んだりします。
他にこのような性質を持つ代数的構造として、素数 \(p\) を法とした剰余体 \(\mathbb{Z}/p\mathbb{Z}\) は原始 \(p-1\) 乗根を持つことが知られています(証明略)。\(N=2^k\) となるような \(N\) について \(p-1=\xi N\) となればある原始 \(p-1\) 乗根 \(\omega_{p-1}\) を用いて \( \omega_{p-1}^\xi \equiv \omega_N \in \mathbb{Z}/p\mathbb{Z}\)。したがって、 \( \omega_N\) が存在するような \(p\) を定めれば、FFTを実行できます。この場合すべての演算が整数型で保持されるため、浮動小数点演算による誤差を考える必要がなくなります。
以下に実装を示します
実装
/// 剰余を考慮する pow fn modpow(mut base: i64, mut exp: i64, modulus: i64) -> i64 { base %= modulus; let mut result = 1; while exp > 0 { if exp & 1 > 0 { result = (result * base) % modulus; } base = (base * base) % modulus; exp >>= 1; } return result; } /// bit reverse を求める fn bit_reverse(n: usize, width: usize) -> usize { let mut result = 0; for i in 0..width { result |= ((n >> i) & 1) << (width - 1 - i); } return result; } /// bit reverse 順になるように swap を行う fn bit_swap(a: &mut[i64], k: usize) { for i in 0..a.len() { let j = bit_reverse(i, k); if i < j { a[i] ^= a[j]; a[j] ^= a[i]; a[i] ^= a[j]; } } } /// FFTを行う fn fft(a: &mut[i64], base: i64, modulus: i64, k: usize) { bit_swap(a, k); let n = a.len(); let sqrt = modpow(base, (n/2) as i64, modulus); let mut m = n; while m > 1 { m >>= 1; let omega = modpow(base, m as i64, modulus); let mut omega_k = 1; let dist = n/m/2; for i in 0..dist { for s in (i..n).step_by(dist*2) { let f_eve = a[s] % modulus; let f_odd = omega_k*a[s+dist] % modulus; a[s] = (f_eve + f_odd) % modulus; a[s+dist] = (f_eve + sqrt*f_odd) % modulus; } omega_k *= omega; omega_k %= modulus; } } } /// 逆FFTを行う fn ifft(a: &mut[i64], base: i64, modulus: i64, k: usize) { // 逆数を用いて fft(a, modpow(base, modulus-2, modulus), modulus, k); // Nで割る let n_inv = modpow(a.len() as i64, modulus-2, modulus); for i in 0..a.len() { a[i] *= n_inv; a[i] %= modulus; } } /// 畳み込みをインプレースに求める fn convolute(a: &mut[i64], b: &mut[i64], base: i64, modulus: i64, k: usize) { fft(a, base, modulus, k); fft(b, base, modulus, k); for i in 0..a.len() { a[i] *= b[i]; a[i] %= modulus; } ifft(a, base, modulus, k); } fn main() { let k = 18; // n = 2^k となるように n を取る let n = 2i64.pow(k); // 原始 n 乗根 ω let omega = 103; // p = 5880*(2^18) + 1 let p = 1541406721; let mut f = vec![0; n as usize]; let mut g = vec![0; n as usize]; // f(x) = 1 + 2x + 3x^2 // g(x) = 5 + 3x + x^2 // f(x)g(x) = 5 + 13x + 22x^2 + 11x^3 + 3x^4 f[0] = 1; f[1] = 2; f[2] = 3; g[0] = 5; g[1] = 3; g[2] = 1; convolute(&mut f, &mut g, omega, p, k as usize); for i in 0..5 { println!("{}", f[i]); // -> 5 13 22 11 3 } }
\(p, N, \omega_N\) などの決め方は上述したものに加えて、\(p\) を出力上問題ない範囲の大きさに定め、 \(\omega_N\) の冪乗を列挙することにより1でない最も小さいものを選ぶことが出来ます。入力については十分に大きい \(N\) を取り適当に0埋めを行うと良いです。
まとめ
以上で多項式の演算を \(O(N \log N)\) で行えるようになりました。今後とも組み合わせの計算や周波数成分の分析などに生かして行こうと思います☺️。最後までお読みいただきありがとうございました。
この記事へのコメントはありません。