fix:Added condition array for filtering 0 values

This commit is contained in:
evan.zhang5 2024-10-24 18:08:23 +08:00
parent 6e24935f88
commit a6ddea0f8e

View File

@ -645,6 +645,11 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
- y_true: True class probabilities
- y_pred: Predicted class probabilities
>>> true_labels = np.array([0, 0.4, 0.6])
>>> predicted_probs = np.array([0.3, 0.3, 0.4])
>>> float(kullback_leibler_divergence(true_labels, predicted_probs))
0.35835189384561095
>>> true_labels = np.array([0.2, 0.3, 0.5])
>>> predicted_probs = np.array([0.3, 0.3, 0.4])
>>> float(kullback_leibler_divergence(true_labels, predicted_probs))
@ -659,6 +664,9 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
if len(y_true) != len(y_pred):
raise ValueError("Input arrays must have the same length.")
filter_array = y_true != 0
y_true = y_true[filter_array]
y_pred = y_pred[filter_array]
kl_loss = y_true * np.log(y_true / y_pred)
return np.sum(kl_loss)