Add sparse categorical cross entropy loss algorithm

This commit is contained in:
AtomicVar 2024-01-14 17:33:18 +08:00
parent 3952ba703a
commit 866cc8b95a

View File

@ -148,6 +148,58 @@ def categorical_cross_entropy(
return -np.sum(y_true * np.log(y_pred))
def sparse_categorical_cross_entropy(
y_true: np.ndarray, y_pred: np.ndarray, epsilon: float = 1e-15
) -> float:
"""
Calculate sparse categorical cross entropy (SCCE) loss between true class labels and
predicted class probabilities. SCCE is used in cases where the true class labels are
represented in a sparse matrix format.
SCCE = -Σ(ln(y_pred[i, y_true[i]]))
Reference: https://en.wikipedia.org/wiki/Cross_entropy
Parameters:
- y_true: True class labels containing class indices.
- y_pred: Predicted class probabilities.
- epsilon: Small constant to avoid numerical instability.
>>> true_labels = np.array([0, 1, 2])
>>> pred_probs = np.array([[0.9, 0.1, 0.0], [0.2, 0.7, 0.1], [0.0, 0.1, 0.9]])
>>> sparse_categorical_cross_entropy(true_labels, pred_probs)
0.567395975254385
>>> true_labels = np.array([1, 2])
>>> pred_probs = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
>>> sparse_categorical_cross_entropy(true_labels, pred_probs)
2.353878387381596
>>> true_labels = np.array([1, 5])
>>> pred_probs = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
>>> sparse_categorical_cross_entropy(true_labels, pred_probs)
Traceback (most recent call last):
...
ValueError: Class labels in y_true are out of range.
>>> true_labels = np.array([1, 2])
>>> pred_probs = np.array([[0.05, 0.95, 0.1], [0.1, 0.8, 0.1]])
>>> sparse_categorical_cross_entropy(true_labels, pred_probs)
Traceback (most recent call last):
...
ValueError: Predicted probabilities must sum to approximately 1.
"""
if np.any(y_true >= y_pred.shape[1]) or np.any(y_true < 0):
raise ValueError("Class labels in y_true are out of range.")
if not np.all(np.isclose(np.sum(y_pred, axis=1), 1, rtol=epsilon, atol=epsilon)):
raise ValueError("Predicted probabilities must sum to approximately 1.")
y_pred = np.clip(y_pred, epsilon, 1) # Clip predictions to avoid log(0)
log_preds = np.log(y_pred[np.arange(len(y_pred)), y_true])
return -np.sum(log_preds)
def categorical_focal_cross_entropy(
y_true: np.ndarray,
y_pred: np.ndarray,