Shortcuts

Source code for tomopt.optimisation.loss.sub_losses

from typing import Optional

import torch
from torch import Tensor

r"""
Provides functions to compute sub-loss components
"""

__all__ = ["integer_class_loss"]


[docs]def integer_class_loss( int_probs: Tensor, target_int: Tensor, pred_start_int: int, use_mse: bool, weight: Optional[Tensor] = None, reduction: str = "mean", ) -> Tensor: r""" Loss for classifying integers, when regression is not applicable. It assumed that the the integers really are quantifiably comparable, and not categorical codes of classes. Like multiclass-classification, predictions are a probabilities for each possible integer, but the ICL aims to penalise close predictions less than far-off ones: For a target of 3 and a close prediction of `softmax([1,3,10,5,5,3,1])` and a far-off prediction of `softmax([10,3,1,5,5,3,1])`, the categorical cross-entropy produces the same loss for both predictions (5.0154) despite the close prediction having a higher probability near the target. ICL instead computes the absolute error, or squared error, between each of the possible integers and the true target. These errors are then normalised, weighted by the predicted probabilities, and summed. I.e. integers close to the target have a lower error, and these are given greater weight in the sum if they have a higher probability. For the example, the ICL produces a loss of 1.0007 for the close prediction, and 8.8773 for the far-off one. Arguments: int_probs: (*,integers) tensor of predicted probabilities target_int: (*) tensor of target integers pred_start_int: the integer that the zeroth probability in predictions corresponds to use_mse: whether to compute errors as absolute or squared weight: Optional (*) tensor of multiplicative weights for the unreduced ICLs reduction: 'mean' return the average ICL, 'sum' sum the ICLs, 'none', return the individual ICLs """ ints = torch.arange(pred_start_int, pred_start_int + int_probs.size(-1)) diffs = target_int - ints if use_mse: diffs = diffs**2 else: diffs = diffs.abs() diffs = diffs / diffs.sum(-1, keepdim=True) loss = diffs * int_probs loss = loss.sum(-1, keepdim=True) if weight is not None: loss = loss * weight if reduction == "mean": return loss.mean() elif reduction == "sum": return loss.sum() elif reduction == "none": return loss else: raise ValueError(f"Unknown reduction {reduction}. Please use ['mean', 'sum', 'none'].")

Docs

Access comprehensive developer and user documentation for TomOpt

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of TomOpt

View Tutorials