diff --git a/maths/radix2_fft.py b/maths/radix2_fft.py new file mode 100644 index 000000000..c7ffe9652 --- /dev/null +++ b/maths/radix2_fft.py @@ -0,0 +1,222 @@ +""" +Fast Polynomial Multiplication using radix-2 fast Fourier Transform. +""" + +import mpmath # for roots of unity +import numpy as np + + +class FFT: + """ + Fast Polynomial Multiplication using radix-2 fast Fourier Transform. + + Reference: + https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm#The_radix-2_DIT_case + + For polynomials of degree m and n the algorithms has complexity + O(n*logn + m*logm) + + The main part of the algorithm is split in two parts: + 1) __DFT: We compute the discrete fourier transform (DFT) of A and B using a + bottom-up dynamic approach - + 2) __multiply: Once we obtain the DFT of A*B, we can similarly + invert it to obtain A*B + + The class FFT takes two polynomials A and B with complex coefficients as arguments; + The two polynomials should be represented as a sequence of coefficients starting + from the free term. Thus, for instance x + 2*x^3 could be represented as + [0,1,0,2] or (0,1,0,2). The constructor adds some zeros at the end so that the + polynomials have the same length which is a power of 2 at least the length of + their product. + + Example: + + Create two polynomials as sequences + >>> A = [0, 1, 0, 2] # x+2x^3 + >>> B = (2, 3, 4, 0) # 2+3x+4x^2 + + Create an FFT object with them + >>> x = FFT(A, B) + + Print product + >>> print(x.product) # 2x + 3x^2 + 8x^3 + 4x^4 + 6x^5 + [(-0+0j), (2+0j), (3+0j), (8+0j), (6+0j), (8+0j)] + + __str__ test + >>> print(x) + A = 0*x^0 + 1*x^1 + 2*x^0 + 3*x^2 + B = 0*x^2 + 1*x^3 + 2*x^4 + A*B = 0*x^(-0+0j) + 1*x^(2+0j) + 2*x^(3+0j) + 3*x^(8+0j) + 4*x^(6+0j) + 5*x^(8+0j) + """ + + def __init__(self, polyA=[0], polyB=[0]): + # Input as list + self.polyA = list(polyA)[:] + self.polyB = list(polyB)[:] + + # Remove leading zero coefficients + while self.polyA[-1] == 0: + self.polyA.pop() + self.len_A = len(self.polyA) + + while self.polyB[-1] == 0: + self.polyB.pop() + self.len_B = len(self.polyB) + + # Add 0 to make lengths equal a power of 2 + self.C_max_length = int( + 2 + ** np.ceil( + np.log2( + len(self.polyA) + len(self.polyB) - 1 + ) + ) + ) + + while len(self.polyA) < self.C_max_length: + self.polyA.append(0) + while len(self.polyB) < self.C_max_length: + self.polyB.append(0) + # A complex root used for the fourier transform + self.root = complex( + mpmath.root(x=1, n=self.C_max_length, k=1) + ) + + # The product + self.product = self.__multiply() + + # Discrete fourier transform of A and B + def __DFT(self, which): + if which == "A": + dft = [[x] for x in self.polyA] + else: + dft = [[x] for x in self.polyB] + # Corner case + if len(dft) <= 1: + return dft[0] + # + next_ncol = self.C_max_length // 2 + while next_ncol > 0: + new_dft = [[] for i in range(next_ncol)] + root = self.root ** next_ncol + + # First half of next step + current_root = 1 + for j in range( + self.C_max_length // (next_ncol * 2) + ): + for i in range(next_ncol): + new_dft[i].append( + dft[i][j] + + current_root + * dft[i + next_ncol][j] + ) + current_root *= root + # Second half of next step + current_root = 1 + for j in range( + self.C_max_length // (next_ncol * 2) + ): + for i in range(next_ncol): + new_dft[i].append( + dft[i][j] + - current_root + * dft[i + next_ncol][j] + ) + current_root *= root + # Update + dft = new_dft + next_ncol = next_ncol // 2 + return dft[0] + + # multiply the DFTs of A and B and find A*B + def __multiply(self): + dftA = self.__DFT("A") + dftB = self.__DFT("B") + inverseC = [ + [ + dftA[i] * dftB[i] + for i in range(self.C_max_length) + ] + ] + del dftA + del dftB + + # Corner Case + if len(inverseC[0]) <= 1: + return inverseC[0] + # Inverse DFT + next_ncol = 2 + while next_ncol <= self.C_max_length: + new_inverseC = [[] for i in range(next_ncol)] + root = self.root ** (next_ncol // 2) + current_root = 1 + # First half of next step + for j in range(self.C_max_length // next_ncol): + for i in range(next_ncol // 2): + # Even positions + new_inverseC[i].append( + ( + inverseC[i][j] + + inverseC[i][ + j + + self.C_max_length + // next_ncol + ] + ) + / 2 + ) + # Odd positions + new_inverseC[i + next_ncol // 2].append( + ( + inverseC[i][j] + - inverseC[i][ + j + + self.C_max_length + // next_ncol + ] + ) + / (2 * current_root) + ) + current_root *= root + # Update + inverseC = new_inverseC + next_ncol *= 2 + # Unpack + inverseC = [ + round(x[0].real, 8) + round(x[0].imag, 8) * 1j + for x in inverseC + ] + + # Remove leading 0's + while inverseC[-1] == 0: + inverseC.pop() + return inverseC + + # Overwrite __str__ for print(); Shows A, B and A*B + def __str__(self): + A = "A = " + " + ".join( + f"{coef}*x^{i}" + for coef, i in enumerate( + self.polyA[: self.len_A] + ) + ) + B = "B = " + " + ".join( + f"{coef}*x^{i}" + for coef, i in enumerate( + self.polyB[: self.len_B] + ) + ) + C = "A*B = " + " + ".join( + f"{coef}*x^{i}" + for coef, i in enumerate(self.product) + ) + + return "\n".join((A, B, C)) + + +# Unit tests +if __name__ == "__main__": + import doctest + + doctest.testmod()