mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-02-25 10:28:39 +00:00
Add perplexity loss algorithm (#11028)
This commit is contained in:
parent
34eb9c529a
commit
e4eda14583
@ -379,6 +379,98 @@ def mean_absolute_percentage_error(
|
|||||||
return np.mean(absolute_percentage_diff)
|
return np.mean(absolute_percentage_diff)
|
||||||
|
|
||||||
|
|
||||||
|
def perplexity_loss(
|
||||||
|
y_true: np.ndarray, y_pred: np.ndarray, epsilon: float = 1e-7
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Calculate the perplexity for the y_true and y_pred.
|
||||||
|
|
||||||
|
Compute the Perplexity which useful in predicting language model
|
||||||
|
accuracy in Natural Language Processing (NLP.)
|
||||||
|
Perplexity is measure of how certain the model in its predictions.
|
||||||
|
|
||||||
|
Perplexity Loss = exp(-1/N (Σ ln(p(x)))
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
https://en.wikipedia.org/wiki/Perplexity
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_true: Actual label encoded sentences of shape (batch_size, sentence_length)
|
||||||
|
y_pred: Predicted sentences of shape (batch_size, sentence_length, vocab_size)
|
||||||
|
epsilon: Small floating point number to avoid getting inf for log(0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Perplexity loss between y_true and y_pred.
|
||||||
|
|
||||||
|
>>> y_true = np.array([[1, 4], [2, 3]])
|
||||||
|
>>> y_pred = np.array(
|
||||||
|
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
|
||||||
|
... [0.24, 0.19, 0.09, 0.18, 0.27]],
|
||||||
|
... [[0.03, 0.26, 0.21, 0.18, 0.30],
|
||||||
|
... [0.28, 0.10, 0.33, 0.15, 0.12]]]
|
||||||
|
... )
|
||||||
|
>>> perplexity_loss(y_true, y_pred)
|
||||||
|
5.0247347775367945
|
||||||
|
>>> y_true = np.array([[1, 4], [2, 3]])
|
||||||
|
>>> y_pred = np.array(
|
||||||
|
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
|
||||||
|
... [0.24, 0.19, 0.09, 0.18, 0.27],
|
||||||
|
... [0.30, 0.10, 0.20, 0.15, 0.25]],
|
||||||
|
... [[0.03, 0.26, 0.21, 0.18, 0.30],
|
||||||
|
... [0.28, 0.10, 0.33, 0.15, 0.12],
|
||||||
|
... [0.30, 0.10, 0.20, 0.15, 0.25]],]
|
||||||
|
... )
|
||||||
|
>>> perplexity_loss(y_true, y_pred)
|
||||||
|
Traceback (most recent call last):
|
||||||
|
...
|
||||||
|
ValueError: Sentence length of y_true and y_pred must be equal.
|
||||||
|
>>> y_true = np.array([[1, 4], [2, 11]])
|
||||||
|
>>> y_pred = np.array(
|
||||||
|
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
|
||||||
|
... [0.24, 0.19, 0.09, 0.18, 0.27]],
|
||||||
|
... [[0.03, 0.26, 0.21, 0.18, 0.30],
|
||||||
|
... [0.28, 0.10, 0.33, 0.15, 0.12]]]
|
||||||
|
... )
|
||||||
|
>>> perplexity_loss(y_true, y_pred)
|
||||||
|
Traceback (most recent call last):
|
||||||
|
...
|
||||||
|
ValueError: Label value must not be greater than vocabulary size.
|
||||||
|
>>> y_true = np.array([[1, 4]])
|
||||||
|
>>> y_pred = np.array(
|
||||||
|
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
|
||||||
|
... [0.24, 0.19, 0.09, 0.18, 0.27]],
|
||||||
|
... [[0.03, 0.26, 0.21, 0.18, 0.30],
|
||||||
|
... [0.28, 0.10, 0.33, 0.15, 0.12]]]
|
||||||
|
... )
|
||||||
|
>>> perplexity_loss(y_true, y_pred)
|
||||||
|
Traceback (most recent call last):
|
||||||
|
...
|
||||||
|
ValueError: Batch size of y_true and y_pred must be equal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vocab_size = y_pred.shape[2]
|
||||||
|
|
||||||
|
if y_true.shape[0] != y_pred.shape[0]:
|
||||||
|
raise ValueError("Batch size of y_true and y_pred must be equal.")
|
||||||
|
if y_true.shape[1] != y_pred.shape[1]:
|
||||||
|
raise ValueError("Sentence length of y_true and y_pred must be equal.")
|
||||||
|
if np.max(y_true) > vocab_size:
|
||||||
|
raise ValueError("Label value must not be greater than vocabulary size.")
|
||||||
|
|
||||||
|
# Matrix to select prediction value only for true class
|
||||||
|
filter_matrix = np.array(
|
||||||
|
[[list(np.eye(vocab_size)[word]) for word in sentence] for sentence in y_true]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Getting the matrix containing prediction for only true class
|
||||||
|
true_class_pred = np.sum(y_pred * filter_matrix, axis=2).clip(epsilon, 1)
|
||||||
|
|
||||||
|
# Calculating perplexity for each sentence
|
||||||
|
perp_losses = np.exp(np.negative(np.mean(np.log(true_class_pred), axis=1)))
|
||||||
|
|
||||||
|
return np.mean(perp_losses)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import doctest
|
import doctest
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user