mirror of
https://github.com/TheAlgorithms/Python.git
synced 2024-11-23 21:11:08 +00:00
Add perplexity loss algorithm (#11028)
This commit is contained in:
parent
34eb9c529a
commit
e4eda14583
|
@ -379,6 +379,98 @@ def mean_absolute_percentage_error(
|
|||
return np.mean(absolute_percentage_diff)
|
||||
|
||||
|
||||
def perplexity_loss(
|
||||
y_true: np.ndarray, y_pred: np.ndarray, epsilon: float = 1e-7
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the perplexity for the y_true and y_pred.
|
||||
|
||||
Compute the Perplexity which useful in predicting language model
|
||||
accuracy in Natural Language Processing (NLP.)
|
||||
Perplexity is measure of how certain the model in its predictions.
|
||||
|
||||
Perplexity Loss = exp(-1/N (Σ ln(p(x)))
|
||||
|
||||
Reference:
|
||||
https://en.wikipedia.org/wiki/Perplexity
|
||||
|
||||
Args:
|
||||
y_true: Actual label encoded sentences of shape (batch_size, sentence_length)
|
||||
y_pred: Predicted sentences of shape (batch_size, sentence_length, vocab_size)
|
||||
epsilon: Small floating point number to avoid getting inf for log(0)
|
||||
|
||||
Returns:
|
||||
Perplexity loss between y_true and y_pred.
|
||||
|
||||
>>> y_true = np.array([[1, 4], [2, 3]])
|
||||
>>> y_pred = np.array(
|
||||
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
|
||||
... [0.24, 0.19, 0.09, 0.18, 0.27]],
|
||||
... [[0.03, 0.26, 0.21, 0.18, 0.30],
|
||||
... [0.28, 0.10, 0.33, 0.15, 0.12]]]
|
||||
... )
|
||||
>>> perplexity_loss(y_true, y_pred)
|
||||
5.0247347775367945
|
||||
>>> y_true = np.array([[1, 4], [2, 3]])
|
||||
>>> y_pred = np.array(
|
||||
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
|
||||
... [0.24, 0.19, 0.09, 0.18, 0.27],
|
||||
... [0.30, 0.10, 0.20, 0.15, 0.25]],
|
||||
... [[0.03, 0.26, 0.21, 0.18, 0.30],
|
||||
... [0.28, 0.10, 0.33, 0.15, 0.12],
|
||||
... [0.30, 0.10, 0.20, 0.15, 0.25]],]
|
||||
... )
|
||||
>>> perplexity_loss(y_true, y_pred)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: Sentence length of y_true and y_pred must be equal.
|
||||
>>> y_true = np.array([[1, 4], [2, 11]])
|
||||
>>> y_pred = np.array(
|
||||
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
|
||||
... [0.24, 0.19, 0.09, 0.18, 0.27]],
|
||||
... [[0.03, 0.26, 0.21, 0.18, 0.30],
|
||||
... [0.28, 0.10, 0.33, 0.15, 0.12]]]
|
||||
... )
|
||||
>>> perplexity_loss(y_true, y_pred)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: Label value must not be greater than vocabulary size.
|
||||
>>> y_true = np.array([[1, 4]])
|
||||
>>> y_pred = np.array(
|
||||
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
|
||||
... [0.24, 0.19, 0.09, 0.18, 0.27]],
|
||||
... [[0.03, 0.26, 0.21, 0.18, 0.30],
|
||||
... [0.28, 0.10, 0.33, 0.15, 0.12]]]
|
||||
... )
|
||||
>>> perplexity_loss(y_true, y_pred)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: Batch size of y_true and y_pred must be equal.
|
||||
"""
|
||||
|
||||
vocab_size = y_pred.shape[2]
|
||||
|
||||
if y_true.shape[0] != y_pred.shape[0]:
|
||||
raise ValueError("Batch size of y_true and y_pred must be equal.")
|
||||
if y_true.shape[1] != y_pred.shape[1]:
|
||||
raise ValueError("Sentence length of y_true and y_pred must be equal.")
|
||||
if np.max(y_true) > vocab_size:
|
||||
raise ValueError("Label value must not be greater than vocabulary size.")
|
||||
|
||||
# Matrix to select prediction value only for true class
|
||||
filter_matrix = np.array(
|
||||
[[list(np.eye(vocab_size)[word]) for word in sentence] for sentence in y_true]
|
||||
)
|
||||
|
||||
# Getting the matrix containing prediction for only true class
|
||||
true_class_pred = np.sum(y_pred * filter_matrix, axis=2).clip(epsilon, 1)
|
||||
|
||||
# Calculating perplexity for each sentence
|
||||
perp_losses = np.exp(np.negative(np.mean(np.log(true_class_pred), axis=1)))
|
||||
|
||||
return np.mean(perp_losses)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user