mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-01-18 16:27:02 +00:00
bc8df6de31
* [pre-commit.ci] pre-commit autoupdate updates: - [github.com/astral-sh/ruff-pre-commit: v0.2.2 → v0.3.2](https://github.com/astral-sh/ruff-pre-commit/compare/v0.2.2...v0.3.2) - [github.com/pre-commit/mirrors-mypy: v1.8.0 → v1.9.0](https://github.com/pre-commit/mirrors-mypy/compare/v1.8.0...v1.9.0) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
182 lines
5.0 KiB
Python
182 lines
5.0 KiB
Python
# @Author : ojas-wani
|
|
# @File : matrix_multiplication_recursion.py
|
|
# @Date : 10/06/2023
|
|
|
|
|
|
"""
|
|
Perform matrix multiplication using a recursive algorithm.
|
|
https://en.wikipedia.org/wiki/Matrix_multiplication
|
|
"""
|
|
|
|
# type Matrix = list[list[int]] # psf/black currenttly fails on this line
|
|
Matrix = list[list[int]]
|
|
|
|
matrix_1_to_4 = [
|
|
[1, 2],
|
|
[3, 4],
|
|
]
|
|
|
|
matrix_5_to_8 = [
|
|
[5, 6],
|
|
[7, 8],
|
|
]
|
|
|
|
matrix_5_to_9_high = [
|
|
[5, 6],
|
|
[7, 8],
|
|
[9],
|
|
]
|
|
|
|
matrix_5_to_9_wide = [
|
|
[5, 6],
|
|
[7, 8, 9],
|
|
]
|
|
|
|
matrix_count_up = [
|
|
[1, 2, 3, 4],
|
|
[5, 6, 7, 8],
|
|
[9, 10, 11, 12],
|
|
[13, 14, 15, 16],
|
|
]
|
|
|
|
matrix_unordered = [
|
|
[5, 8, 1, 2],
|
|
[6, 7, 3, 0],
|
|
[4, 5, 9, 1],
|
|
[2, 6, 10, 14],
|
|
]
|
|
matrices = (
|
|
matrix_1_to_4,
|
|
matrix_5_to_8,
|
|
matrix_5_to_9_high,
|
|
matrix_5_to_9_wide,
|
|
matrix_count_up,
|
|
matrix_unordered,
|
|
)
|
|
|
|
|
|
def is_square(matrix: Matrix) -> bool:
|
|
"""
|
|
>>> is_square([])
|
|
True
|
|
>>> is_square(matrix_1_to_4)
|
|
True
|
|
>>> is_square(matrix_5_to_9_high)
|
|
False
|
|
"""
|
|
len_matrix = len(matrix)
|
|
return all(len(row) == len_matrix for row in matrix)
|
|
|
|
|
|
def matrix_multiply(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
|
|
"""
|
|
>>> matrix_multiply(matrix_1_to_4, matrix_5_to_8)
|
|
[[19, 22], [43, 50]]
|
|
"""
|
|
return [
|
|
[sum(a * b for a, b in zip(row, col)) for col in zip(*matrix_b)]
|
|
for row in matrix_a
|
|
]
|
|
|
|
|
|
def matrix_multiply_recursive(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
|
|
"""
|
|
:param matrix_a: A square Matrix.
|
|
:param matrix_b: Another square Matrix with the same dimensions as matrix_a.
|
|
:return: Result of matrix_a * matrix_b.
|
|
:raises ValueError: If the matrices cannot be multiplied.
|
|
|
|
>>> matrix_multiply_recursive([], [])
|
|
[]
|
|
>>> matrix_multiply_recursive(matrix_1_to_4, matrix_5_to_8)
|
|
[[19, 22], [43, 50]]
|
|
>>> matrix_multiply_recursive(matrix_count_up, matrix_unordered)
|
|
[[37, 61, 74, 61], [105, 165, 166, 129], [173, 269, 258, 197], [241, 373, 350, 265]]
|
|
>>> matrix_multiply_recursive(matrix_1_to_4, matrix_5_to_9_wide)
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: Invalid matrix dimensions
|
|
>>> matrix_multiply_recursive(matrix_1_to_4, matrix_5_to_9_high)
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: Invalid matrix dimensions
|
|
>>> matrix_multiply_recursive(matrix_1_to_4, matrix_count_up)
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: Invalid matrix dimensions
|
|
"""
|
|
if not matrix_a or not matrix_b:
|
|
return []
|
|
if not all(
|
|
(len(matrix_a) == len(matrix_b), is_square(matrix_a), is_square(matrix_b))
|
|
):
|
|
raise ValueError("Invalid matrix dimensions")
|
|
|
|
# Initialize the result matrix with zeros
|
|
result = [[0] * len(matrix_b[0]) for _ in range(len(matrix_a))]
|
|
|
|
# Recursive multiplication of matrices
|
|
def multiply(
|
|
i_loop: int,
|
|
j_loop: int,
|
|
k_loop: int,
|
|
matrix_a: Matrix,
|
|
matrix_b: Matrix,
|
|
result: Matrix,
|
|
) -> None:
|
|
"""
|
|
:param matrix_a: A square Matrix.
|
|
:param matrix_b: Another square Matrix with the same dimensions as matrix_a.
|
|
:param result: Result matrix
|
|
:param i: Index used for iteration during multiplication.
|
|
:param j: Index used for iteration during multiplication.
|
|
:param k: Index used for iteration during multiplication.
|
|
>>> 0 > 1 # Doctests in inner functions are never run
|
|
True
|
|
"""
|
|
if i_loop >= len(matrix_a):
|
|
return
|
|
if j_loop >= len(matrix_b[0]):
|
|
return multiply(i_loop + 1, 0, 0, matrix_a, matrix_b, result)
|
|
if k_loop >= len(matrix_b):
|
|
return multiply(i_loop, j_loop + 1, 0, matrix_a, matrix_b, result)
|
|
result[i_loop][j_loop] += matrix_a[i_loop][k_loop] * matrix_b[k_loop][j_loop]
|
|
return multiply(i_loop, j_loop, k_loop + 1, matrix_a, matrix_b, result)
|
|
|
|
# Perform the recursive matrix multiplication
|
|
multiply(0, 0, 0, matrix_a, matrix_b, result)
|
|
return result
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from doctest import testmod
|
|
|
|
failure_count, test_count = testmod()
|
|
if not failure_count:
|
|
matrix_a = matrices[0]
|
|
for matrix_b in matrices[1:]:
|
|
print("Multiplying:")
|
|
for row in matrix_a:
|
|
print(row)
|
|
print("By:")
|
|
for row in matrix_b:
|
|
print(row)
|
|
print("Result:")
|
|
try:
|
|
result = matrix_multiply_recursive(matrix_a, matrix_b)
|
|
for row in result:
|
|
print(row)
|
|
assert result == matrix_multiply(matrix_a, matrix_b)
|
|
except ValueError as e:
|
|
print(f"{e!r}")
|
|
print()
|
|
matrix_a = matrix_b
|
|
|
|
print("Benchmark:")
|
|
from functools import partial
|
|
from timeit import timeit
|
|
|
|
mytimeit = partial(timeit, globals=globals(), number=100_000)
|
|
for func in ("matrix_multiply", "matrix_multiply_recursive"):
|
|
print(f"{func:>25}(): {mytimeit(f'{func}(matrix_count_up, matrix_unordered)')}")
|