mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-02-23 17:38:39 +00:00
optimize split_matrix function by removing duplicate code to the extract_submatrix function, add tests
This commit is contained in:
parent
6c92c5a539
commit
792ee57123
@ -49,18 +49,20 @@ def split_matrix(a: list) -> tuple[list, list, list, list]:
|
|||||||
if len(a) % 2 != 0 or len(a[0]) % 2 != 0:
|
if len(a) % 2 != 0 or len(a[0]) % 2 != 0:
|
||||||
raise Exception("Odd matrices are not supported!")
|
raise Exception("Odd matrices are not supported!")
|
||||||
|
|
||||||
matrix_length = len(a)
|
def extract_submatrix(rows, cols):
|
||||||
mid = matrix_length // 2
|
return [[a[i][j] for j in cols] for i in rows]
|
||||||
|
|
||||||
top_right = [[a[i][j] for j in range(mid, matrix_length)] for i in range(mid)]
|
mid = len(a) // 2
|
||||||
bot_right = [
|
|
||||||
[a[i][j] for j in range(mid, matrix_length)] for i in range(mid, matrix_length)
|
|
||||||
]
|
|
||||||
|
|
||||||
top_left = [[a[i][j] for j in range(mid)] for i in range(mid)]
|
rows_top, rows_bot = range(mid), range(mid, len(a))
|
||||||
bot_left = [[a[i][j] for j in range(mid)] for i in range(mid, matrix_length)]
|
cols_left, cols_right = range(mid), range(mid, len(a))
|
||||||
|
|
||||||
return top_left, top_right, bot_left, bot_right
|
return (
|
||||||
|
extract_submatrix(rows_top, cols_left), # Top-left
|
||||||
|
extract_submatrix(rows_top, cols_right), # Top-right
|
||||||
|
extract_submatrix(rows_bot, cols_left), # Bottom-left
|
||||||
|
extract_submatrix(rows_bot, cols_right), # Bottom-right
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def matrix_dimensions(matrix: list) -> tuple[int, int]:
|
def matrix_dimensions(matrix: list) -> tuple[int, int]:
|
||||||
|
0
divide_and_conquer/tests/__init__.py
Normal file
0
divide_and_conquer/tests/__init__.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import unittest
|
||||||
|
from strassen_matrix_multiplication import split_matrix
|
||||||
|
|
||||||
|
|
||||||
|
class TestSplitMatrix(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_4x4_matrix(self):
|
||||||
|
matrix = [
|
||||||
|
[4, 3, 2, 4],
|
||||||
|
[2, 3, 1, 1],
|
||||||
|
[6, 5, 4, 3],
|
||||||
|
[8, 4, 1, 6]
|
||||||
|
]
|
||||||
|
expected = (
|
||||||
|
[[4, 3], [2, 3]],
|
||||||
|
[[2, 4], [1, 1]],
|
||||||
|
[[6, 5], [8, 4]],
|
||||||
|
[[4, 3], [1, 6]]
|
||||||
|
)
|
||||||
|
self.assertEqual(split_matrix(matrix), expected)
|
||||||
|
|
||||||
|
def test_8x8_matrix(self):
|
||||||
|
matrix = [
|
||||||
|
[4, 3, 2, 4, 4, 3, 2, 4],
|
||||||
|
[2, 3, 1, 1, 2, 3, 1, 1],
|
||||||
|
[6, 5, 4, 3, 6, 5, 4, 3],
|
||||||
|
[8, 4, 1, 6, 8, 4, 1, 6],
|
||||||
|
[4, 3, 2, 4, 4, 3, 2, 4],
|
||||||
|
[2, 3, 1, 1, 2, 3, 1, 1],
|
||||||
|
[6, 5, 4, 3, 6, 5, 4, 3],
|
||||||
|
[8, 4, 1, 6, 8, 4, 1, 6]
|
||||||
|
]
|
||||||
|
expected = (
|
||||||
|
[[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]],
|
||||||
|
[[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]],
|
||||||
|
[[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]],
|
||||||
|
[[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]]
|
||||||
|
)
|
||||||
|
self.assertEqual(split_matrix(matrix), expected)
|
||||||
|
|
||||||
|
def test_invalid_odd_matrix(self):
|
||||||
|
matrix = [
|
||||||
|
[1, 2, 3],
|
||||||
|
[4, 5, 6],
|
||||||
|
[7, 8, 9]
|
||||||
|
]
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
split_matrix(matrix)
|
||||||
|
|
||||||
|
def test_invalid_non_square_matrix(self):
|
||||||
|
matrix = [
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 8],
|
||||||
|
[9, 10, 11, 12]
|
||||||
|
]
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
split_matrix(matrix)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user