From 307ce1cd7f8cc999d76140574c6798df3d874b4f Mon Sep 17 00:00:00 2001 From: 99991 <99991@users.noreply.github.com> Date: Wed, 9 Oct 2024 08:32:28 +0200 Subject: [PATCH] Simplify equations, rename variables --- maths/cholesky_decomposition.py | 58 ++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/maths/cholesky_decomposition.py b/maths/cholesky_decomposition.py index fe0820f9a..363436b40 100644 --- a/maths/cholesky_decomposition.py +++ b/maths/cholesky_decomposition.py @@ -1,7 +1,8 @@ import numpy as np -def cholesky_decomposition(a: np.ndarray) -> np.ndarray: +# ruff: noqa: N803,N806 +def cholesky_decomposition(A: np.ndarray) -> np.ndarray: """Return a Cholesky decomposition of the matrix A. The Cholesky decomposition decomposes the square, positive definite matrix A @@ -41,25 +42,28 @@ def cholesky_decomposition(a: np.ndarray) -> np.ndarray: >>> np.allclose(X, X_true) True """ - assert a.shape[0] == a.shape[1] - n = a.shape[0] - lo = np.tril(a) + + assert A.shape[0] == A.shape[1], f"A is not square, {A.shape=}" + + n = A.shape[0] + L = np.tril(A) for i in range(n): - for j in range(i): - lo[i, j] = (lo[i, j] - np.sum(lo[i, :j] * lo[j, :j])) / lo[j, j] + for j in range(i + 1): + L[i, j] -= np.sum(L[i, :j] * L[j, :j]) - s = lo[i, i] - np.sum(lo[i, :i] * lo[i, :i]) + if i == j: + if L[i, i] <= 0: + raise ValueError("Matrix A is not positive definite") - if s <= 0: - raise ValueError("Matrix A is not positive definite") + L[i, i] = np.sqrt(L[i, i]) + else: + L[i, j] /= L[j, j] - lo[i, i] = np.sqrt(s) - - return lo + return L -def solve_cholesky(lo: np.ndarray, y: np.ndarray) -> np.ndarray: +def solve_cholesky(L: np.ndarray, Y: np.ndarray) -> np.ndarray: """Given a Cholesky decomposition L L^T = A of a matrix A, solve the system of equations A X = Y where B is either a matrix or a vector. @@ -70,30 +74,32 @@ def solve_cholesky(lo: np.ndarray, y: np.ndarray) -> np.ndarray: True """ + assert L.shape[0] == L.shape[1], f"L is not square, {L.shape=}" + assert np.allclose(np.tril(L), L), "L is not lower triangular" + # Handle vector case by reshaping to matrix and then flattening again - if len(y.shape) == 1: - return solve_cholesky(lo, y.reshape(-1, 1)).ravel() + if len(Y.shape) == 1: + return solve_cholesky(L, Y.reshape(-1, 1)).ravel() - n, m = y.shape + n = Y.shape[0] - # Backsubstitute L X = B - x = y.copy() + # Solve L W = B for W + W = Y.copy() for i in range(n): for j in range(i): - x[i, :] -= lo[i, j] * x[j, :] + W[i] -= L[i, j] * W[j] - for k in range(m): - x[i, k] /= lo[i, i] + W[i] /= L[i, i] - # Backsubstitute L^T + # Solve L^T X = W for X + X = W for i in reversed(range(n)): for j in range(i + 1, n): - x[i, :] -= lo[j, i] * x[j, :] + X[i] -= L[j, i] * X[j] - for k in range(m): - x[i, k] /= lo[i, i] + X[i] /= L[i, i] - return x + return X if __name__ == "__main__":