Shortcuts

Source code for tomopt.optimisation.callbacks.opt_callbacks

from __future__ import annotations

from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Optional, Tuple

import numpy as np

from .callback import Callback
from .warmup_callbacks import PostWarmupCallback

if TYPE_CHECKING:
    from ...optimisation.wrapper import AbsVolumeWrapper

__all__ = ["OneCycle", "EpochSave"]


class AbsOptSchedule(PostWarmupCallback, metaclass=ABCMeta):
    def __init__(
        self,
        opt_name: str,
        init_lr: Optional[float] = None,
        init_mom: Optional[float] = None,
    ) -> None:
        self.init_lr = init_lr
        self.init_mom = init_mom
        self.opt_name = opt_name

    def set_wrapper(self, wrapper: AbsVolumeWrapper) -> None:
        super().set_wrapper(wrapper)
        if self.opt_name not in self.wrapper.opts:
            raise ValueError(f"{self.opt_name} not found in VolumeWrapper")

    def _activate(self) -> None:
        super()._activate()
        self.n_iters_per_epoch = len(self.wrapper.fit_params.trn_passives) // self.wrapper.fit_params.passive_bs
        self.n_epochs_expected = self.wrapper.fit_params.n_epochs - self.wrapper.fit_params.epoch + 1

    @abstractmethod
    def schedule(self) -> Tuple[Optional[float], Optional[float]]:
        r"""
        Compute LR and momentum as a function of iter_cnt, according to defined ranges.
        """

        pass

    def on_train_begin(self) -> None:
        super().on_train_begin()
        if self.init_lr is not None:
            self.wrapper.set_opt_lr(self.init_lr, self.opt_name)
        if self.init_mom is not None:
            self.wrapper.set_opt_mom(self.init_mom, self.opt_name)
        self.iter_cnt = 0

    def on_step_end(self) -> None:
        if self.active:
            self.iter_cnt += 1
            lr, mom = self.schedule()
            if lr is not None:
                self.wrapper.set_opt_lr(lr, self.opt_name)
            if mom is not None:
                self.wrapper.set_opt_mom(mom, self.opt_name)


[docs]class OneCycle(AbsOptSchedule): r""" Callback implementing Smith 1-cycle evolution for lr and momentum (beta_1) https://arxiv.org/abs/1803.09820 In the warmup phase: Learning rate is increased from `init_lr` to `mid_lr`, Momentum is decreased from `init_mom` to `mid_mom`, to stabilise the use of high LRs In the convergence phase: Learning rate is decreased from `mid_lr` to `final_lr`, Momentum is increased from `mid_mom` to `final_mom` Setting the learning rate or momentum here will override the values specified when instantiating the `VolumeWrapper`. learning rate or momentum arguments can be `None` to avoid annealing or overriding their values. Arguments: opt_name: name of optimiser that should be affected by the scheduler warmup_length: number of epochs to use for the warmup phase init_lr: initial learning rate (low) init_mom: initial momentum (high) mid_lr: nominal learning rate (high), mid_mom: nominal momentum (moderate), final_lr: final learning rate (low), final_mom: final momentum (high) """ def __init__( self, opt_name: str, warmup_length: int, init_lr: Optional[float] = None, init_mom: Optional[float] = None, mid_lr: Optional[float] = None, mid_mom: Optional[float] = None, final_lr: Optional[float] = None, final_mom: Optional[float] = None, ) -> None: super().__init__(opt_name=opt_name, init_lr=init_lr, init_mom=init_mom) self.warmup_length = warmup_length self.mid_lr = mid_lr self.mid_mom = mid_mom self.final_lr = final_lr if final_lr is not None else init_lr self.final_mom = final_mom if final_mom is not None else init_mom
[docs] def on_epoch_end(self) -> None: if self.wrapper.fit_params.epoch - 1 == self.warmup_length: self.iter_cnt = 0 self.warming_up = False self.length = self.n_epochs_expected - self.wrapper.fit_params.epoch + 1 self.scale = self.length * self.n_iters_per_epoch
def _activate(self) -> None: super()._activate() self.warming_up = True self.length = self.warmup_length self.scale = self.length * self.n_iters_per_epoch
[docs] def schedule(self) -> Tuple[Optional[float], Optional[float]]: x = np.cos(np.pi * (self.iter_cnt) / self.scale) + 1 def get_param(init: float, mid: float, final: float) -> float: if self.warming_up: params = (init, mid) else: params = (final, mid) dx = (params[1] - params[0]) * x / 2 return params[1] - dx if self.init_lr is None: lr = None else: lr = get_param(self.init_lr, self.mid_lr, self.final_lr) if self.init_mom is None: mom = None else: mom = get_param(self.init_mom, self.mid_mom, self.final_mom) return lr, mom
[docs]class EpochSave(Callback): r""" Saves the state of the volume at the end of each training epoch to a unique file. This can be used to load a specifc state to either be used, or to resume training. """
[docs] def on_epoch_end(self) -> None: if self.wrapper.fit_params.state == "train": self.wrapper.save(self.wrapper.fit_params.cb_savepath / f"volume_state_{self.wrapper.fit_params.epoch}.pt")

Docs

Access comprehensive developer and user documentation for TomOpt

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of TomOpt

View Tutorials