Python/maths/radix2_fft.py
Christian Clauss 24d3cf8244
The black formatter is no longer beta (#5960)
* The black formatter is no longer beta

* pre-commit autoupdate

* pre-commit autoupdate

* Remove project_euler/problem_145 which is killing our CI tests

* updating DIRECTORY.md

Co-authored-by: github-actions <${GITHUB_ACTOR}@users.noreply.github.com>
2022-01-30 20:29:54 +01:00

181 lines
6.0 KiB
Python

"""
Fast Polynomial Multiplication using radix-2 fast Fourier Transform.
"""
import mpmath # for roots of unity
import numpy as np
class FFT:
"""
Fast Polynomial Multiplication using radix-2 fast Fourier Transform.
Reference:
https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm#The_radix-2_DIT_case
For polynomials of degree m and n the algorithms has complexity
O(n*logn + m*logm)
The main part of the algorithm is split in two parts:
1) __DFT: We compute the discrete fourier transform (DFT) of A and B using a
bottom-up dynamic approach -
2) __multiply: Once we obtain the DFT of A*B, we can similarly
invert it to obtain A*B
The class FFT takes two polynomials A and B with complex coefficients as arguments;
The two polynomials should be represented as a sequence of coefficients starting
from the free term. Thus, for instance x + 2*x^3 could be represented as
[0,1,0,2] or (0,1,0,2). The constructor adds some zeros at the end so that the
polynomials have the same length which is a power of 2 at least the length of
their product.
Example:
Create two polynomials as sequences
>>> A = [0, 1, 0, 2] # x+2x^3
>>> B = (2, 3, 4, 0) # 2+3x+4x^2
Create an FFT object with them
>>> x = FFT(A, B)
Print product
>>> print(x.product) # 2x + 3x^2 + 8x^3 + 4x^4 + 6x^5
[(-0+0j), (2+0j), (3+0j), (8+0j), (6+0j), (8+0j)]
__str__ test
>>> print(x)
A = 0*x^0 + 1*x^1 + 2*x^0 + 3*x^2
B = 0*x^2 + 1*x^3 + 2*x^4
A*B = 0*x^(-0+0j) + 1*x^(2+0j) + 2*x^(3+0j) + 3*x^(8+0j) + 4*x^(6+0j) + 5*x^(8+0j)
"""
def __init__(self, polyA=None, polyB=None):
# Input as list
self.polyA = list(polyA or [0])[:]
self.polyB = list(polyB or [0])[:]
# Remove leading zero coefficients
while self.polyA[-1] == 0:
self.polyA.pop()
self.len_A = len(self.polyA)
while self.polyB[-1] == 0:
self.polyB.pop()
self.len_B = len(self.polyB)
# Add 0 to make lengths equal a power of 2
self.C_max_length = int(
2 ** np.ceil(np.log2(len(self.polyA) + len(self.polyB) - 1))
)
while len(self.polyA) < self.C_max_length:
self.polyA.append(0)
while len(self.polyB) < self.C_max_length:
self.polyB.append(0)
# A complex root used for the fourier transform
self.root = complex(mpmath.root(x=1, n=self.C_max_length, k=1))
# The product
self.product = self.__multiply()
# Discrete fourier transform of A and B
def __DFT(self, which):
if which == "A":
dft = [[x] for x in self.polyA]
else:
dft = [[x] for x in self.polyB]
# Corner case
if len(dft) <= 1:
return dft[0]
#
next_ncol = self.C_max_length // 2
while next_ncol > 0:
new_dft = [[] for i in range(next_ncol)]
root = self.root**next_ncol
# First half of next step
current_root = 1
for j in range(self.C_max_length // (next_ncol * 2)):
for i in range(next_ncol):
new_dft[i].append(dft[i][j] + current_root * dft[i + next_ncol][j])
current_root *= root
# Second half of next step
current_root = 1
for j in range(self.C_max_length // (next_ncol * 2)):
for i in range(next_ncol):
new_dft[i].append(dft[i][j] - current_root * dft[i + next_ncol][j])
current_root *= root
# Update
dft = new_dft
next_ncol = next_ncol // 2
return dft[0]
# multiply the DFTs of A and B and find A*B
def __multiply(self):
dftA = self.__DFT("A")
dftB = self.__DFT("B")
inverseC = [[dftA[i] * dftB[i] for i in range(self.C_max_length)]]
del dftA
del dftB
# Corner Case
if len(inverseC[0]) <= 1:
return inverseC[0]
# Inverse DFT
next_ncol = 2
while next_ncol <= self.C_max_length:
new_inverseC = [[] for i in range(next_ncol)]
root = self.root ** (next_ncol // 2)
current_root = 1
# First half of next step
for j in range(self.C_max_length // next_ncol):
for i in range(next_ncol // 2):
# Even positions
new_inverseC[i].append(
(
inverseC[i][j]
+ inverseC[i][j + self.C_max_length // next_ncol]
)
/ 2
)
# Odd positions
new_inverseC[i + next_ncol // 2].append(
(
inverseC[i][j]
- inverseC[i][j + self.C_max_length // next_ncol]
)
/ (2 * current_root)
)
current_root *= root
# Update
inverseC = new_inverseC
next_ncol *= 2
# Unpack
inverseC = [round(x[0].real, 8) + round(x[0].imag, 8) * 1j for x in inverseC]
# Remove leading 0's
while inverseC[-1] == 0:
inverseC.pop()
return inverseC
# Overwrite __str__ for print(); Shows A, B and A*B
def __str__(self):
A = "A = " + " + ".join(
f"{coef}*x^{i}" for coef, i in enumerate(self.polyA[: self.len_A])
)
B = "B = " + " + ".join(
f"{coef}*x^{i}" for coef, i in enumerate(self.polyB[: self.len_B])
)
C = "A*B = " + " + ".join(
f"{coef}*x^{i}" for coef, i in enumerate(self.product)
)
return "\n".join((A, B, C))
# Unit tests
if __name__ == "__main__":
import doctest
doctest.testmod()