Shortcuts

Source code for tomopt.volume.heatmap

from typing import Dict, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn.functional as F
from torch import Tensor, nn

from ..core import DEVICE
from ..muon import MuonBatch

__all__ = ["DetectorHeatMap"]


[docs]class DetectorHeatMap(nn.Module): def __init__( self, *, res: float, eff: float, init_xyz: Tuple[float, float, float], init_xy_span: Tuple[float, float], m2_cost: float = 1, budget: Optional[Tensor] = None, realistic_validation: bool = False, device: torch.device = DEVICE, n_cluster: int = 30, ): if res <= 0: raise ValueError("Resolution must be positive") if eff <= 0: raise ValueError("Efficiency must be positive") super().__init__() self.realistic_validation = realistic_validation self.device = device self.register_buffer("m2_cost", torch.tensor(float(m2_cost), device=self.device)) self.register_buffer("resolution", torch.tensor(float(res), device=self.device)) self.register_buffer("efficiency", torch.tensor(float(eff), device=self.device)) self.n_cluster = n_cluster if init_xy_span[1] < init_xy_span[0]: init_xy_span = (init_xy_span[1], init_xy_span[0]) self.register_buffer("xy_fix", torch.tensor(init_xyz[:2], device=self.device)) self.register_buffer("xy_span_fix", torch.tensor(init_xy_span, device=self.device)) self.delta_xy = init_xy_span[1] - init_xy_span[0] self.gmm = GMM(n_cluster=self.n_cluster, init_xy=init_xyz[:2], device=device, init_xy_span=self.delta_xy) self.mu = self.gmm.mu self.sig = self.gmm.sig self.norm = self.gmm.norm self.z = nn.Parameter(torch.tensor(init_xyz[2:3], device=self.device)) self.range_mult = 1.2 self.budget_scale = torch.ones(1, device=device) self.assign_budget(budget) def __repr__(self) -> str: return f"""{self.__class__} at av. xy={self.gmm.mu.T.mean(1)} with n_comp {self.n_cluster}, z={self.z.data}."""
[docs] def get_xy_mask(self, xy: Tensor) -> Tensor: raise NotImplementedError("Realistic validation isn't yet supported for heatmap detectors") if not isinstance(self.xy_fix, Tensor): raise ValueError(f"{self.xy_fix} is not a Tensor for some reason.") # To appease MyPy if not isinstance(self.xy_span_fix, Tensor): raise ValueError(f"{self.xy_span_fix} is not a Tensor for some reason.") # To appease MyPy xy_low = self.xy_fix - self.range_mult * self.xy_span_fix xy_high = self.xy_fix + self.range_mult * self.xy_span_fix return (xy[:, 0] >= xy_low[0]) * (xy[:, 0] < xy_high[0]) * (xy[:, 1] >= xy_low[1]) * (xy[:, 1] < xy_high[1])
[docs] def get_resolution(self, xy: Tensor, mask: Optional[Tensor] = None) -> Tensor: if not isinstance(self.resolution, Tensor): raise ValueError(f"{self.resolution} is not a Tensor for some reason.") # To appease MyPy if self.training or not self.realistic_validation: res = self.resolution * self.gmm(xy) else: if mask is None: mask = self.get_xy_mask(xy) res = torch.zeros((len(xy), 2), device=self.device) # Zero detection outside detector res[mask] = self.resolution return res
[docs] def get_efficiency(self, xy: Tensor, mask: Optional[Tensor] = None, as_2d: bool = False) -> Tensor: if not isinstance(self.efficiency, Tensor): raise ValueError(f"{self.efficiency} is not a Tensor for some reason.") # To appease MyPy if self.training or not self.realistic_validation: scale = self.gmm(xy) scale = torch.min(torch.tensor(1.0), scale) if not as_2d: scale = torch.prod(scale, dim=-1) # Maybe weight product by xy distance? eff = self.efficiency * scale else: if mask is None: mask = self.get_xy_mask(xy) eff = torch.zeros(len(xy), device=self.device) # Zero detection outside detector eff[mask] = self.efficiency if as_2d: eff = eff[:, None] return eff
[docs] def get_cost(self) -> Tensor: return self.m2_cost * self.sig.prod(1).mean()
[docs] def assign_budget(self, budget: Optional[Tensor] = None) -> None: if budget is not None: raise NotImplementedError("Please update me to work with a budget!")
[docs] def get_hits(self, mu: MuonBatch) -> Dict[str, Tensor]: if not isinstance(self.xy_fix, Tensor): raise ValueError(f"{self.xy_fix} is not a Tensor for some reason.") # To appease MyPy if not isinstance(self.xy_span_fix, Tensor): raise ValueError(f"{self.xy_span_fix} is not a Tensor for some reason.") # To appease MyPy mask = mu.get_xy_mask(self.xy_fix - self.range_mult * self.delta_xy, self.xy_fix + self.range_mult * self.delta_xy) # Muons in panel true_mu_xy = mu.xy.data xy0 = self.xy_fix - (self.delta_xy / 2) # approx. Low-left of panel rel_xy = true_mu_xy - xy0 res = self.get_resolution(true_mu_xy, mask) rel_xy = rel_xy + (torch.randn((len(mu), 2), device=self.device) / res) if not self.training and self.realistic_validation: # Prevent reco hit from exiting panel # fix this? span = self.xy_span_fix.detach().cpu().numpy() rel_xy[mask] = torch.stack([torch.clamp(rel_xy[mask][:, 0], 0, span[0]), torch.clamp(rel_xy[mask][:, 1], 0, span[1])], dim=-1) reco_xy = xy0 + rel_xy reco_xyz = F.pad(reco_xy, (0, 1)) reco_xyz[:, 2] = self.z gen_xyz = F.pad(true_mu_xy, (0, 1)) gen_xyz[:, 2] = self.z hits = { "reco_xyz": reco_xyz, "gen_xyz": gen_xyz, "unc_xyz": F.pad(1 / res, (0, 1)), # Add zero for z unc "eff": self.get_efficiency(true_mu_xy, mask)[:, None], } return hits
[docs] def plot_map(self, bpixelate: bool = False, bsavefig: bool = False, filename: str = None) -> None: """""" if not isinstance(self.xy_fix, Tensor): raise ValueError(f"{self.xy_fix} is not a Tensor for some reason.") # To appease MyPy if not isinstance(self.xy_span_fix, Tensor): raise ValueError(f"{self.xy_span_fix} is not a Tensor for some reason.") # To appease MyPy def get_z_from_mesh(x: Tensor, y: Tensor) -> Tensor: stacked_t = torch.stack([x, y]).T reshaped = torch.reshape(stacked_t, (stacked_t.shape[0] * stacked_t.shape[1], stacked_t.shape[2])) reshaped = torch.unsqueeze(reshaped, 1) z = self.gmm(reshaped).prod(1) torch.min(torch.tensor(1.0), z) z = torch.reshape(z, (stacked_t.shape[0], stacked_t.shape[1])) return z with sns.axes_style(style="whitegrid", rc={"patch.edgecolor": "none"}): x = self.xy_fix[0].detach().cpu().numpy() y = self.xy_fix[1].detach().cpu().numpy() xs = torch.linspace(x - 2 * self.delta_xy, x + 2 * self.delta_xy, steps=200) ys = torch.linspace(y - 2 * self.delta_xy, y + 2 * self.delta_xy, steps=200) x, y = torch.meshgrid(xs, ys) z = get_z_from_mesh(x, y).detach().cpu().numpy() fig, ax = plt.subplots(1, 1, figsize=(4, 4)) if bpixelate: cs = ax.scatter(x, y, c=z, cmap="plasma", s=450.0, marker="s") else: cs = ax.contourf(x, y, z, cmap="plasma", vmin=0.0, vmax=0.9) ax.set_aspect("equal") fig.colorbar(cs) if bsavefig: if filename is None: filename = "heatmap_plot.png" plt.savefig(filename, dpi=300) plt.close() else: plt.show()
[docs] def clamp_params(self, musigz_low: Tuple[float, float, float], musigz_high: Tuple[float, float, float]) -> None: with torch.no_grad(): eps = np.random.uniform(0, 1e-3) # prevent hits at same z due to clamping self.mu.clamp_(min=musigz_low[0], max=musigz_high[0]) self.z.clamp_(min=musigz_low[2] + eps, max=musigz_high[2] - eps) self.sig.clamp_(min=musigz_low[1] / self.range_mult, max=self.range_mult * musigz_high[1]) self.norm.clamp_(min=0.01, max=1.5)
@property def x(self) -> Tensor: if not isinstance(self.xy_fix, Tensor): raise ValueError(f"{self.xy_fix} is not a Tensor for some reason.") # To appease MyPy return self.xy_fix[0] @property def y(self) -> Tensor: if not isinstance(self.xy_fix, Tensor): raise ValueError(f"{self.xy_fix} is not a Tensor for some reason.") # I just love MyPy return self.xy_fix[1]
class GMM(nn.Module): """""" def __init__( self, n_cluster: int = 20, init_xy: Tuple[float, float] = (0.0, 0.0), init_xy_span: float = 10.0, init_norm: float = 1.0, device: torch.device = DEVICE, ) -> None: super(GMM, self).__init__() self.n_cluster = n_cluster self.device = device self._init_xy = torch.tensor(init_xy, device=self.device) self._init_xy_span = torch.tensor(init_xy_span, device=self.device) rand_mu = self._init_xy_span * (0.5 - torch.rand(self.n_cluster, 2, device=self.device)) self.mu = nn.Parameter(rand_mu + self._init_xy) rand_sig = torch.max(torch.rand(self.n_cluster, 2, device=self.device), torch.tensor(0.2)) self.sig = nn.Parameter(self._init_xy_span * rand_sig) self.norm = nn.Parameter(torch.tensor([float(init_norm)], device=self.device)) mix = torch.distributions.Categorical( torch.ones( self.n_cluster, ) ) comp = torch.distributions.Independent( torch.distributions.Normal( self.mu, self.sig, ), 1, ) self.gmm = torch.distributions.MixtureSameFamily(mix, comp) def forward(self, x: Tensor) -> Tensor: res = self.norm * torch.exp(self.gmm.log_prob(x) - torch.max(self.gmm.log_prob(self.mu))) res = res.reshape(res.shape[0], 1) res = res.expand(res.shape[0], 2) res = torch.sqrt(res) return res

Docs

Access comprehensive developer and user documentation for TomOpt

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of TomOpt

View Tutorials