diff --git a/divide_and_conquer/tests/test_strassen_matrix_multiplication.py b/divide_and_conquer/tests/test_strassen_matrix_multiplication.py index 1a6073f8d..240096caa 100644 --- a/divide_and_conquer/tests/test_strassen_matrix_multiplication.py +++ b/divide_and_conquer/tests/test_strassen_matrix_multiplication.py @@ -3,19 +3,13 @@ from strassen_matrix_multiplication import split_matrix class TestSplitMatrix(unittest.TestCase): - def test_4x4_matrix(self): - matrix = [ - [4, 3, 2, 4], - [2, 3, 1, 1], - [6, 5, 4, 3], - [8, 4, 1, 6] - ] + matrix = [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]] expected = ( [[4, 3], [2, 3]], [[2, 4], [1, 1]], [[6, 5], [8, 4]], - [[4, 3], [1, 6]] + [[4, 3], [1, 6]], ) self.assertEqual(split_matrix(matrix), expected) @@ -28,31 +22,23 @@ class TestSplitMatrix(unittest.TestCase): [4, 3, 2, 4, 4, 3, 2, 4], [2, 3, 1, 1, 2, 3, 1, 1], [6, 5, 4, 3, 6, 5, 4, 3], - [8, 4, 1, 6, 8, 4, 1, 6] + [8, 4, 1, 6, 8, 4, 1, 6], ] expected = ( [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], - [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]] + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], ) self.assertEqual(split_matrix(matrix), expected) def test_invalid_odd_matrix(self): - matrix = [ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9] - ] + matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] with self.assertRaises(Exception): split_matrix(matrix) def test_invalid_non_square_matrix(self): - matrix = [ - [1, 2, 3, 4], - [5, 6, 7, 8], - [9, 10, 11, 12] - ] + matrix = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]] with self.assertRaises(Exception): split_matrix(matrix)