Rename variables

This commit is contained in:
99991 2024-10-10 08:00:18 +02:00
parent 818448b05d
commit 4522258980

View File

@ -1,8 +1,7 @@
import numpy as np import numpy as np
# ruff: noqa: N803,N806 def cholesky_decomposition(a: np.ndarray) -> np.ndarray:
def cholesky_decomposition(A: np.ndarray) -> np.ndarray:
"""Return a Cholesky decomposition of the matrix A. """Return a Cholesky decomposition of the matrix A.
The Cholesky decomposition decomposes the square, positive definite matrix A The Cholesky decomposition decomposes the square, positive definite matrix A
@ -26,7 +25,7 @@ def cholesky_decomposition(A: np.ndarray) -> np.ndarray:
>>> np.allclose(np.tril(L), L) >>> np.allclose(np.tril(L), L)
True True
The Cholesky decomposition can be used to solve the system of equations A x = y. The Cholesky decomposition can be used to solve the linear system A x = y.
>>> x_true = np.array([1, 2, 3], dtype=float) >>> x_true = np.array([1, 2, 3], dtype=float)
>>> y = A @ x_true >>> y = A @ x_true
@ -43,28 +42,30 @@ def cholesky_decomposition(A: np.ndarray) -> np.ndarray:
True True
""" """
assert A.shape[0] == A.shape[1], f"Matrix A is not square, {A.shape=}" assert a.shape[0] == a.shape[1], f"Matrix A is not square, {a.shape=}"
assert np.allclose(A, A.T), "Matrix A must be symmetric" assert np.allclose(a, a.T), "Matrix A must be symmetric"
n = A.shape[0] n = a.shape[0]
L = np.tril(A) lower_triangle = np.tril(a)
for i in range(n): for i in range(n):
for j in range(i + 1): for j in range(i + 1):
L[i, j] -= np.sum(L[i, :j] * L[j, :j]) lower_triangle[i, j] -= np.sum(
lower_triangle[i, :j] * lower_triangle[j, :j]
)
if i == j: if i == j:
if L[i, i] <= 0: if lower_triangle[i, i] <= 0:
raise ValueError("Matrix A is not positive definite") raise ValueError("Matrix A is not positive definite")
L[i, i] = np.sqrt(L[i, i]) lower_triangle[i, i] = np.sqrt(lower_triangle[i, i])
else: else:
L[i, j] /= L[j, j] lower_triangle[i, j] /= lower_triangle[j, j]
return L return lower_triangle
def solve_cholesky(L: np.ndarray, Y: np.ndarray) -> np.ndarray: def solve_cholesky(lower_triangle: np.ndarray, y: np.ndarray) -> np.ndarray:
"""Given a Cholesky decomposition L L^T = A of a matrix A, solve the """Given a Cholesky decomposition L L^T = A of a matrix A, solve the
system of equations A X = Y where Y is either a matrix or a vector. system of equations A X = Y where Y is either a matrix or a vector.
@ -75,32 +76,36 @@ def solve_cholesky(L: np.ndarray, Y: np.ndarray) -> np.ndarray:
True True
""" """
assert L.shape[0] == L.shape[1], f"Matrix L is not square, {L.shape=}" assert (
assert np.allclose(np.tril(L), L), "Matrix L is not lower triangular" lower_triangle.shape[0] == lower_triangle.shape[1]
), f"Matrix L is not square, {lower_triangle.shape=}"
assert np.allclose(
np.tril(lower_triangle), lower_triangle
), "Matrix L is not lower triangular"
# Handle vector case by reshaping to matrix and then flattening again # Handle vector case by reshaping to matrix and then flattening again
if len(Y.shape) == 1: if len(y.shape) == 1:
return solve_cholesky(L, Y.reshape(-1, 1)).ravel() return solve_cholesky(lower_triangle, y.reshape(-1, 1)).ravel()
n = Y.shape[0] n = y.shape[0]
# Solve L W = B for W # Solve L W = B for W
W = Y.copy() w = y.copy()
for i in range(n): for i in range(n):
for j in range(i): for j in range(i):
W[i] -= L[i, j] * W[j] w[i] -= lower_triangle[i, j] * w[j]
W[i] /= L[i, i] w[i] /= lower_triangle[i, i]
# Solve L^T X = W for X # Solve L^T X = W for X
X = W x = w
for i in reversed(range(n)): for i in reversed(range(n)):
for j in range(i + 1, n): for j in range(i + 1, n):
X[i] -= L[j, i] * X[j] x[i] -= lower_triangle[j, i] * x[j]
X[i] /= L[i, i] x[i] /= lower_triangle[i, i]
return X return x
if __name__ == "__main__": if __name__ == "__main__":