Python/matrix/matrix_multiplication_recursion.py
pre-commit-ci[bot] bc8df6de31
[pre-commit.ci] pre-commit autoupdate (#11322)
* [pre-commit.ci] pre-commit autoupdate

updates:
- [github.com/astral-sh/ruff-pre-commit: v0.2.2 → v0.3.2](https://github.com/astral-sh/ruff-pre-commit/compare/v0.2.2...v0.3.2)
- [github.com/pre-commit/mirrors-mypy: v1.8.0 → v1.9.0](https://github.com/pre-commit/mirrors-mypy/compare/v1.8.0...v1.9.0)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-03-13 07:52:41 +01:00

182 lines
5.0 KiB
Python

# @Author : ojas-wani
# @File : matrix_multiplication_recursion.py
# @Date : 10/06/2023
"""
Perform matrix multiplication using a recursive algorithm.
https://en.wikipedia.org/wiki/Matrix_multiplication
"""
# type Matrix = list[list[int]] # psf/black currenttly fails on this line
Matrix = list[list[int]]
matrix_1_to_4 = [
[1, 2],
[3, 4],
]
matrix_5_to_8 = [
[5, 6],
[7, 8],
]
matrix_5_to_9_high = [
[5, 6],
[7, 8],
[9],
]
matrix_5_to_9_wide = [
[5, 6],
[7, 8, 9],
]
matrix_count_up = [
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
]
matrix_unordered = [
[5, 8, 1, 2],
[6, 7, 3, 0],
[4, 5, 9, 1],
[2, 6, 10, 14],
]
matrices = (
matrix_1_to_4,
matrix_5_to_8,
matrix_5_to_9_high,
matrix_5_to_9_wide,
matrix_count_up,
matrix_unordered,
)
def is_square(matrix: Matrix) -> bool:
"""
>>> is_square([])
True
>>> is_square(matrix_1_to_4)
True
>>> is_square(matrix_5_to_9_high)
False
"""
len_matrix = len(matrix)
return all(len(row) == len_matrix for row in matrix)
def matrix_multiply(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
"""
>>> matrix_multiply(matrix_1_to_4, matrix_5_to_8)
[[19, 22], [43, 50]]
"""
return [
[sum(a * b for a, b in zip(row, col)) for col in zip(*matrix_b)]
for row in matrix_a
]
def matrix_multiply_recursive(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
"""
:param matrix_a: A square Matrix.
:param matrix_b: Another square Matrix with the same dimensions as matrix_a.
:return: Result of matrix_a * matrix_b.
:raises ValueError: If the matrices cannot be multiplied.
>>> matrix_multiply_recursive([], [])
[]
>>> matrix_multiply_recursive(matrix_1_to_4, matrix_5_to_8)
[[19, 22], [43, 50]]
>>> matrix_multiply_recursive(matrix_count_up, matrix_unordered)
[[37, 61, 74, 61], [105, 165, 166, 129], [173, 269, 258, 197], [241, 373, 350, 265]]
>>> matrix_multiply_recursive(matrix_1_to_4, matrix_5_to_9_wide)
Traceback (most recent call last):
...
ValueError: Invalid matrix dimensions
>>> matrix_multiply_recursive(matrix_1_to_4, matrix_5_to_9_high)
Traceback (most recent call last):
...
ValueError: Invalid matrix dimensions
>>> matrix_multiply_recursive(matrix_1_to_4, matrix_count_up)
Traceback (most recent call last):
...
ValueError: Invalid matrix dimensions
"""
if not matrix_a or not matrix_b:
return []
if not all(
(len(matrix_a) == len(matrix_b), is_square(matrix_a), is_square(matrix_b))
):
raise ValueError("Invalid matrix dimensions")
# Initialize the result matrix with zeros
result = [[0] * len(matrix_b[0]) for _ in range(len(matrix_a))]
# Recursive multiplication of matrices
def multiply(
i_loop: int,
j_loop: int,
k_loop: int,
matrix_a: Matrix,
matrix_b: Matrix,
result: Matrix,
) -> None:
"""
:param matrix_a: A square Matrix.
:param matrix_b: Another square Matrix with the same dimensions as matrix_a.
:param result: Result matrix
:param i: Index used for iteration during multiplication.
:param j: Index used for iteration during multiplication.
:param k: Index used for iteration during multiplication.
>>> 0 > 1 # Doctests in inner functions are never run
True
"""
if i_loop >= len(matrix_a):
return
if j_loop >= len(matrix_b[0]):
return multiply(i_loop + 1, 0, 0, matrix_a, matrix_b, result)
if k_loop >= len(matrix_b):
return multiply(i_loop, j_loop + 1, 0, matrix_a, matrix_b, result)
result[i_loop][j_loop] += matrix_a[i_loop][k_loop] * matrix_b[k_loop][j_loop]
return multiply(i_loop, j_loop, k_loop + 1, matrix_a, matrix_b, result)
# Perform the recursive matrix multiplication
multiply(0, 0, 0, matrix_a, matrix_b, result)
return result
if __name__ == "__main__":
from doctest import testmod
failure_count, test_count = testmod()
if not failure_count:
matrix_a = matrices[0]
for matrix_b in matrices[1:]:
print("Multiplying:")
for row in matrix_a:
print(row)
print("By:")
for row in matrix_b:
print(row)
print("Result:")
try:
result = matrix_multiply_recursive(matrix_a, matrix_b)
for row in result:
print(row)
assert result == matrix_multiply(matrix_a, matrix_b)
except ValueError as e:
print(f"{e!r}")
print()
matrix_a = matrix_b
print("Benchmark:")
from functools import partial
from timeit import timeit
mytimeit = partial(timeit, globals=globals(), number=100_000)
for func in ("matrix_multiply", "matrix_multiply_recursive"):
print(f"{func:>25}(): {mytimeit(f'{func}(matrix_count_up, matrix_unordered)')}")