Enforce symmetry on A

This commit is contained in:
99991 2024-10-09 08:42:07 +02:00
parent 307ce1cd7f
commit 907b783668

View File

@ -43,7 +43,8 @@ def cholesky_decomposition(A: np.ndarray) -> np.ndarray:
True
"""
assert A.shape[0] == A.shape[1], f"A is not square, {A.shape=}"
assert A.shape[0] == A.shape[1], f"Matrix A is not square, {A.shape=}"
assert np.allclose(A, A.T), "Matrix A must be symmetric"
n = A.shape[0]
L = np.tril(A)
@ -74,8 +75,8 @@ def solve_cholesky(L: np.ndarray, Y: np.ndarray) -> np.ndarray:
True
"""
assert L.shape[0] == L.shape[1], f"L is not square, {L.shape=}"
assert np.allclose(np.tril(L), L), "L is not lower triangular"
assert L.shape[0] == L.shape[1], f"Matrix L is not square, {L.shape=}"
assert np.allclose(np.tril(L), L), "Matrix L is not lower triangular"
# Handle vector case by reshaping to matrix and then flattening again
if len(Y.shape) == 1: