Add perplexity loss algorithm (#11028)

This commit is contained in:
Poojan Smart 2023-10-27 20:14:33 +05:30 committed by GitHub
parent 34eb9c529a
commit e4eda14583
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -379,6 +379,98 @@ def mean_absolute_percentage_error(
return np.mean(absolute_percentage_diff)
def perplexity_loss(
y_true: np.ndarray, y_pred: np.ndarray, epsilon: float = 1e-7
) -> float:
"""
Calculate the perplexity for the y_true and y_pred.
Compute the Perplexity which useful in predicting language model
accuracy in Natural Language Processing (NLP.)
Perplexity is measure of how certain the model in its predictions.
Perplexity Loss = exp(-1/N (Σ ln(p(x)))
Reference:
https://en.wikipedia.org/wiki/Perplexity
Args:
y_true: Actual label encoded sentences of shape (batch_size, sentence_length)
y_pred: Predicted sentences of shape (batch_size, sentence_length, vocab_size)
epsilon: Small floating point number to avoid getting inf for log(0)
Returns:
Perplexity loss between y_true and y_pred.
>>> y_true = np.array([[1, 4], [2, 3]])
>>> y_pred = np.array(
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
... [0.24, 0.19, 0.09, 0.18, 0.27]],
... [[0.03, 0.26, 0.21, 0.18, 0.30],
... [0.28, 0.10, 0.33, 0.15, 0.12]]]
... )
>>> perplexity_loss(y_true, y_pred)
5.0247347775367945
>>> y_true = np.array([[1, 4], [2, 3]])
>>> y_pred = np.array(
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
... [0.24, 0.19, 0.09, 0.18, 0.27],
... [0.30, 0.10, 0.20, 0.15, 0.25]],
... [[0.03, 0.26, 0.21, 0.18, 0.30],
... [0.28, 0.10, 0.33, 0.15, 0.12],
... [0.30, 0.10, 0.20, 0.15, 0.25]],]
... )
>>> perplexity_loss(y_true, y_pred)
Traceback (most recent call last):
...
ValueError: Sentence length of y_true and y_pred must be equal.
>>> y_true = np.array([[1, 4], [2, 11]])
>>> y_pred = np.array(
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
... [0.24, 0.19, 0.09, 0.18, 0.27]],
... [[0.03, 0.26, 0.21, 0.18, 0.30],
... [0.28, 0.10, 0.33, 0.15, 0.12]]]
... )
>>> perplexity_loss(y_true, y_pred)
Traceback (most recent call last):
...
ValueError: Label value must not be greater than vocabulary size.
>>> y_true = np.array([[1, 4]])
>>> y_pred = np.array(
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
... [0.24, 0.19, 0.09, 0.18, 0.27]],
... [[0.03, 0.26, 0.21, 0.18, 0.30],
... [0.28, 0.10, 0.33, 0.15, 0.12]]]
... )
>>> perplexity_loss(y_true, y_pred)
Traceback (most recent call last):
...
ValueError: Batch size of y_true and y_pred must be equal.
"""
vocab_size = y_pred.shape[2]
if y_true.shape[0] != y_pred.shape[0]:
raise ValueError("Batch size of y_true and y_pred must be equal.")
if y_true.shape[1] != y_pred.shape[1]:
raise ValueError("Sentence length of y_true and y_pred must be equal.")
if np.max(y_true) > vocab_size:
raise ValueError("Label value must not be greater than vocabulary size.")
# Matrix to select prediction value only for true class
filter_matrix = np.array(
[[list(np.eye(vocab_size)[word]) for word in sentence] for sentence in y_true]
)
# Getting the matrix containing prediction for only true class
true_class_pred = np.sum(y_pred * filter_matrix, axis=2).clip(epsilon, 1)
# Calculating perplexity for each sentence
perp_losses = np.exp(np.negative(np.mean(np.log(true_class_pred), axis=1)))
return np.mean(perp_losses)
if __name__ == "__main__":
import doctest