diff --git a/maths/cholesky_decomposition.py b/maths/cholesky_decomposition.py index 363436b40..a17a5c2a7 100644 --- a/maths/cholesky_decomposition.py +++ b/maths/cholesky_decomposition.py @@ -43,7 +43,8 @@ def cholesky_decomposition(A: np.ndarray) -> np.ndarray: True """ - assert A.shape[0] == A.shape[1], f"A is not square, {A.shape=}" + assert A.shape[0] == A.shape[1], f"Matrix A is not square, {A.shape=}" + assert np.allclose(A, A.T), "Matrix A must be symmetric" n = A.shape[0] L = np.tril(A) @@ -74,8 +75,8 @@ def solve_cholesky(L: np.ndarray, Y: np.ndarray) -> np.ndarray: True """ - assert L.shape[0] == L.shape[1], f"L is not square, {L.shape=}" - assert np.allclose(np.tril(L), L), "L is not lower triangular" + assert L.shape[0] == L.shape[1], f"Matrix L is not square, {L.shape=}" + assert np.allclose(np.tril(L), L), "Matrix L is not lower triangular" # Handle vector case by reshaping to matrix and then flattening again if len(Y.shape) == 1: