Shortcuts

Source code for tomopt.optimisation.callbacks.heatmap_gif

import os
from typing import List

import imageio

from ...volume import PanelDetectorLayer
from .callback import Callback

r"""
Skeletal script to create a gif of the heatmap during training through callbacks.
"""

__all__ = ["HeatMapGif"]


[docs]class HeatMapGif(Callback): r""" Records a gif of the first heatmap in the first detector layer during training. Arguments: gif_filename: savename for the gif (will be appended to the callback savepath) """ def __init__(self, gif_filename: str = "heatmap.gif") -> None: r""" Initialises the callback. """ self.gif_filename = gif_filename self._reset()
[docs] def on_train_begin(self) -> None: r""" Prepares to record a new gif """ super().on_train_begin() self._reset()
[docs] def on_epoch_begin(self) -> None: r""" When a new training epoch begins, saves an image of the current layout of the first heatmap in the first detector layer """ if self.wrapper.fit_params.state == "train": # Avoid doubling the length of the GIF self._plot_current()
[docs] def on_train_end(self) -> None: r""" When training, saves an image of the current layout of the first heatmap in the first detector layer and then combines all images into a gif """ self._plot_current() self._create_gif()
def _plot_current(self) -> None: r""" Saves an image of the current layout of the first heatmap in the first detector layer """ filename = self.wrapper.fit_params.cb_savepath / f"temp_heatmap_{len(self._buffer_files)}.png" self._buffer_files.append(filename) for l in self.wrapper.volume.get_detectors(): if isinstance(l, PanelDetectorLayer) and l.type_label == "heatmap": for p in l.panels: p.plot_map(bsavefig=True, filename=filename) break else: raise NotImplementedError(f"HeatMapGif does not yet support {type(l) , l.type_label}") break def _reset(self) -> None: r""" Prepares to record a new gif """ self._buffer_files: List[str] = [] def _create_gif(self) -> None: r""" Combines recorded images into a gif """ with imageio.get_writer(self.wrapper.fit_params.cb_savepath / self.gif_filename, mode="I") as writer: for filename in self._buffer_files: image = imageio.imread(filename) writer.append_data(image) for filename in set(self._buffer_files): os.remove(filename)

Docs

Access comprehensive developer and user documentation for TomOpt

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of TomOpt

View Tutorials