From 7808f21e1cda505088712c0dd8cf50195a593a22 Mon Sep 17 00:00:00 2001 From: 99991 <99991@users.noreply.github.com> Date: Thu, 10 Oct 2024 08:46:05 +0200 Subject: [PATCH] Rename variables --- maths/cholesky_decomposition.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/maths/cholesky_decomposition.py b/maths/cholesky_decomposition.py index db3b47511..591a3e9df 100644 --- a/maths/cholesky_decomposition.py +++ b/maths/cholesky_decomposition.py @@ -1,7 +1,7 @@ import numpy as np -def cholesky_decomposition(a: np.ndarray) -> np.ndarray: +def cholesky_decomposition(matrix: np.ndarray) -> np.ndarray: """Return a Cholesky decomposition of the matrix A. The Cholesky decomposition decomposes the square, positive definite matrix A @@ -42,11 +42,13 @@ def cholesky_decomposition(a: np.ndarray) -> np.ndarray: True """ - assert a.shape[0] == a.shape[1], f"Matrix A is not square, {a.shape=}" - assert np.allclose(a, a.T), "Matrix A must be symmetric" + assert ( + matrix.shape[0] == matrix.shape[1] + ), f"Input matrix is not square, {matrix.shape=}" + assert np.allclose(matrix, matrix.T), "Input matrix must be symmetric" - n = a.shape[0] - lower_triangle = np.tril(a) + n = matrix.shape[0] + lower_triangle = np.tril(matrix) for i in range(n): for j in range(i + 1): @@ -65,9 +67,13 @@ def cholesky_decomposition(a: np.ndarray) -> np.ndarray: return lower_triangle -def solve_cholesky(lower_triangle: np.ndarray, y: np.ndarray) -> np.ndarray: +def solve_cholesky( + lower_triangle: np.ndarray, + right_hand_side: np.ndarray, +) -> np.ndarray: """Given a Cholesky decomposition L L^T = A of a matrix A, solve the - system of equations A X = Y where Y is either a matrix or a vector. + system of equations A X = Y where the right-hand side Y is either + a matrix or a vector. >>> L = np.array([[2, 0], [3, 4]], dtype=float) >>> Y = np.array([[22, 54], [81, 193]], dtype=float) @@ -84,13 +90,13 @@ def solve_cholesky(lower_triangle: np.ndarray, y: np.ndarray) -> np.ndarray: ), "Matrix L is not lower triangular" # Handle vector case by reshaping to matrix and then flattening again - if len(y.shape) == 1: - return solve_cholesky(lower_triangle, y.reshape(-1, 1)).ravel() + if len(right_hand_side.shape) == 1: + return solve_cholesky(lower_triangle, right_hand_side.reshape(-1, 1)).ravel() - n = y.shape[0] + n = right_hand_side.shape[0] - # Solve L W = B for W - w = y.copy() + # Solve L W = Y for W + w = right_hand_side.copy() for i in range(n): for j in range(i): w[i] -= lower_triangle[i, j] * w[j]