mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-02-21 08:42:03 +00:00
Resolved log(0) error in KL divergence #12233
This commit is contained in:
parent
6c92c5a539
commit
27fac2c113
|
@ -659,7 +659,9 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
|
|||
if len(y_true) != len(y_pred):
|
||||
raise ValueError("Input arrays must have the same length.")
|
||||
|
||||
kl_loss = y_true * np.log(y_true / y_pred)
|
||||
kl_loss = np.concatenate((y_true[None, :], y_pred[None, :]))
|
||||
kl_loss = kl_loss[:, ~np.any(kl_loss == 0, axis=0)]
|
||||
kl_loss = kl_loss[0] * np.log(kl_loss[0] / kl_loss[1])
|
||||
return np.sum(kl_loss)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user