Source code for ptyrax.models.interaction

from __future__ import annotations

from abc import abstractmethod
from typing import Callable, Type, Union

import equinox as eqx
import gin
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jax import vmap
from jax.scipy.spatial.transform import Rotation as JaxRotation
from jaxtyping import Array, Complex, Float, Shaped
from matplotlib.figure import Figure
from matplotlib.gridspec import SubplotSpec
from tensorboardX import SummaryWriter

from ptyrax.field import CoherentField
from ptyrax.initializers import uniform
from ptyrax.logger import log_image
from ptyrax.models.illumination import NamedLoss
from ptyrax.parametrizations import (
    ArrayParametrization,
    DirectArrayParametrization,
    IndexSliceParameter,
    resolve_parametrizations,
)
from ptyrax.spatial import CoordinateSystem, SamplingGrid, reflect
from ptyrax.utils import phase_only_exp, plot, tree_slice_first, unstack_tree

X_AXIS = jnp.array([1.0, 0.0, 0.0])
Y_AXIS = jnp.array([0.0, 1.0, 0.0])
Z_AXIS = jnp.array([0.0, 0.0, 1.0])


[docs] class InteractionModel(eqx.Module): """Abstract base class for interaction models in ptychographic reconstruction. An interaction model describes how an incoming coherent illumination field interacts with a sample to produce an outgoing field. Subclasses implement specific physical interaction mechanisms such as thin-object Fresnel reflection or multi-slice propagation. All interaction models are JAX-compatible Equinox modules that can be differentiated through and used in JIT-compiled pipelines. Attributes: coordinates: The coordinate system defining the sample position and orientation in the global reference frame. """ coordinates: CoordinateSystem @abstractmethod def __init__( self, coordinates: CoordinateSystem, sampling: SamplingGrid, initializer: Callable[[tuple[int, int]], Complex[Array, "m n"]] = None, ) -> None: pass
[docs] @abstractmethod def __call__(self, input_field: CoherentField) -> CoherentField: """Compute the interaction of an illumination field with the sample. Args: input_field: The incoming coherent illumination field incident on the sample. Returns: The outgoing coherent field after interaction with the sample, including updated propagation direction and coordinate system. """ pass
def __log_epoch__( self, writer: SummaryWriter, epoch: int, prefix: str = "", **kwargs, ) -> None: """Log interaction model state to TensorBoard at the end of an epoch. Logs sample positions, mean rotation Euler angles, and distance to the first scan position. Args: writer: TensorBoard summary writer instance. epoch: Current training epoch number. prefix: String prefix for all logged scalar/figure names. **kwargs: Additional keyword arguments passed by the training loop. """ translation_internal = resolve_parametrizations(self.coordinates.all.translation_internal) fig, ax = plt.subplots(1, 1) ax.plot(*translation_internal[:, :2].T, ".") writer.add_figure("1_sample/interaction_positions", fig, epoch) r = np.array(self.coordinates.all.rotation.as_matrix()) r = np.mean(r, axis=0) euler_angles = JaxRotation.from_matrix(r).as_euler("xyz", degrees=True) writer.add_scalar("1_sample/rotation/x", euler_angles[0], epoch) writer.add_scalar("1_sample/rotation/y", euler_angles[1], epoch) writer.add_scalar("1_sample/rotation/z", euler_angles[2], epoch) writer.add_scalar( "1_sample/position/distance", np.linalg.norm(translation_internal[0]), epoch, ) @abstractmethod def __regularize__(self) -> list[tuple[str, float]]: """Compute regularization losses for this interaction model. Returns: A tuple of (total_loss, named_losses) where total_loss is the scalar sum of all regularization terms, and named_losses is a list of :py:class:`~ptyrax.models.illumination.NamedLoss` entries identifying each contribution. """ pass
[docs] @gin.configurable() class FresnelReflection(InteractionModel): r"""Thin-object Fresnel reflection interaction model. Models the interaction of the illumination beam with a thin planar sample using a multiplicative approximation. The exit field is computed as: $$ \psi_{\text{exit}}(\mathbf{r}) = O(\mathbf{r}) \cdot P(\mathbf{r}) $$ where $O$ is the complex reflection coefficient of the sample and $P$ is the probe illumination. The reflected field propagation direction is computed by reflecting the incident direction about the sample surface normal. Attributes: coordinates: Sample coordinate system with position and rotation for each scan point, wrapped as an :py:class:`~ptyrax.parametrizations.IndexSliceParameter`. surface_normal: Unit vector normal to the sample surface in the sample's local coordinate frame. reflection_coefficient: Complex-valued 2D array representing the sample's spatially-resolved reflection coefficient. sampling: Pixel grid defining the sample discretization. forward_sampling: Optional separate grid for the forward model computation. If ``None``, uses the illumination field's sampling. regularization_functions: Tuple of callables that compute scalar regularization penalties from the reflection coefficient. normalize: Whether to apply normalization by the square root of the number of pixels. """ coordinates: IndexSliceParameter[CoordinateSystem] surface_normal: Float[Array, "3"] reflection_coefficient: Shaped[Array, "* m n"] sampling: SamplingGrid forward_sampling: SamplingGrid regularization_functions: tuple[Callable[[Complex[Array, "* m n"]], float]] = eqx.field(static=True) normalize: bool = eqx.field(static=True, default=True) @property def shape(self) -> tuple[int, int]: """Spatial shape (m, n) of the sample reflection coefficient array.""" return self.sampling.shape def __init__( self, coordinates: IndexSliceParameter[CoordinateSystem], sampling: SamplingGrid, forward_sampling: SamplingGrid = None, initializer: Callable[[SamplingGrid], Complex[Array, "m n"]] = uniform, parametrization_type: Type[ArrayParametrization] = None, regularization_functions: tuple[Callable[[Complex[Array, "* m n"]], float]] = (), normalize: bool = True, ) -> None: """Initialize a thin-object Fresnel reflection model. Args: coordinates: Coordinate system for the sample, defining translation and rotation at each scan position. sampling: Sampling grid defining the pixel layout of the sample. forward_sampling: Optional separate sampling grid for the forward model. If ``None``, the illumination field's sampling is used. initializer: Callable that takes a :py:class:`~ptyrax.spatial.SamplingGrid` and returns an initial complex reflection coefficient array. parametrization_type: Optional array parametrization class to wrap the reflection coefficient (e.g., for enforcing constraints). regularization_functions: Tuple of functions computing scalar regularization penalties from the reflection coefficient. normalize: If ``True``, multiply the interpolated object by ``sqrt(m * n)`` to maintain energy normalization. """ if not isinstance(coordinates, IndexSliceParameter): self.coordinates = IndexSliceParameter(coordinates) else: self.coordinates = coordinates self.surface_normal = jnp.array([0.0, 0.0, 1.0]) self.sampling = sampling self.forward_sampling = forward_sampling self.reflection_coefficient = ( parametrization_type(initializer(self.sampling)) if parametrization_type is not None else initializer(self.sampling) ) self.regularization_functions = regularization_functions self.normalize = normalize
[docs] def __call__(self, illumination_field: CoherentField, normalize: bool = True) -> CoherentField: """Compute the thin-object multiplicative interaction. Interpolates the sample reflection coefficient onto the illumination grid, multiplies it element-wise with the probe, and constructs the reflected output field with updated propagation direction. Args: illumination_field: Incoming coherent illumination field. normalize: Whether to apply pixel-count normalization to the interpolated object. If the original array is normalized, this ensures the output remains approximally normalized after interpolation. Returns: The reflected coherent field after sample interaction. """ from ptyrax.spatial import interpolate_grid_to_grid sample_coordinates = self.coordinates.at_current_index() reflection_coefficient = self.reflection_coefficient if self.forward_sampling is None: forward_sampling = illumination_field.sampling else: forward_sampling = self.forward_sampling shifted_object = interpolate_grid_to_grid( reflection_coefficient, self.sampling, sample_coordinates, (self.forward_sampling if self.forward_sampling is not None else illumination_field.sampling), illumination_field.coordinate_system, equal_pixel_size=self.forward_sampling is None, ) probe_data = illumination_field()[..., 0] if self.forward_sampling is not None: probe_data = interpolate_grid_to_grid( probe_data, illumination_field.sampling, illumination_field.coordinate_system, self.forward_sampling, illumination_field.coordinate_system, ) m0, n0 = self.sampling.shape[:2] m1, n1 = illumination_field.shape[:2] if normalize: normalization_factor = jnp.sqrt(m0 * n0) shifted_object = shifted_object * normalization_factor new_data = probe_data * shifted_object surface_normal = self.surface_normal / jnp.linalg.norm(self.surface_normal) sample_z_axis = sample_coordinates.rotation.as_matrix().T @ surface_normal new_propagation_direction = reflect(illumination_field.propagation_direction, sample_z_axis) new_coordinate_system = CoordinateSystem( rotation=sample_coordinates.rotation, translation=illumination_field.coordinate_system.translation, ) output_field = CoherentField( new_data[..., jnp.newaxis], illumination_field.wavelength, forward_sampling, new_coordinate_system, new_propagation_direction, illumination_field.spatial_dims, illumination_field.vector_dim, ) return output_field
def __regularize__(self) -> tuple[float, list[NamedLoss]]: """Compute regularization losses for the reflection coefficient. Evaluates each registered regularization function on the current reflection coefficient and accumulates the results. Returns: A tuple of (total_loss, named_losses) where total_loss is the scalar sum of all regularization terms, and named_losses is a list of :py:class:`~ptyrax.models.illumination.NamedLoss` entries. """ regularizations = [] total = 0.0 for fn in self.regularization_functions: value = fn(self.reflection_coefficient()) total += value regularizations.append(NamedLoss(f"reflection_coefficient.{fn.__name__}", value)) return total, regularizations
[docs] def __plot__(self, *args, **kwargs) -> tuple[Figure, SubplotSpec, Union[Array, list[Array]]]: """Plot the reflection coefficient as a complex-valued image. Args: *args: Positional arguments forwarded to :py:func:`~ptyrax.utils.plot`. **kwargs: Keyword arguments forwarded to :py:func:`~ptyrax.utils.plot`. Returns: A tuple of (figure, subplot_spec, plotted_array) as produced by the plotting utility. """ return plot(self.reflection_coefficient, *args, extent=self.sampling.extent, **kwargs)
def __log_epoch__( self, writer: SummaryWriter, epoch: int, prefix: str = "", **kwargs, ) -> None: """Log reflection coefficient images and diagnostics to TensorBoard. Logs a real-space image (with gamma correction), the FFT magnitude (log-scaled), the coefficient norm, and surface normal components. Args: writer: TensorBoard summary writer instance. epoch: Current training epoch number. prefix: String prefix for all logged entries. **kwargs: Additional keyword arguments passed by the training loop. """ super(FresnelReflection, self).__log_epoch__(writer, epoch, prefix=prefix) coefficient = self.reflection_coefficient coefficient = jax.image.resize(coefficient, (512, 512), method="bilinear") log_image( writer, f"0/1_sample/{prefix}image", coefficient, epoch, gamma=0.5, title="sample", extent=self.sampling.extent, ) log_image( writer, f"0/1_sample/{prefix}fft_image", jnp.abs(jnp.fft.fftshift(jnp.fft.fft2(coefficient))), epoch, log10=True, title="sample fft", extent=self.sampling.to_far_field(1.0, 1.0).extent, ) writer.add_scalar(f"1_sample/{prefix}coefficient_norm", jnp.linalg.norm(coefficient), epoch) for axis, name in zip(range(3), ["x", "y", "z"]): writer.add_scalar( f"1_sample/{prefix}surface_normal/surface_normal_{name}", self.surface_normal[axis], epoch, )
[docs] @gin.configurable() class MultiSlice(InteractionModel): r"""Multi-slice interaction model for thick samples. Divides the sample into multiple axial slices, each modeled as an independent :py:class:`~ptyrax.models.interaction.FresnelReflection` interaction. The illumination is near-field propagated to each slice depth, interacted, and propagated back. The total exit field is the coherent sum of contributions from all slices: .. math:: \psi_{\text{exit}} = \sum_{j=1}^{N} \mathcal{P}(z_j)\left[ O_j \cdot \mathcal{P}(z_j)[P] \right] where $\mathcal{P}(z)$ denotes near-field propagation by distance $z$ and $O_j$ is the reflection coefficient of the $j$-th slice. Attributes: coordinates: Sample coordinate system shared across all slices. inner_interactions: Vmapped stack of per-slice interaction models. slice_distances: Array of axial distances from the reference plane to each slice. separable_in_z: If ``True``, uses a single shared interaction model for all slices (only varying the propagation distance). inverted_bottom: If ``True``, inverts the ordering for the bottom half of the slices. """ coordinates: IndexSliceParameter[CoordinateSystem] inner_interactions: FresnelReflection slice_distances: Float[Array, " n_slices"] # TODO remove this parameter, it is a bit of a hack for a specific use case and we can refactor the forward logic # to be more composable rather than having this separate path. separable_in_z: bool = eqx.field(static=True, default=False) inverted_bottom: bool = eqx.field(static=True, default=False) def __init__( self, coordinates: CoordinateSystem, sampling: SamplingGrid, slice_distances: Float[Array, " n_slices"], initializer: Callable[[tuple[int, int]], Complex[Array, "m n"]] = uniform, parametrization_type: Type[ArrayParametrization] = DirectArrayParametrization, regularization_functions: tuple[Callable[[Complex[Array, "* m n"]], float]] = (), inner_interactions: tuple[InteractionModel] = None, **kwargs, ) -> None: """Initialize a multi-slice interaction model. Either provide pre-built ``inner_interactions`` or specify initialization parameters to construct identical slices. Args: coordinates: Coordinate system for the sample. sampling: Sampling grid for the slice reflection coefficients. slice_distances: 1D array of axial distances (in meters) from the reference plane to each slice. initializer: Callable producing an initial reflection coefficient. parametrization_type: Parametrization class for the reflection coefficients. regularization_functions: Regularization functions applied to each slice. inner_interactions: Optional tuple of pre-built interaction models, one per slice. If provided, other initialization parameters are ignored for the coefficient arrays. **kwargs: Additional keyword arguments. Supports ``separable_in_z`` and ``inverted_bottom`` overrides. """ if inner_interactions is not None: self.inner_interactions = jax.tree.map( lambda *v: jnp.stack(v, axis=0), *inner_interactions, ) else: def make_single_interaction(index: int) -> FresnelReflection: return FresnelReflection( coordinates=coordinates, sampling=sampling, initializer=initializer, parametrization_type=parametrization_type, regularization_functions=regularization_functions, ) self.inner_interactions = vmap(make_single_interaction)(jnp.arange(len(slice_distances))) self.slice_distances = jnp.asarray(slice_distances) self.coordinates = coordinates self.separable_in_z = kwargs.get("separable_in_z", self.separable_in_z) self.inverted_bottom = kwargs.get("inverted_bottom", self.inverted_bottom)
[docs] @classmethod def from_interactions( cls, interactions: tuple[InteractionModel], distances: Float[Array, " n_slices"], **kwargs ) -> Type["MultiSlice"]: """Construct a MultiSlice from pre-built per-slice interaction models. Args: interactions: Tuple of interaction models, one per slice. distances: Array of axial slice distances. **kwargs: Additional keyword arguments forwarded to the constructor (e.g., ``separable_in_z``, ``inverted_bottom``). Returns: A new :py:class:`~ptyrax.models.interaction.MultiSlice` instance wrapping the provided interactions. """ return cls( coordinates=interactions[0].coordinates, sampling=interactions[0].sampling, slice_distances=distances, inner_interactions=interactions, **kwargs, )
@property def n_slices(self) -> int: """Number of axial slices in the model.""" return len(self.slice_distances) def __regularize__(self) -> tuple[float, list[NamedLoss]]: """Compute regularization losses aggregated over all slices. Returns: A tuple of (total_loss, named_losses) delegated to the inner vmapped interaction stack. """ return self.inner_interactions.__regularize__() # TODO this method is a bit of a hack for a specific use case. We may want to refactor the forward logic to be more # composable rather than having this separate path.
[docs] def separable_forward_fields(self, illumination_field: CoherentField) -> CoherentField: """Compute per-slice reflected fields using a single shared interaction. Uses only the first slice's interaction model for all depths, varying only the propagation distance. This is the forward pass when ``separable_in_z=True``. Args: illumination_field: Incoming coherent illumination field. Returns: A vmapped stack of reflected fields, one per slice. """ top_interaction = tree_slice_first(self.inner_interactions) backreflected_fields = eqx.filter_vmap(self.single_slice_forward, in_axes=(None, 0, None))( illumination_field, self.slice_distances, top_interaction, ) return backreflected_fields
[docs] def forward_fields(self, illumination_field: CoherentField) -> CoherentField: """Compute per-slice reflected fields using independent interactions. Each slice uses its own interaction model. This is the forward pass when ``separable_in_z=False``. Args: illumination_field: Incoming coherent illumination field. Returns: A vmapped stack of reflected fields, one per slice. """ backreflected_fields = eqx.filter_vmap(self.single_slice_forward, in_axes=(None, 0, 0))( illumination_field, self.slice_distances, self.inner_interactions, ) return backreflected_fields
[docs] @staticmethod def single_slice_forward( illumination_field: CoherentField, z_displacement: float, interaction: InteractionModel, ) -> CoherentField: """Forward model for a single slice: propagate, interact, propagate back. Args: illumination_field: Incoming coherent illumination field. z_displacement: Axial distance from the reference plane to this slice (in meters). interaction: The interaction model for this slice. Returns: The back-reflected field after round-trip propagation and interaction with the slice. """ propagated_probe = illumination_field.propagate_tilted_nearfield(z_displacement * Z_AXIS) post_interaction_field = interaction(propagated_probe, normalize=True) backreflected_field = post_interaction_field.propagate_tilted_nearfield(z_displacement * Z_AXIS) return backreflected_field
[docs] def __call__(self, illumination_field: CoherentField, normalize: bool = True) -> CoherentField: """Compute the multi-slice interaction as a coherent sum over slices. Propagates the illumination to each slice depth, computes the per-slice interaction, propagates back, and sums all contributions coherently. Args: illumination_field: Incoming coherent illumination field. normalize: Whether to apply normalization in the per-slice interactions. Returns: The total reflected coherent field (coherent sum of all slices). """ backreflected_fields = ( self.separable_forward_fields(illumination_field) if self.separable_in_z else self.forward_fields(illumination_field) ) total_field_data = jnp.sum(backreflected_fields(), axis=0) single_field = jax.tree.map(lambda leaf: leaf[0], backreflected_fields) total_field = eqx.tree_at(lambda f: f.data, single_field, total_field_data) return total_field
def __log_epoch__(self, writer: SummaryWriter, epoch: int, prefix: str = "", **kwargs) -> None: """Log per-slice distances and inner interaction diagnostics. Args: writer: TensorBoard summary writer instance. epoch: Current training epoch number. prefix: String prefix for all logged entries. **kwargs: Additional keyword arguments forwarded to inner interaction logging. """ for i in range(self.n_slices): writer.add_scalar(f"1_sample/{prefix}slice_distance_{i}", self.slice_distances[i], epoch) writer.add_scalar( f"1_sample/{prefix}_distance12_abs", jnp.abs(self.slice_distances[1] - self.slice_distances[0]), epoch ) for i in range(self.n_slices): jax.tree.map( lambda leaf: leaf[i], self.inner_interactions, ).__log_epoch__(writer, epoch, f"{prefix}inner_{i}_", **kwargs)
[docs] @gin.configurable def to_single_slice_approximation( multislice: MultiSlice, tilt_angle: float = 0.0, wavelength: float = 1.0, interaction_type: type[InteractionModel] = FresnelReflection, ) -> FresnelReflection: r"""Convert a multi-slice model to a single-slice approximation. Collapses the multi-slice reflection coefficients into a single effective reflection coefficient by coherently summing the per-slice contributions with depth-dependent phase factors: .. math:: O_{\text{eff}}(\mathbf{r}) = \sum_j O_j(\mathbf{r}) \exp\!\left(i \frac{4\pi}{\lambda} z_j \cos\theta\right) where $z_j$ is the slice distance and $\theta$ is the tilt angle. Args: multislice: The multi-slice interaction model to collapse. tilt_angle: Beam tilt angle in degrees relative to the surface normal. Affects the depth-dependent phase factor. wavelength: Illumination wavelength (in meters). interaction_type: Class of the output single-slice interaction model. Returns: A single-slice interaction model with the coherently summed reflection coefficient. Raises: TypeError: If ``multislice`` is not a :py:class:`~ptyrax.models.interaction.MultiSlice` instance. """ if not isinstance(multislice, MultiSlice): raise TypeError(f"Expected MultiSlice, got {type(multislice)}") interaction_data = multislice.inner_interactions.reflection_coefficient() phases = 2 * np.pi / wavelength * multislice.slice_distances * 2 * jnp.cos(jnp.deg2rad(tilt_angle)) phases = phases.reshape((-1, 1, 1)) new_interaction_data = jnp.sum(interaction_data * phase_only_exp(phases), axis=0) first_coordinates = unstack_tree(multislice.inner_interactions.coordinates)[0] first_sampling = unstack_tree(multislice.inner_interactions.sampling)[0] single_interaction = interaction_type( coordinates=first_coordinates, sampling=first_sampling, initializer=lambda shape: new_interaction_data, ) return single_interaction