mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-03-30 18:36:43 +00:00
fix:Added condition array for filtering 0 values
This commit is contained in:
parent
6e24935f88
commit
a6ddea0f8e
@ -645,6 +645,11 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
|
||||
- y_true: True class probabilities
|
||||
- y_pred: Predicted class probabilities
|
||||
|
||||
>>> true_labels = np.array([0, 0.4, 0.6])
|
||||
>>> predicted_probs = np.array([0.3, 0.3, 0.4])
|
||||
>>> float(kullback_leibler_divergence(true_labels, predicted_probs))
|
||||
0.35835189384561095
|
||||
|
||||
>>> true_labels = np.array([0.2, 0.3, 0.5])
|
||||
>>> predicted_probs = np.array([0.3, 0.3, 0.4])
|
||||
>>> float(kullback_leibler_divergence(true_labels, predicted_probs))
|
||||
@ -659,6 +664,9 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
|
||||
if len(y_true) != len(y_pred):
|
||||
raise ValueError("Input arrays must have the same length.")
|
||||
|
||||
filter_array = y_true != 0
|
||||
y_true = y_true[filter_array]
|
||||
y_pred = y_pred[filter_array]
|
||||
kl_loss = y_true * np.log(y_true / y_pred)
|
||||
return np.sum(kl_loss)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user