Shortcuts

Source code for tomopt.benchmarks.ladle_furnace.loss

from collections import defaultdict
from typing import Dict, List, Optional, Union

import torch
from torch import Tensor

from ...optimisation.callbacks import Callback
from ...optimisation.loss import VolumeIntClassLoss
from ...volume import Volume

__all__ = ["LadleFurnaceIntClassLoss", "SpreadRangeLoss"]


[docs]class LadleFurnaceIntClassLoss(VolumeIntClassLoss): r""" Research tested only: no unit tests """ def __init__( self, *, pred_int_start: int = 0, use_mse: bool, target_budget: float, budget_smoothing: float = 10, cost_coef: Optional[Union[Tensor, float]] = None, steep_budget: bool = True, debug: bool = False, ): super().__init__( targ2int=self._targ2int, pred_int_start=pred_int_start, use_mse=use_mse, target_budget=target_budget, budget_smoothing=budget_smoothing, cost_coef=cost_coef, steep_budget=steep_budget, debug=debug, ) @staticmethod def _targ2int(targs: Tensor, volume: Volume) -> Tensor: return ( torch.div((targs - volume.get_passive_z_range()[0]), volume.passive_size, rounding_mode="floor") - 1 ) # -1 due to conversion to layer ID, instead of fill height
[docs]class SpreadRangeLoss(Callback): r""" Research tested only: no unit tests """
[docs] def on_volume_batch_begin(self) -> None: self._preds: Dict[float, List[Tensor]] = defaultdict(list)
[docs] def on_x0_pred_end(self) -> None: self._preds[self.wrapper.volume.target.cpu().item()].append(self.wrapper.fit_params.pred)
[docs] def on_volume_batch_end(self) -> None: stds = [] means = [] for preds in self._preds.values(): if len(preds) < 2: continue p = torch.cat(preds) stds.append(torch.std(p)) means.append(torch.mean(p)) spread = torch.mean(torch.stack(stds)) range_ = torch.std(torch.stack(means)) loss = spread / range_ self.wrapper.fit_params.mean_loss = loss[None]

Docs

Access comprehensive developer and user documentation for TomOpt

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of TomOpt

View Tutorials