Source code for kliff.trainer.utils.losses

from typing import Union

import numpy as np
import torch


[docs] def MSE_loss( predictions: Union[np.ndarray, torch.Tensor], targets: Union[np.ndarray, torch.Tensor], weights: Union[np.ndarray, torch.Tensor] = 1.0, ) -> Union[np.ndarray, torch.Tensor]: r""" Compute the mean squared error (MSE) of the residuals, with the option to weight the residuals. Args: predictions: The predicted values. targets: The target values. weights: The weights to apply to the residuals. Default is 1.0. Returns: The MSE of the residuals. """ residuals = predictions - targets if isinstance(residuals, (np.ndarray, float, np.float64)): return np.mean((residuals * weights) ** 2) else: return torch.mean( (residuals**2) * torch.asarray(weights, dtype=residuals.dtype, device=residuals.device) )
[docs] def MAE_loss( predictions: Union[np.ndarray, torch.Tensor], targets: Union[np.ndarray, torch.Tensor], weights: Union[np.ndarray, torch.Tensor] = 1.0, ) -> Union[np.ndarray, torch.Tensor]: r""" Compute the mean absolute error (MAE) of the residuals, with the option to weight the residuals. Args: predictions: The predicted values. targets: The target values. weights: The weights to apply to the residuals. Default is 1.0. Returns: The MAE of the residuals. """ residuals = predictions - targets if isinstance(residuals, (np.ndarray, float, np.float64)): return np.mean(np.abs(residuals * weights)) else: return torch.mean( torch.abs(residuals) * torch.asarray(weights, dtype=residuals.dtype, device=residuals.device) )