mirror of
https://github.com/TheAlgorithms/Python.git
synced 2024-11-23 21:11:08 +00:00
Add KL divergence loss algorithm (#11238)
* Add KL divergence loss algorithm * Apply suggestions from code review --------- Co-authored-by: Tianyi Zheng <tianyizheng02@gmail.com>
This commit is contained in:
parent
ffaa976f6c
commit
c919579869
|
@ -629,6 +629,40 @@ def smooth_l1_loss(y_true: np.ndarray, y_pred: np.ndarray, beta: float = 1.0) ->
|
||||||
return np.mean(loss)
|
return np.mean(loss)
|
||||||
|
|
||||||
|
|
||||||
|
def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||||
|
"""
|
||||||
|
Calculate the Kullback-Leibler divergence (KL divergence) loss between true labels
|
||||||
|
and predicted probabilities.
|
||||||
|
|
||||||
|
KL divergence loss quantifies dissimilarity between true labels and predicted
|
||||||
|
probabilities. It's often used in training generative models.
|
||||||
|
|
||||||
|
KL = Σ(y_true * ln(y_true / y_pred))
|
||||||
|
|
||||||
|
Reference: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- y_true: True class probabilities
|
||||||
|
- y_pred: Predicted class probabilities
|
||||||
|
|
||||||
|
>>> true_labels = np.array([0.2, 0.3, 0.5])
|
||||||
|
>>> predicted_probs = np.array([0.3, 0.3, 0.4])
|
||||||
|
>>> kullback_leibler_divergence(true_labels, predicted_probs)
|
||||||
|
0.030478754035472025
|
||||||
|
>>> true_labels = np.array([0.2, 0.3, 0.5])
|
||||||
|
>>> predicted_probs = np.array([0.3, 0.3, 0.4, 0.5])
|
||||||
|
>>> kullback_leibler_divergence(true_labels, predicted_probs)
|
||||||
|
Traceback (most recent call last):
|
||||||
|
...
|
||||||
|
ValueError: Input arrays must have the same length.
|
||||||
|
"""
|
||||||
|
if len(y_true) != len(y_pred):
|
||||||
|
raise ValueError("Input arrays must have the same length.")
|
||||||
|
|
||||||
|
kl_loss = y_true * np.log(y_true / y_pred)
|
||||||
|
return np.sum(kl_loss)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import doctest
|
import doctest
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user