diff --git a/maths/cholesky_decomposition.py b/maths/cholesky_decomposition.py index d7fa5b45c..db3b47511 100644 --- a/maths/cholesky_decomposition.py +++ b/maths/cholesky_decomposition.py @@ -1,8 +1,7 @@ import numpy as np -# ruff: noqa: N803,N806 -def cholesky_decomposition(A: np.ndarray) -> np.ndarray: +def cholesky_decomposition(a: np.ndarray) -> np.ndarray: """Return a Cholesky decomposition of the matrix A. The Cholesky decomposition decomposes the square, positive definite matrix A @@ -26,7 +25,7 @@ def cholesky_decomposition(A: np.ndarray) -> np.ndarray: >>> np.allclose(np.tril(L), L) True - The Cholesky decomposition can be used to solve the system of equations A x = y. + The Cholesky decomposition can be used to solve the linear system A x = y. >>> x_true = np.array([1, 2, 3], dtype=float) >>> y = A @ x_true @@ -43,28 +42,30 @@ def cholesky_decomposition(A: np.ndarray) -> np.ndarray: True """ - assert A.shape[0] == A.shape[1], f"Matrix A is not square, {A.shape=}" - assert np.allclose(A, A.T), "Matrix A must be symmetric" + assert a.shape[0] == a.shape[1], f"Matrix A is not square, {a.shape=}" + assert np.allclose(a, a.T), "Matrix A must be symmetric" - n = A.shape[0] - L = np.tril(A) + n = a.shape[0] + lower_triangle = np.tril(a) for i in range(n): for j in range(i + 1): - L[i, j] -= np.sum(L[i, :j] * L[j, :j]) + lower_triangle[i, j] -= np.sum( + lower_triangle[i, :j] * lower_triangle[j, :j] + ) if i == j: - if L[i, i] <= 0: + if lower_triangle[i, i] <= 0: raise ValueError("Matrix A is not positive definite") - L[i, i] = np.sqrt(L[i, i]) + lower_triangle[i, i] = np.sqrt(lower_triangle[i, i]) else: - L[i, j] /= L[j, j] + lower_triangle[i, j] /= lower_triangle[j, j] - return L + return lower_triangle -def solve_cholesky(L: np.ndarray, Y: np.ndarray) -> np.ndarray: +def solve_cholesky(lower_triangle: np.ndarray, y: np.ndarray) -> np.ndarray: """Given a Cholesky decomposition L L^T = A of a matrix A, solve the system of equations A X = Y where Y is either a matrix or a vector. @@ -75,32 +76,36 @@ def solve_cholesky(L: np.ndarray, Y: np.ndarray) -> np.ndarray: True """ - assert L.shape[0] == L.shape[1], f"Matrix L is not square, {L.shape=}" - assert np.allclose(np.tril(L), L), "Matrix L is not lower triangular" + assert ( + lower_triangle.shape[0] == lower_triangle.shape[1] + ), f"Matrix L is not square, {lower_triangle.shape=}" + assert np.allclose( + np.tril(lower_triangle), lower_triangle + ), "Matrix L is not lower triangular" # Handle vector case by reshaping to matrix and then flattening again - if len(Y.shape) == 1: - return solve_cholesky(L, Y.reshape(-1, 1)).ravel() + if len(y.shape) == 1: + return solve_cholesky(lower_triangle, y.reshape(-1, 1)).ravel() - n = Y.shape[0] + n = y.shape[0] # Solve L W = B for W - W = Y.copy() + w = y.copy() for i in range(n): for j in range(i): - W[i] -= L[i, j] * W[j] + w[i] -= lower_triangle[i, j] * w[j] - W[i] /= L[i, i] + w[i] /= lower_triangle[i, i] # Solve L^T X = W for X - X = W + x = w for i in reversed(range(n)): for j in range(i + 1, n): - X[i] -= L[j, i] * X[j] + x[i] -= lower_triangle[j, i] * x[j] - X[i] /= L[i, i] + x[i] /= lower_triangle[i, i] - return X + return x if __name__ == "__main__":