mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-01-18 16:27:02 +00:00
Add radix2 FFT (#1166)
* Add radix2 FFT Created a dynamic implementation of the radix - 2 Fast Fourier Transform for fast polynomial multiplication. Reference: https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm#The_radix-2_DIT_case * Rename radix2_FFT.py to radix2_fft.py * Update radix2_fft printing Improved the printing method with f.prefix and String.join() * __str__ method update * Turned the tests into doctests
This commit is contained in:
parent
ab25079e16
commit
a41a14f9d8
222
maths/radix2_fft.py
Normal file
222
maths/radix2_fft.py
Normal file
|
@ -0,0 +1,222 @@
|
|||
"""
|
||||
Fast Polynomial Multiplication using radix-2 fast Fourier Transform.
|
||||
"""
|
||||
|
||||
import mpmath # for roots of unity
|
||||
import numpy as np
|
||||
|
||||
|
||||
class FFT:
|
||||
"""
|
||||
Fast Polynomial Multiplication using radix-2 fast Fourier Transform.
|
||||
|
||||
Reference:
|
||||
https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm#The_radix-2_DIT_case
|
||||
|
||||
For polynomials of degree m and n the algorithms has complexity
|
||||
O(n*logn + m*logm)
|
||||
|
||||
The main part of the algorithm is split in two parts:
|
||||
1) __DFT: We compute the discrete fourier transform (DFT) of A and B using a
|
||||
bottom-up dynamic approach -
|
||||
2) __multiply: Once we obtain the DFT of A*B, we can similarly
|
||||
invert it to obtain A*B
|
||||
|
||||
The class FFT takes two polynomials A and B with complex coefficients as arguments;
|
||||
The two polynomials should be represented as a sequence of coefficients starting
|
||||
from the free term. Thus, for instance x + 2*x^3 could be represented as
|
||||
[0,1,0,2] or (0,1,0,2). The constructor adds some zeros at the end so that the
|
||||
polynomials have the same length which is a power of 2 at least the length of
|
||||
their product.
|
||||
|
||||
Example:
|
||||
|
||||
Create two polynomials as sequences
|
||||
>>> A = [0, 1, 0, 2] # x+2x^3
|
||||
>>> B = (2, 3, 4, 0) # 2+3x+4x^2
|
||||
|
||||
Create an FFT object with them
|
||||
>>> x = FFT(A, B)
|
||||
|
||||
Print product
|
||||
>>> print(x.product) # 2x + 3x^2 + 8x^3 + 4x^4 + 6x^5
|
||||
[(-0+0j), (2+0j), (3+0j), (8+0j), (6+0j), (8+0j)]
|
||||
|
||||
__str__ test
|
||||
>>> print(x)
|
||||
A = 0*x^0 + 1*x^1 + 2*x^0 + 3*x^2
|
||||
B = 0*x^2 + 1*x^3 + 2*x^4
|
||||
A*B = 0*x^(-0+0j) + 1*x^(2+0j) + 2*x^(3+0j) + 3*x^(8+0j) + 4*x^(6+0j) + 5*x^(8+0j)
|
||||
"""
|
||||
|
||||
def __init__(self, polyA=[0], polyB=[0]):
|
||||
# Input as list
|
||||
self.polyA = list(polyA)[:]
|
||||
self.polyB = list(polyB)[:]
|
||||
|
||||
# Remove leading zero coefficients
|
||||
while self.polyA[-1] == 0:
|
||||
self.polyA.pop()
|
||||
self.len_A = len(self.polyA)
|
||||
|
||||
while self.polyB[-1] == 0:
|
||||
self.polyB.pop()
|
||||
self.len_B = len(self.polyB)
|
||||
|
||||
# Add 0 to make lengths equal a power of 2
|
||||
self.C_max_length = int(
|
||||
2
|
||||
** np.ceil(
|
||||
np.log2(
|
||||
len(self.polyA) + len(self.polyB) - 1
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
while len(self.polyA) < self.C_max_length:
|
||||
self.polyA.append(0)
|
||||
while len(self.polyB) < self.C_max_length:
|
||||
self.polyB.append(0)
|
||||
# A complex root used for the fourier transform
|
||||
self.root = complex(
|
||||
mpmath.root(x=1, n=self.C_max_length, k=1)
|
||||
)
|
||||
|
||||
# The product
|
||||
self.product = self.__multiply()
|
||||
|
||||
# Discrete fourier transform of A and B
|
||||
def __DFT(self, which):
|
||||
if which == "A":
|
||||
dft = [[x] for x in self.polyA]
|
||||
else:
|
||||
dft = [[x] for x in self.polyB]
|
||||
# Corner case
|
||||
if len(dft) <= 1:
|
||||
return dft[0]
|
||||
#
|
||||
next_ncol = self.C_max_length // 2
|
||||
while next_ncol > 0:
|
||||
new_dft = [[] for i in range(next_ncol)]
|
||||
root = self.root ** next_ncol
|
||||
|
||||
# First half of next step
|
||||
current_root = 1
|
||||
for j in range(
|
||||
self.C_max_length // (next_ncol * 2)
|
||||
):
|
||||
for i in range(next_ncol):
|
||||
new_dft[i].append(
|
||||
dft[i][j]
|
||||
+ current_root
|
||||
* dft[i + next_ncol][j]
|
||||
)
|
||||
current_root *= root
|
||||
# Second half of next step
|
||||
current_root = 1
|
||||
for j in range(
|
||||
self.C_max_length // (next_ncol * 2)
|
||||
):
|
||||
for i in range(next_ncol):
|
||||
new_dft[i].append(
|
||||
dft[i][j]
|
||||
- current_root
|
||||
* dft[i + next_ncol][j]
|
||||
)
|
||||
current_root *= root
|
||||
# Update
|
||||
dft = new_dft
|
||||
next_ncol = next_ncol // 2
|
||||
return dft[0]
|
||||
|
||||
# multiply the DFTs of A and B and find A*B
|
||||
def __multiply(self):
|
||||
dftA = self.__DFT("A")
|
||||
dftB = self.__DFT("B")
|
||||
inverseC = [
|
||||
[
|
||||
dftA[i] * dftB[i]
|
||||
for i in range(self.C_max_length)
|
||||
]
|
||||
]
|
||||
del dftA
|
||||
del dftB
|
||||
|
||||
# Corner Case
|
||||
if len(inverseC[0]) <= 1:
|
||||
return inverseC[0]
|
||||
# Inverse DFT
|
||||
next_ncol = 2
|
||||
while next_ncol <= self.C_max_length:
|
||||
new_inverseC = [[] for i in range(next_ncol)]
|
||||
root = self.root ** (next_ncol // 2)
|
||||
current_root = 1
|
||||
# First half of next step
|
||||
for j in range(self.C_max_length // next_ncol):
|
||||
for i in range(next_ncol // 2):
|
||||
# Even positions
|
||||
new_inverseC[i].append(
|
||||
(
|
||||
inverseC[i][j]
|
||||
+ inverseC[i][
|
||||
j
|
||||
+ self.C_max_length
|
||||
// next_ncol
|
||||
]
|
||||
)
|
||||
/ 2
|
||||
)
|
||||
# Odd positions
|
||||
new_inverseC[i + next_ncol // 2].append(
|
||||
(
|
||||
inverseC[i][j]
|
||||
- inverseC[i][
|
||||
j
|
||||
+ self.C_max_length
|
||||
// next_ncol
|
||||
]
|
||||
)
|
||||
/ (2 * current_root)
|
||||
)
|
||||
current_root *= root
|
||||
# Update
|
||||
inverseC = new_inverseC
|
||||
next_ncol *= 2
|
||||
# Unpack
|
||||
inverseC = [
|
||||
round(x[0].real, 8) + round(x[0].imag, 8) * 1j
|
||||
for x in inverseC
|
||||
]
|
||||
|
||||
# Remove leading 0's
|
||||
while inverseC[-1] == 0:
|
||||
inverseC.pop()
|
||||
return inverseC
|
||||
|
||||
# Overwrite __str__ for print(); Shows A, B and A*B
|
||||
def __str__(self):
|
||||
A = "A = " + " + ".join(
|
||||
f"{coef}*x^{i}"
|
||||
for coef, i in enumerate(
|
||||
self.polyA[: self.len_A]
|
||||
)
|
||||
)
|
||||
B = "B = " + " + ".join(
|
||||
f"{coef}*x^{i}"
|
||||
for coef, i in enumerate(
|
||||
self.polyB[: self.len_B]
|
||||
)
|
||||
)
|
||||
C = "A*B = " + " + ".join(
|
||||
f"{coef}*x^{i}"
|
||||
for coef, i in enumerate(self.product)
|
||||
)
|
||||
|
||||
return "\n".join((A, B, C))
|
||||
|
||||
|
||||
# Unit tests
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
|
||||
doctest.testmod()
|
Loading…
Reference in New Issue
Block a user