mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-04-01 11:26:43 +00:00
fix:Added condition array for filtering 0 values
This commit is contained in:
parent
6e24935f88
commit
a6ddea0f8e
@ -645,6 +645,11 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
|
|||||||
- y_true: True class probabilities
|
- y_true: True class probabilities
|
||||||
- y_pred: Predicted class probabilities
|
- y_pred: Predicted class probabilities
|
||||||
|
|
||||||
|
>>> true_labels = np.array([0, 0.4, 0.6])
|
||||||
|
>>> predicted_probs = np.array([0.3, 0.3, 0.4])
|
||||||
|
>>> float(kullback_leibler_divergence(true_labels, predicted_probs))
|
||||||
|
0.35835189384561095
|
||||||
|
|
||||||
>>> true_labels = np.array([0.2, 0.3, 0.5])
|
>>> true_labels = np.array([0.2, 0.3, 0.5])
|
||||||
>>> predicted_probs = np.array([0.3, 0.3, 0.4])
|
>>> predicted_probs = np.array([0.3, 0.3, 0.4])
|
||||||
>>> float(kullback_leibler_divergence(true_labels, predicted_probs))
|
>>> float(kullback_leibler_divergence(true_labels, predicted_probs))
|
||||||
@ -659,6 +664,9 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
|
|||||||
if len(y_true) != len(y_pred):
|
if len(y_true) != len(y_pred):
|
||||||
raise ValueError("Input arrays must have the same length.")
|
raise ValueError("Input arrays must have the same length.")
|
||||||
|
|
||||||
|
filter_array = y_true != 0
|
||||||
|
y_true = y_true[filter_array]
|
||||||
|
y_pred = y_pred[filter_array]
|
||||||
kl_loss = y_true * np.log(y_true / y_pred)
|
kl_loss = y_true * np.log(y_true / y_pred)
|
||||||
return np.sum(kl_loss)
|
return np.sum(kl_loss)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user