Fix ruff check in loss_functions.py

This commit is contained in:
jbsch 2024-10-24 16:31:38 +05:30
parent 0ea341a18b
commit 1ff79750a8
2 changed files with 14 additions and 5 deletions

View File

@ -240,7 +240,7 @@ def ascend_tree(leaf_node: TreeNode, prefix_path: list[str]) -> None:
ascend_tree(leaf_node.parent, prefix_path) ascend_tree(leaf_node.parent, prefix_path)
def find_prefix_path(base_pat: frozenset, tree_node: TreeNode | None) -> dict: # noqa: ARG001 def find_prefix_path(base_pat: frozenset, tree_node: TreeNode | None) -> dict:
""" """
Find the conditional pattern base for a given base pattern. Find the conditional pattern base for a given base pattern.

View File

@ -629,13 +629,15 @@ def smooth_l1_loss(y_true: np.ndarray, y_pred: np.ndarray, beta: float = 1.0) ->
return np.mean(loss) return np.mean(loss)
def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float: def kullback_leibler_divergence(
y_true: np.ndarray, y_pred: np.ndarray, epsilon: float = 1e-10
) -> float:
""" """
Calculate the Kullback-Leibler divergence (KL divergence) loss between true labels Calculate the Kullback-Leibler divergence (KL divergence) loss between true labels
and predicted probabilities. and predicted probabilities.
KL divergence loss quantifies dissimilarity between true labels and predicted KL divergence loss quantifies the dissimilarity between true labels and predicted
probabilities. It's often used in training generative models. probabilities. It is often used in training generative models.
KL = Σ(y_true * ln(y_true / y_pred)) KL = Σ(y_true * ln(y_true / y_pred))
@ -649,6 +651,7 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
>>> predicted_probs = np.array([0.3, 0.3, 0.4]) >>> predicted_probs = np.array([0.3, 0.3, 0.4])
>>> float(kullback_leibler_divergence(true_labels, predicted_probs)) >>> float(kullback_leibler_divergence(true_labels, predicted_probs))
0.030478754035472025 0.030478754035472025
>>> true_labels = np.array([0.2, 0.3, 0.5]) >>> true_labels = np.array([0.2, 0.3, 0.5])
>>> predicted_probs = np.array([0.3, 0.3, 0.4, 0.5]) >>> predicted_probs = np.array([0.3, 0.3, 0.4, 0.5])
>>> kullback_leibler_divergence(true_labels, predicted_probs) >>> kullback_leibler_divergence(true_labels, predicted_probs)
@ -659,7 +662,13 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
if len(y_true) != len(y_pred): if len(y_true) != len(y_pred):
raise ValueError("Input arrays must have the same length.") raise ValueError("Input arrays must have the same length.")
kl_loss = y_true * np.log(y_true / y_pred) # negligible epsilon to avoid issues with log(0) or division by zero
epsilon = 1e-10
y_pred = np.clip(y_pred, epsilon, None)
# calculate KL divergence only where y_true is not zero
kl_loss = np.where(y_true != 0, y_true * np.log(y_true / y_pred), 0.0)
return np.sum(kl_loss) return np.sum(kl_loss)