Source code for ptyrax.models.detector

from __future__ import annotations

from abc import abstractmethod
from typing import Callable, Dict, Literal, Optional, Union

import equinox as eqx
import gin
import jax.numpy as jnp
import numpy as np
from jax.scipy.spatial.transform import Rotation as JaxRotation
from jaxtyping import Array, Float, PyTree
from tensorboardX import SummaryWriter

from ptyrax.field import CoherentField
from ptyrax.logger import log_image
from ptyrax.models.illumination import NamedLoss
from ptyrax.models.propagation import (
    source_fourier_occupancy_from_tilt_interpolation,
    source_fourier_support_from_tilt_interpolation,
)
from ptyrax.parametrizations import IndexSliceParameter, resolve_array_parametrizations
from ptyrax.spatial import CoordinateSystem, Rotation, SamplingGrid
from ptyrax.state_geometry import state_get_with_candidates


[docs] @gin.configurable() class Detector(eqx.Module): """Abstract detector interface mapping input fields to measured detector counts.""" coordinates: IndexSliceParameter[CoordinateSystem] sampling: SamplingGrid = eqx.field(static=True)
[docs] @abstractmethod def __call__(self, input_field: CoherentField, index: Optional[int]) -> Float[Array, "* m n"]: pass
[docs] @classmethod def load_from_hdf5_state( cls, state: Dict[str, np.ndarray], *, detector_path_prefix: str = "detector", ) -> "Detector": r"""Construct a concrete detector instance from an HDF5 state dictionary. Extracts rotation, translation, pixel size, and optionally dark counts from the HDF5 state and builds the detector coordinate system and sampling grid. Args: state: Dictionary mapping HDF5 dataset paths to NumPy arrays, as returned by reading an HDF5 checkpoint file. detector_path_prefix: Path prefix within the state dictionary under which detector parameters are stored. Returns: An instance of the concrete detector subclass populated from the state dictionary. Raises: TypeError: If called directly on the abstract :py:class:`~ptyrax.models.detector.Detector` base class. KeyError: If the required ``dark_counts`` key is missing from the state dictionary. ValueError: If the detector pixel size has fewer than one value. Example: >>> state = load_hdf5_to_dict("reconstruction.hdf5") >>> detector = NoiselessEqualWeightDetector.load_from_hdf5_state(state, detector_path_prefix="detector") """ if cls is Detector: raise TypeError("Call `load_from_hdf5_state` on a concrete Detector subclass.") prefix = detector_path_prefix.strip("/") pref = f"{prefix}/" if prefix else "" rotation_6d = np.asarray( state_get_with_candidates( state, [ f"{pref}coordinates/parameters/rotation/_representation_6d", f"{pref}coordinates/rotation/_representation_6d", ], ) ) translation = np.asarray( state_get_with_candidates( state, [ f"{pref}coordinates/parameters/_translation", f"{pref}coordinates/_translation", f"{pref}coordinates/translation", ], ) ) detector_pixel_size = np.asarray( state_get_with_candidates( state, [f"{pref}sampling/pixel_size"], ) ).reshape(-1) if detector_pixel_size.size == 1: detector_pixel_size = np.repeat(detector_pixel_size, 2) if detector_pixel_size.size < 2: raise ValueError("Detector pixel size must have at least one value.") if rotation_6d.ndim == 1: rotation_6d = rotation_6d[None, :] if translation.ndim == 1: translation = translation[None, :] if f"{pref}dark_counts" in state: dark_counts = np.asarray(state[f"{pref}dark_counts"]) detector_shape = tuple(dark_counts.shape[-2:]) else: dark_counts = None raise KeyError(f"Missing required detector shape source '{pref}dark_counts' in state.") coordinates = IndexSliceParameter( CoordinateSystem( rotation=Rotation(jnp.asarray(rotation_6d)), translation=jnp.asarray(translation), ) ) sampling = SamplingGrid.from_tuples(detector_shape, tuple(detector_pixel_size[:2])) kwargs = { "coordinates": coordinates, "sampling": sampling, } if dark_counts is not None and "dark_counts" in getattr(cls, "__annotations__", {}): kwargs["dark_counts"] = jnp.asarray(dark_counts) try: return cls(**kwargs) except TypeError: kwargs.pop("dark_counts", None) return cls(**kwargs)
[docs] @classmethod def from_hdf5_state( cls, state: Dict[str, np.ndarray], *, detector_path_prefix: str = "detector", ) -> "Detector": """Alias for :py:meth:`load_from_hdf5_state`. Args: state: Dictionary mapping HDF5 dataset paths to NumPy arrays. detector_path_prefix: Path prefix for detector parameters in the state dictionary. Returns: An instance of the concrete detector subclass. """ return cls.load_from_hdf5_state(state, detector_path_prefix=detector_path_prefix)
def _resolve_target_coordinate_system(self, index: int = 0): """Resolve the target coordinate system for a given index.""" coords = resolve_array_parametrizations(self.coordinates) if hasattr(coords, "at"): return coords.at(index) elif hasattr(coords, "at_current_index"): return coords.at_current_index() return coords
[docs] def fourier_mask( self, field: CoherentField, *, index: int = 0, return_details: bool = False, ) -> np.ndarray | tuple[np.ndarray, np.ndarray, np.ndarray]: r"""Compute the Fourier-space occupancy mask for this detector. Determines which pixels in the source field's Fourier representation map onto the detector, given the relative tilt between the field coordinate system and the detector coordinate system. Args: field: The coherent source field whose Fourier occupancy is evaluated against the detector geometry. index: Position index into the detector coordinate parameter, used when the detector has multiple indexed orientations. return_details: If ``True``, return the full tuple of intermediate arrays instead of just the source occupancy. Returns: If ``return_details`` is ``False``, a 2-D array of the source Fourier occupancy (fraction of each source pixel captured by the detector). If ``True``, a tuple of ``(source_occupancy, detector_mask, field_frame_target_indices)``. """ target_coordinate_system = self._resolve_target_coordinate_system(index) source_occupancy, detector_mask, field_frame_target_indices = source_fourier_occupancy_from_tilt_interpolation( field, target_coordinate_system, self.sampling, ) if return_details: return source_occupancy, detector_mask, field_frame_target_indices return source_occupancy
[docs] def fourier_support_mask( self, field: CoherentField, *, index: int = 0, return_details: bool = False, ) -> np.ndarray | tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: r"""Compute the Fourier-space binary support mask for this detector. Similar to :py:meth:`fourier_mask`, but returns a binary support indicating which source Fourier pixels have *any* overlap with the detector, rather than a fractional occupancy. Args: field: The coherent source field whose Fourier support is evaluated against the detector geometry. index: Position index into the detector coordinate parameter. return_details: If ``True``, return the full tuple of intermediate arrays. Returns: If ``return_details`` is ``False``, a 2-D binary array of the source Fourier support. If ``True``, a tuple of ``(source_support, source_occupancy, detector_mask, field_frame_target_indices)``. """ target_coordinate_system = self._resolve_target_coordinate_system(index) source_support, source_occupancy, detector_mask, field_frame_target_indices = ( source_fourier_support_from_tilt_interpolation( field, target_coordinate_system, self.sampling, ) ) if return_details: return source_support, source_occupancy, detector_mask, field_frame_target_indices return source_support
def __log_epoch__( self, writer: SummaryWriter, epoch: int, prefix: str = "", **kwargs, ) -> None: """Log detector pose parameters to TensorBoard. Writes the detector Euler angles (xyz convention, in degrees) and translation components as scalar summaries under the ``3_detector/`` tag group. Args: writer: TensorBoard summary writer instance. epoch: Current training epoch number. prefix: Optional prefix for TensorBoard tags (unused in this base implementation but accepted for interface consistency). **kwargs: Additional keyword arguments (ignored). """ coordinate = resolve_array_parametrizations(self.coordinates).at(0) position = coordinate.translation r = np.array(coordinate.rotation.as_matrix()) euler_angles = JaxRotation.from_matrix(r).as_euler("xyz", degrees=True) writer.add_scalar("3_detector/rotation/x", euler_angles[0], epoch) writer.add_scalar("3_detector/rotation/y", euler_angles[1], epoch) writer.add_scalar("3_detector/rotation/z", euler_angles[2], epoch) writer.add_scalar("3_detector/position/x", position[0], epoch) writer.add_scalar("3_detector/position/y", position[1], epoch) writer.add_scalar("3_detector/position/z", position[2], epoch) writer.add_scalar( "3_detector/position/distance", np.linalg.norm(position), epoch, )
[docs] @gin.configurable() class NoiselessEqualWeightDetector(Detector): r"""Ideal noiseless detector with equal weighting across coherent modes. Models a detector that sums the intensity contributions from all coherent modes with equal weight and no noise or background. The output can be returned as either amplitude (square root of summed intensity) or raw intensity, controlled by the ``mode`` parameter. The forward model is: .. math:: I_{\text{det}} = \sum_k |\psi_k|^2 where :math:`\psi_k` are the coherent mode fields at the detector plane. Attributes: coordinates: Detector coordinate system (position and orientation). sampling: Detector pixel grid specification. mode: Output mode — ``"amplitude"`` returns :math:`\sqrt{I_{\text{det}}}`, ``"intensity"`` returns :math:`I_{\text{det}}` directly. Example: >>> detector = NoiselessEqualWeightDetector( ... coordinates=detector_coords, ... sampling=detector_sampling, ... mode="amplitude", ... ) >>> predicted = detector(exit_fields) """ coordinates: IndexSliceParameter[CoordinateSystem] sampling: SamplingGrid = eqx.field(static=True) mode: Union[Literal["amplitude"], Literal["intensity"]] = eqx.field(static=True, default="amplitude") def __init__(self, *args, **kwargs) -> None: """Initialize the noiseless equal-weight detector. Args: *args: Positional arguments forwarded to :py:class:`~ptyrax.models.detector.Detector`. **kwargs: Keyword arguments forwarded to the parent class. The ``mode`` keyword is consumed here and not forwarded. """ mode = kwargs.pop("mode", "amplitude") self.mode = mode super().__init__(*args, **kwargs)
[docs] def __call__(self, input_fields: PyTree[CoherentField]) -> Float[Array, "* m n"]: r"""Compute the noiseless detector measurement from input fields. Sums intensity contributions from all coherent modes and returns either the amplitude or intensity depending on the detector mode. Args: input_fields: PyTree of coherent fields at the detector plane, with an ``.intensity`` attribute of shape ``(n_modes, m, n, 1)``. Returns: A 2-D array of shape ``(m, n)`` containing the predicted detector readout (amplitude or intensity). """ coherent_counts = input_fields.intensity[..., 0] incoherent_counts = jnp.sum(coherent_counts, axis=0) return jnp.sqrt(incoherent_counts) if self.mode == "amplitude" else incoherent_counts
[docs] @gin.configurable() class BackgroundEqualWeightDetector(Detector): r"""Detector model with learnable dark-count background. Extends the equal-weight detector by adding a trainable 2-D ``dark_counts`` array that models detector background (e.g. thermal noise, stray light). The forward model averages amplitudes across coherent modes and adds the dark counts: .. math:: I_{\text{det}} = \frac{1}{K}\sum_k |\psi_k| + D where :math:`D` is the dark-count background and :math:`K` is the number of coherent modes. Regularization functions can be attached to penalize the dark-count map during optimization (e.g. smoothness or sparsity priors). Attributes: coordinates: Detector coordinate system. sampling: Detector pixel grid specification. dark_counts: Trainable 2-D background array of shape ``(m, n)``. dynamic_range: Tuple ``(min, max)`` of the detector dynamic range. scale: Multiplicative scale factor (static). regularization_functions: Tuple of callables applied to ``dark_counts`` to compute regularization losses. Example: >>> detector = BackgroundEqualWeightDetector( ... coordinates=detector_coords, ... sampling=detector_sampling, ... dark_counts=jnp.zeros((256, 256)), ... ) >>> predicted = detector(exit_fields) """ coordinates: CoordinateSystem sampling: SamplingGrid dark_counts: Array dynamic_range: tuple[float, float] = eqx.field(static=True) scale: float = eqx.field(static=True) regularization_functions: tuple[Callable[[Float[Array, "m n"]], float]] = eqx.field(static=True) def __init__( self, *args, dynamic_range: tuple[float, float] = None, dark_counts: Float[Array, "m n"] = None, scale: float = 1.0, regularization_functions: tuple[Callable[[Float[Array, "m n"]], float]] = (), **kwargs, ) -> None: """Initialize the background equal-weight detector. Args: *args: Positional arguments forwarded to :py:class:`~ptyrax.models.detector.Detector`. dynamic_range: Tuple of ``(min, max)`` detector counts. Defaults to ``(0.0, 2**16)`` if not provided. dark_counts: Initial 2-D dark-count background array. If ``None``, initialized to zeros matching the detector shape. scale: Static multiplicative scale factor applied to the detector model. regularization_functions: Tuple of callables that accept the dark-count array and return a scalar loss value. **kwargs: Additional keyword arguments forwarded to the parent class. """ super(BackgroundEqualWeightDetector, self).__init__(*args, **kwargs) if dark_counts is not None: self.dark_counts = dark_counts else: self.dark_counts = jnp.zeros((self.sampling.n_x, self.sampling.n_y)) self.dynamic_range = (0.0, 2**16) if dynamic_range is None else dynamic_range self.regularization_functions = regularization_functions self.scale = scale
[docs] def __call__(self, input_fields: PyTree[CoherentField]) -> Float[Array, "* m n"]: """Compute the detector measurement including dark-count background. Averages the amplitude across coherent modes and adds the learned dark-count background. Args: input_fields: PyTree of coherent fields at the detector plane, with an ``.amplitude`` attribute of shape ``(n_modes, m, n, 1)``. Returns: A 2-D array of shape ``(m, n)`` containing the predicted detector readout (mean amplitude plus dark counts). """ coherent_counts = input_fields.amplitude[..., 0] incoherent_counts = jnp.mean(coherent_counts, axis=0) total_detected_counts = incoherent_counts + self.dark_counts return total_detected_counts
def __log_epoch__( self, writer: SummaryWriter, epoch: int, prefix: str = "", **kwargs, ) -> None: """Log detector pose and dark-count image to TensorBoard. Extends the base class logging by additionally writing the dark-count background as an image summary with gamma correction. Args: writer: TensorBoard summary writer instance. epoch: Current training epoch number. prefix: Optional prefix for TensorBoard tags. **kwargs: Additional keyword arguments (ignored). """ super(BackgroundEqualWeightDetector, self).__log_epoch__(writer, epoch, prefix) log_image(writer, "3_detector/dark_counts", self.dark_counts, epoch, gamma=0.33) def __regularize__(self) -> list[NamedLoss]: """Compute regularization losses on the dark-count background. Applies each function in ``regularization_functions`` to the ``dark_counts`` array and accumulates a total scalar loss. Returns: A tuple of ``(total_loss, regularizations)`` where ``total_loss`` is the sum of all regularization terms and ``regularizations`` is a list of :py:class:`~ptyrax.models.illumination.NamedLoss` entries. """ regularizations = [] total = 0.0 for fn in self.regularization_functions: value = fn(self.dark_counts) total += value regularizations.append(NamedLoss(f"dark_counts.{fn.__name__}", value)) return total, regularizations