diff --git a/machine_learning/loss_functions.py b/machine_learning/loss_functions.py index 0bd9aa8b5..980f2ef31 100644 --- a/machine_learning/loss_functions.py +++ b/machine_learning/loss_functions.py @@ -659,7 +659,9 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float if len(y_true) != len(y_pred): raise ValueError("Input arrays must have the same length.") - kl_loss = y_true * np.log(y_true / y_pred) + kl_loss = np.concatenate((y_true[None, :], y_pred[None, :])) + kl_loss = kl_loss[:, ~np.any(kl_loss == 0, axis=0)] + kl_loss = kl_loss[0] * np.log(kl_loss[0] / kl_loss[1]) return np.sum(kl_loss)