From 792ee57123db193c6ba4de182a13df25da6f3afe Mon Sep 17 00:00:00 2001 From: ivanz-thinkpad Date: Fri, 31 Jan 2025 23:05:22 +0300 Subject: [PATCH] optimize split_matrix function by removing duplicate code to the extract_submatrix function, add tests --- .../strassen_matrix_multiplication.py | 20 +++--- divide_and_conquer/tests/__init__.py | 0 .../test_strassen_matrix_multiplication.py | 61 +++++++++++++++++++ 3 files changed, 72 insertions(+), 9 deletions(-) create mode 100644 divide_and_conquer/tests/__init__.py create mode 100644 divide_and_conquer/tests/test_strassen_matrix_multiplication.py diff --git a/divide_and_conquer/strassen_matrix_multiplication.py b/divide_and_conquer/strassen_matrix_multiplication.py index f529a255d..78c2e56fa 100644 --- a/divide_and_conquer/strassen_matrix_multiplication.py +++ b/divide_and_conquer/strassen_matrix_multiplication.py @@ -49,18 +49,20 @@ def split_matrix(a: list) -> tuple[list, list, list, list]: if len(a) % 2 != 0 or len(a[0]) % 2 != 0: raise Exception("Odd matrices are not supported!") - matrix_length = len(a) - mid = matrix_length // 2 + def extract_submatrix(rows, cols): + return [[a[i][j] for j in cols] for i in rows] - top_right = [[a[i][j] for j in range(mid, matrix_length)] for i in range(mid)] - bot_right = [ - [a[i][j] for j in range(mid, matrix_length)] for i in range(mid, matrix_length) - ] + mid = len(a) // 2 - top_left = [[a[i][j] for j in range(mid)] for i in range(mid)] - bot_left = [[a[i][j] for j in range(mid)] for i in range(mid, matrix_length)] + rows_top, rows_bot = range(mid), range(mid, len(a)) + cols_left, cols_right = range(mid), range(mid, len(a)) - return top_left, top_right, bot_left, bot_right + return ( + extract_submatrix(rows_top, cols_left), # Top-left + extract_submatrix(rows_top, cols_right), # Top-right + extract_submatrix(rows_bot, cols_left), # Bottom-left + extract_submatrix(rows_bot, cols_right), # Bottom-right + ) def matrix_dimensions(matrix: list) -> tuple[int, int]: diff --git a/divide_and_conquer/tests/__init__.py b/divide_and_conquer/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/divide_and_conquer/tests/test_strassen_matrix_multiplication.py b/divide_and_conquer/tests/test_strassen_matrix_multiplication.py new file mode 100644 index 000000000..1a6073f8d --- /dev/null +++ b/divide_and_conquer/tests/test_strassen_matrix_multiplication.py @@ -0,0 +1,61 @@ +import unittest +from strassen_matrix_multiplication import split_matrix + + +class TestSplitMatrix(unittest.TestCase): + + def test_4x4_matrix(self): + matrix = [ + [4, 3, 2, 4], + [2, 3, 1, 1], + [6, 5, 4, 3], + [8, 4, 1, 6] + ] + expected = ( + [[4, 3], [2, 3]], + [[2, 4], [1, 1]], + [[6, 5], [8, 4]], + [[4, 3], [1, 6]] + ) + self.assertEqual(split_matrix(matrix), expected) + + def test_8x8_matrix(self): + matrix = [ + [4, 3, 2, 4, 4, 3, 2, 4], + [2, 3, 1, 1, 2, 3, 1, 1], + [6, 5, 4, 3, 6, 5, 4, 3], + [8, 4, 1, 6, 8, 4, 1, 6], + [4, 3, 2, 4, 4, 3, 2, 4], + [2, 3, 1, 1, 2, 3, 1, 1], + [6, 5, 4, 3, 6, 5, 4, 3], + [8, 4, 1, 6, 8, 4, 1, 6] + ] + expected = ( + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]] + ) + self.assertEqual(split_matrix(matrix), expected) + + def test_invalid_odd_matrix(self): + matrix = [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9] + ] + with self.assertRaises(Exception): + split_matrix(matrix) + + def test_invalid_non_square_matrix(self): + matrix = [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12] + ] + with self.assertRaises(Exception): + split_matrix(matrix) + + +if __name__ == "__main__": + unittest.main()