Source code for nnbma.learning.loss_functions

import math
from abc import ABC, abstractmethod

import torch
from torch import nn

LOG10 = math.log(10.0)

__all__ = [
    "MaskedLossFunction",
    "MaskOverlay",
    "MaskedMSELoss",
    "CauchyLoss",
    "SmoothL1Loss",
]

## Masked loss functions


[docs] class MaskedLossFunction(nn.Module, ABC): r"""Implements a masked loss function which has a signature ``loss_fun(y_hat, y, mask)``.""" def __init__(self): r""" """ super().__init__()
[docs] @abstractmethod def forward( y_hat: torch.Tensor, y: torch.Tensor, mask: torch.Tensor ) -> torch.Tensor: r"""Evaluates the loss between a prediction ``y_hat`` and a reference ``y`` with some masked values indicated in ``mask``. Parameters ---------- y_hat : torch.Tensor network prediction. y : torch.Tensor true values. Must have the same shape as ``y_hat``. mask : torch.Tensor binary mask with values to disregard in the loss. Must have the same shape as ``y_hat``. Returns ------- torch.Tensor evaluated loss, should be a float. """ pass
[docs] class MaskOverlay(MaskedLossFunction): r"""Permits to use a MaskedLossFunction as a standard loss function.""" def __init__(self, loss): r""" Parameters ---------- loss : Callable loss function """ super().__init__() self.loss = loss
[docs] def forward(self, y_hat: torch.Tensor, y: torch.Tensor, _) -> torch.Tensor: """Evaluate the loss function. Parameters ---------- y_hat : torch.Tensor network prediction. y : torch.Tensor true values. Must have the same shape as ``y_hat``. Returns ------- torch.Tensor evaluated loss, should be a float. """ return self.loss(y_hat, y)
[docs] class MaskedMSELoss(MaskedLossFunction): r"""Implements the masked MSE loss function, i.e., for a binary mask :math:`m`, a prediction :math:`\widehat{y}` and a true value :math:`y`, .. math:: \mathrm{MaskedMSE}(\widehat{y}, y, m) = \frac{1}{\sum_{i} m_{i}} \sum_{i} m_{i} \left( \widehat{y}_{i} - y_{i} \right)^2 """ def __init__(self): r""" """ super().__init__()
[docs] def forward( self, y_hat: torch.Tensor, y: torch.Tensor, mask: torch.Tensor ) -> torch.Tensor: r"""Evaluates the masked MSE loss between a prediction ``y_hat`` and a reference ``y``, with a binary mask ``m`` -- where :math:`m_{i}=0` corresponds to a masked value. Parameters ---------- y_hat : torch.Tensor network prediction. y : torch.Tensor true values. Must have the same shape as ``y_hat``. mask : torch.Tensor binary mask. Must have the same shape as ``y_hat``. Returns ------- torch.Tensor evaluated loss, should be a float. """ w = mask / mask.sum(dim=0).clip(min=1) return (w * (y_hat - y).square()).sum(dim=0).mean()
# class MaskedWeightedMSELoss(MaskedLossFunction): # def __init__(self, w: List): # super().__init__() # self.w: torch.Tensor = torch.tensor(w).flatten() # def forward( # self, y_hat: torch.Tensor, y: torch.Tensor, mask: torch.Tensor # ) -> torch.Tensor: # w = mask / mask.sum(dim=0).clip(min=1) # return (self.w * w * (y_hat - y).square()).sum(dim=0).mean() # class MaskedMAELoss(MaskedLossFunction): # def __init__(self): # super().__init__() # def forward(self, y_hat: torch.Tensor, y: torch.Tensor, mask: torch.Tensor): # w = mask / mask.sum(dim=0).clip(min=1) # return (w * (y_hat - y).abs()).sum(dim=0).mean() # class MaskedPowerLoss(MaskedLossFunction): # degree: int # def __init__(self, degree: int): # super().__init__() # self.degree = degree # def forward(self, y_hat: torch.Tensor, y: torch.Tensor, mask: torch.Tensor): # w = mask / mask.sum(dim=0).clip(min=1) # return (w * (y_hat - y).abs() ** self.degree).sum(dim=0).mean() # class MaskedSeriesLoss(MaskedLossFunction): # def __init__(self, order: int): # super().__init__() # self.order = order # def forward( # self, y_hat: torch.Tensor, y: torch.Tensor, mask: torch.Tensor # ) -> torch.Tensor: # w = mask / mask.sum(dim=0).clip(min=1) # diff = (y_hat - y).abs() # err = 0.0 * diff # logk = 1.0 # diffk = 0.0 * diff + 1.0 # denomk = 1 # for k in range(1, self.order + 1): # logk = logk * LOG10 # diffk = diffk * diff # denomk = denomk * k # err = err + logk * diffk / denomk # return (w * err).sum(dim=0).mean() # class MaskedRelErrorLoss(MaskedLossFunction): # def __init__(self): # super().__init__() # def forward(self, y_hat: torch.Tensor, y: torch.Tensor, mask: torch.Tensor): # w = mask / mask.sum(dim=0).clip(min=1) # relerr = (LOG10 * (y_hat - y).abs()).exp() - 1 # return (w * relerr).sum(dim=0).mean() # Custom classic loss functions
[docs] class CauchyLoss(nn.Module): r"""Implements the Cauchy loss function, i.e., .. math:: \mathrm{CL}(\widehat{y}, y) = \sum_{i} \log \left( 1 + \left( \widehat{y}_{i} - y_{i} \right)^2 \right) """ def __init__(self): super().__init__()
[docs] def forward(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor: r"""Evaluates the Cauchy loss between a prediction ``y_hat`` and a reference ``y``. Parameters ---------- y_hat : torch.Tensor network prediction. y : torch.Tensor true values. Must have the same shape as ``y_hat``. Returns ------- torch.Tensor evaluated loss, should be a float. """ return torch.log(1 + (y_hat - y).square()).mean()
[docs] class SmoothL1Loss(nn.Module): r"""Implements the smooth L1 loss function, i.e., .. math:: \mathrm{SmoothL1}(\widehat{y}, y) = \sum_{i} \begin{cases} \frac{1}{2\beta}\left( \widehat{y}_{i} - y_{i} \right)^2 \; \text{ if } \vert \widehat{y}_{i} - y_{i} \vert \leq \beta \\ \vert \widehat{y}_{i} - y_{i} \vert - 0.5 \beta \; \text{ otherwise} \end{cases} """ # beta: float def __init__(self, beta: float): r""" Parameters ---------- beta : float :math:`\beta` parameter of the loss function. """ super().__init__() self.beta = beta
[docs] def forward(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor: r"""Evaluates the smooth L1 loss between a prediction ``y_hat`` and a reference ``y``. Parameters ---------- y_hat : torch.Tensor network prediction. y : torch.Tensor true values. Must have the same shape as ``y_hat``. Returns ------- torch.Tensor evaluated loss, should be a float. """ abs_diffs = torch.abs(y_hat - y) return torch.mean( torch.where( abs_diffs < self.beta, 0.5 * torch.square(abs_diffs) / self.beta, abs_diffs - 0.5 * self.beta, ) )