Source code for ptyrax.field

from __future__ import annotations

import warnings
from typing import Callable, Self

import equinox as eqx
import gin
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jaxtyping import Array, Bool, Complex, Float, PRNGKeyArray
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.gridspec import SubplotSpec
from matplotlib.image import AxesImage
from tensorboardX import SummaryWriter

from ptyrax.initializers import gaussian
from ptyrax.logger import log_image
from ptyrax.parametrizations import ArrayParametrization
from ptyrax.spatial import CoordinateSystem, SamplingGrid, interpolate_grid_to_grid
from ptyrax.utils import fft, ifft, plot

X_AXIS = jnp.array([1.0, 0.0, 0.0])
Y_AXIS = jnp.array([0.0, 1.0, 0.0])
Z_AXIS = jnp.array([0.0, 0.0, 1.0])


[docs] @gin.configurable class CoherentField(eqx.Module): """Representation of a coherent field (probe or object) with sampling and wavelength metadata.""" data: Float[Array, "* m n d"] | ArrayParametrization wavelength: Float[Array, ""] = eqx.field(converter=lambda x: jnp.array(x)) sampling: SamplingGrid coordinate_system: CoordinateSystem = eqx.field(default_factory=lambda: CoordinateSystem()) propagation_direction: Array = eqx.field(default_factory=lambda: Z_AXIS.copy()) spatial_dims: tuple = eqx.field(static=True, default=(-2, -3)) vector_dim: int = eqx.field(static=True, default=-1)
[docs] def __call__(self) -> Float[Array, "* m n d"]: """Return the underlying field data, resolving any parametrization. If the data is wrapped in an :py:class:`~ptyrax.parametrizations.ArrayParametrization`, this evaluates it to produce the raw array. Otherwise returns the data directly. Returns: The complex-valued field array. """ if isinstance(self.data, ArrayParametrization): return self.data() return self.data
@property def _data(self) -> Float[Array, "* m n d"]: warnings.warn("Using _data is deprecated. Use the .data property instead.", DeprecationWarning) return self.data @property def shape(self) -> tuple[int, ...]: """Shape of the underlying data array.""" return self.data.shape def _apply_fftshift(self, target_shifted: bool) -> CoherentField: """Apply or undo FFT-shift along spatial dimensions. Args: target_shifted: Whether the result should be FFT-shifted. Returns: A new :py:class:`~ptyrax.field.CoherentField` with (un)shifted data and updated sampling metadata. """ if self.sampling.fftshifted == target_shifted: return self shift_fn = jnp.fft.fftshift if target_shifted else jnp.fft.ifftshift new_data = shift_fn(self.data, axes=self.spatial_dims) new_data = jnp.asarray(new_data) if isinstance(self.data, jnp.ndarray) else type(self.data)(new_data) new_sampling = SamplingGrid( self.sampling.shape, self.sampling.pixel_size, self.sampling.origin_shift, fftshifted=target_shifted, ) return CoherentField( new_data, self.wavelength, new_sampling, self.coordinate_system, self.propagation_direction, self.spatial_dims, self.vector_dim, ) @property def to_fft_shifted(self) -> CoherentField: """Return an FFT-shifted version of this field. Applies ``jnp.fft.fftshift`` along the spatial dimensions so that the zero-frequency component is centered. If the field is already shifted, returns ``self`` unchanged. Returns: A new :py:class:`~ptyrax.field.CoherentField` with shifted data and updated sampling metadata. """ return self._apply_fftshift(target_shifted=True) @property def to_fft_unshifted(self) -> CoherentField: """Return a non-FFT-shifted version of this field. Applies ``jnp.fft.ifftshift`` along the spatial dimensions to restore the standard array ordering. If the field is already unshifted, returns ``self`` unchanged. Returns: A new :py:class:`~ptyrax.field.CoherentField` with unshifted data and updated sampling metadata. """ return self._apply_fftshift(target_shifted=False) @property def amplitude(self) -> Float[Array, "* m n d"]: """Absolute value (amplitude) of the field, FFT-shifted if applicable.""" amplitude = jnp.abs(self.data) if self.sampling.fftshifted: amplitude = jnp.fft.fftshift(amplitude, axes=self.spatial_dims) return amplitude @property def intensity(self) -> Float[Array, "* m n d"]: r"""Squared magnitude (intensity) of the field: :math:`|E|^2`. The result is real-valued. """ intensity = self.data * self.data.conj() if self.sampling.fftshifted: intensity = jnp.fft.fftshift(intensity, axes=self.spatial_dims) return intensity.real def __getitem__(self, item: slice) -> Self: """Slice the field along its leading (mode/batch) dimensions. All sub-components (data, wavelength, sampling, coordinate_system, propagation_direction) are sliced consistently. Args: item: Index or slice to apply along the leading dimensions. Returns: A new :py:class:`~ptyrax.field.CoherentField` containing the selected subset. """ return CoherentField( self.data[item], self.wavelength[item], self.sampling[item], self.coordinate_system[item], self.propagation_direction[item], self.spatial_dims, self.vector_dim, )
[docs] def __plot__( self, show: bool = True, gs: SubplotSpec = None, fig: Figure = None, **kwargs, ) -> tuple[Figure, Axes, SubplotSpec]: """Visualize the field as amplitude and phase images. Uses the first vector component (index 0 along the last axis) and shows amplitude/phase in the same plot using a bivariate colormap via :py:func:`~ptyrax.utils.plot`. Args: show: Whether to call ``plt.show()`` after plotting. gs: Optional matplotlib ``SubplotSpec`` to draw into. fig: Optional matplotlib ``Figure`` to use. **kwargs: Additional keyword arguments forwarded to :py:func:`~ptyrax.utils.plot`. Returns: A tuple of (Figure, Axes, SubplotSpec). """ if gs is None: fig = plt.figure() gs = fig.add_gridspec(1, 1)[0] data = self.data if self.sampling.fftshifted: data = np.fft.fftshift(data, axes=(0, 1)) extent = self.sampling.extent return plot(data[..., 0], show=show, extent=extent, gs=gs, fig=fig, **kwargs)
def __log_epoch__(self, writer: SummaryWriter, epoch: int, prefix: str = "", **kwargs) -> None: """Log field state to TensorBoard at a given epoch. Logs per-mode wavelength scalars, a probe image, and the field norm. Args: writer: TensorBoardX summary writer instance. epoch: Current training epoch number. prefix: Optional string prefix for logged tags. **kwargs: Additional keyword arguments (unused). """ for i, wavelength in enumerate(self.wavelength): writer.add_scalar(f"2_probe/{prefix}wavelength/{i}", wavelength, epoch) log_image(writer, f"0/2_probe/{prefix}image", self, epoch, title="probe") writer.add_scalar(f"2_probe/{prefix}norm", jnp.linalg.norm(self()), epoch)
[docs] def propagate_fraunhofer(self, distance: float, inverse: bool = False, fftshift: bool = True) -> CoherentField: r"""Propagate the field to the far field using the Fraunhofer approximation. Computes the Fourier transform (or inverse) of the first vector component and updates the sampling grid to far-field coordinates. Args: distance: Propagation distance used to compute the far-field pixel size. inverse: If ``True``, use the inverse FFT (propagate back to near field). fftshift: Whether to apply ``fftshift`` to the result. Returns: A new :py:class:`~ptyrax.field.CoherentField` with propagated data and far-field sampling. """ propagated_probe_data = ( ifft(self()[..., 0], fftshift=fftshift) if inverse else fft(self()[..., 0], fftshift=fftshift) ) propagated_probe_data = propagated_probe_data[..., None] propagated_probe_sampling = self.sampling.to_far_field(self.wavelength, distance, fftshifted=not fftshift) propagated_probe = eqx.tree_at(lambda p: p.data, self, propagated_probe_data) propagated_probe = eqx.tree_at(lambda p: p.sampling, propagated_probe, propagated_probe_sampling) return propagated_probe
[docs] def propagate_tilted_nearfield( self, displacement: Float[Array, "3"], n_medium: complex = 1.0 + 0j ) -> CoherentField: r"""Propagate the field in the near-field regime along a tilted direction. Performs angular-spectrum propagation by applying a phase ramp in Fourier space corresponding to a 3-D displacement vector. Evanescent components (those outside the Ewald sphere) are set to zero. Args: displacement: 3-D displacement vector ``[dx, dy, dz]`` in real-space units. Only the z-component contributes to the propagation phase; the transverse components are currently ignored. Returns: A new :py:class:`~ptyrax.field.CoherentField` with propagated data. """ from ptyrax.models.propagation import nearfield_propagation_coefficient_fourier pupil_xi_z, valid = self.k_z(n_medium) propagation_factor = nearfield_propagation_coefficient_fourier( k_z=pupil_xi_z, z_distance=displacement[-1], valid=valid, ) propagated_probe = self.multiply_fourier(propagation_factor) return propagated_probe
[docs] def k_z(self, n_medium: complex = 1.0 + 0j) -> tuple[Float[Array, "m n"], Bool[Array, "m n"]]: """Compute the z-component of the wave vector in the field's internal coordinate system.""" pupil_sampling = self.sampling.to_far_field( wavelength=self.wavelength, propagation_distance=1.0, fftshifted=False, ) pupil_meshgrid = pupil_sampling.meshgrid pupil_rotation_matrix = self.coordinate_system.rotation.as_matrix() internal_propagation_direction = pupil_rotation_matrix @ self.propagation_direction pupil_meshgrid = pupil_meshgrid + internal_propagation_direction[:2, None, None] / self.wavelength pupil_xi_z = jnp.sqrt(n_medium**2 / (self.wavelength) ** 2 - jnp.sum(pupil_meshgrid**2, axis=0)) valid = ~jnp.isnan(pupil_xi_z) & (pupil_xi_z.imag == 0) pupil_xi_z = jnp.nan_to_num(pupil_xi_z, nan=0.0) return pupil_xi_z, valid
[docs] def multiply_fourier(self, factor: Complex[Array, "m n"]) -> CoherentField: """Multiply the field by a Fourier-space factor. The factor is applied to the first vector component in Fourier space, and the result is transformed back to real space. Args: factor: Complex array of shape ``(m, n)`` representing the Fourier space factor to apply. Returns: A new :py:class:`~ptyrax.field.CoherentField` with the modified data. """ pupil = self.propagate_fraunhofer(1.0, fftshift=True) modified_pupil_data = pupil()[..., 0] * factor modified_probe_data = ifft(modified_pupil_data, fftshift=True) modified_probe_data = modified_probe_data[..., None] modified_probe = eqx.tree_at(lambda p: p.data, self, modified_probe_data) return modified_probe
[docs] def multiply_real(self, factor: Float[Array, "m n"]) -> CoherentField: """Multiply the field by a real-valued factor in real space.""" modified_probe_data = self() * factor[..., None] modified_probe = eqx.tree_at(lambda p: p.data, self, modified_probe_data) return modified_probe
def __add__(self, other: CoherentField) -> CoherentField: return eqx.tree_at(lambda p: p.data, self, self() + other())
[docs] @gin.configurable def plot_field(field: CoherentField, show: bool = True, **kwargs) -> tuple[Figure, SubplotSpec, list[AxesImage]]: """Plot a coherent field showing amplitude and phase. Extracts the first mode and first vector component, applies FFT-shift if needed, and delegates to :py:func:`~ptyrax.utils.plot`. Args: field: The coherent field to visualize. show: Whether to call ``plt.show()`` after plotting. **kwargs: Additional keyword arguments forwarded to :py:func:`~ptyrax.utils.plot`. Returns: A tuple of (Figure, SubplotSpec, list of AxesImage). """ data = field.flattened.data[0, :, :, 0] if field.sampling.fftshifted: data = np.fft.fftshift(data) fov = field.flattened.field_of_view extent = [jnp.squeeze(fov[0, 0]), jnp.squeeze(fov[0, 1]), jnp.squeeze(fov[1, 0]), jnp.squeeze(fov[1, 1])] return plot(data, show=show, extent=extent, **kwargs)
[docs] @gin.configurable def initialize_new_probe( old_probe: CoherentField, new_initializer_functions: tuple[Callable[[tuple[int, int]], CoherentField]], *, key: PRNGKeyArray = jax.random.PRNGKey(0), ) -> CoherentField: """Create a new probe by reinitializing each mode with a given function. Each mode of ``old_probe`` is reinitialized using the corresponding initializer function. If a single initializer is provided it is broadcast to all modes. Args: old_probe: The existing probe whose shape and metadata are preserved. new_initializer_functions: A tuple of callables, each accepting a shape tuple and a ``key`` keyword argument and returning an array for one mode. A single-element tuple is broadcast to all modes. key: JAX PRNG key used for random initialization. Returns: A new :py:class:`~ptyrax.field.CoherentField` with reinitialized data. Raises: ValueError: If the number of initializer functions does not match the number of probe modes (and is not exactly one). """ if len(new_initializer_functions) == 1: new_initializer_functions = new_initializer_functions * old_probe.shape[0] if len(new_initializer_functions) != old_probe.shape[0]: raise ValueError( f"Number of new initializers {len(new_initializer_functions)} does not match " f"number of probe modes {old_probe.shape[0]}" ) def reinitialize_single_mode( old_mode: CoherentField, new_initializer_function_idx: int, *, key: PRNGKeyArray, ) -> CoherentField: initializer = new_initializer_functions[new_initializer_function_idx] new_mode_data = initializer(old_mode.shape[:-1], key=key) new_mode = eqx.tree_at( lambda p: p.data, old_mode, type(old_mode.data)(new_mode_data[None, ..., None]), ) return new_mode reinitialize_modes = eqx.filter_vmap(reinitialize_single_mode, in_axes=(0, 0, 0)) new_probe = reinitialize_modes( old_probe, jnp.arange(len(new_initializer_functions)), jax.random.split(key, old_probe.shape[0]), ) return new_probe
[docs] @gin.configurable def transpose_probe(old_probe: CoherentField) -> CoherentField: """Swap the spatial axes of a probe field. Transposes the two spatial dimensions and interpolates onto a new :py:class:`~ptyrax.spatial.SamplingGrid` with swapped shape and pixel size. Args: old_probe: The probe field to transpose. Returns: A new :py:class:`~ptyrax.field.CoherentField` with transposed spatial dimensions. """ new_probe_data = np.swapaxes(old_probe(), -1, -2) sampled_grid = old_probe.sampling transpose_grid = SamplingGrid.from_tuples(sampled_grid.shape[::-1], sampled_grid.pixel_size[0, ::-1]) new_probe_data = eqx.filter_vmap( lambda p: interpolate_grid_to_grid( p[..., 0], transpose_grid, CoordinateSystem(), sampled_grid, CoordinateSystem() ) )(new_probe_data) return eqx.tree_at(lambda p: p.data, old_probe, type(old_probe.data)(new_probe_data[None, ..., None]))
[docs] @gin.configurable def remove_field_phase(old_probe: CoherentField) -> CoherentField: """Remove the phase of a field, keeping only the amplitude. Args: old_probe: The coherent field whose phase is to be removed. Returns: A new :py:class:`~ptyrax.field.CoherentField` with real-valued (amplitude-only) data. """ probe_new = np.abs(old_probe()) return eqx.tree_at(lambda p: p.data, old_probe, type(old_probe.data)(probe_new[None, ..., None]))
[docs] @gin.configurable def replace_defocus( old_field: CoherentField, defocus_amount: float | tuple[float, float], axis: str = "xy", ) -> CoherentField: """Replace the field phase with a defocus along the specified axis or axes. Args: old_field: The coherent field to modify. defocus_amount: Defocus strength in meters. A single float for single-axis defocus, or a tuple ``(defocus_x, defocus_y)`` for ``axis="xy"``. axis: Which axis to defocus: ``"x"``, ``"y"``, or ``"xy"``. Returns: A new :py:class:`~ptyrax.field.CoherentField` with the specified defocus phase. Raises: ValueError: If ``axis`` is not one of ``"x"``, ``"y"``, or ``"xy"``. """ if axis == "x": defocus_tuple = (defocus_amount, 0.0) elif axis == "y": defocus_tuple = (0.0, defocus_amount) elif axis == "xy": defocus_tuple = defocus_amount if isinstance(defocus_amount, tuple) else (defocus_amount, defocus_amount) else: raise ValueError(f"Invalid axis: {axis!r}. Must be 'x', 'y', or 'xy'.") return replace_field_phase( old_field, defocus_phase_generator(old_field, defocus_amount=defocus_tuple), )
[docs] @gin.configurable def replace_defocus_x(old_field: CoherentField, defocus_amount: float) -> CoherentField: """Replace the field phase with a defocus along the x-axis only.""" return replace_defocus(old_field, defocus_amount, axis="x")
[docs] @gin.configurable def replace_defocus_y(old_field: CoherentField, defocus_amount: float) -> CoherentField: """Replace the field phase with a defocus along the y-axis only.""" return replace_defocus(old_field, defocus_amount, axis="y")
[docs] @gin.configurable def replace_defocus_xy(old_field: CoherentField, defocus_amount: tuple[float, float]) -> CoherentField: """Replace the field phase with independent defocus along both axes.""" return replace_defocus(old_field, defocus_amount, axis="xy")
[docs] def defocus_phase_generator(old_field: CoherentField, defocus_amount: tuple[float, float]) -> Float[Array, "n m 1"]: r"""Generate a quadratic defocus phase profile for a field. Uses :py:func:`~ptyrax.initializers.gaussian` with a negative radius to produce a pure defocus (quadratic phase) without amplitude modulation. Args: old_field: The field whose sampling grid defines the coordinate space. defocus_amount: Tuple of ``(defocus_x, defocus_y)`` in meters. Returns: A complex-valued array of shape ``(n, m, 1)`` representing the defocus phase factor. """ return gaussian( old_field.sampling, radius=(-1, -1), defocus=defocus_amount, )
[docs] def replace_field_phase( old_probe: CoherentField, phase_or_generator: Complex[Array, "n m 1"] | Callable[[CoherentField], Complex[Array, "n m 1"]] | CoherentField, ) -> CoherentField: """Replace the phase of a field with a new phase profile. The amplitude of the original field is preserved while its phase is replaced by the phase of the provided array, callable, or field. Args: old_probe: The coherent field whose phase will be replaced. phase_or_generator: The new phase source as either a complex array, a callable from :py:class:`~ptyrax.field.CoherentField` to a complex array, or another :py:class:`~ptyrax.field.CoherentField`. Returns: A new :py:class:`~ptyrax.field.CoherentField` with the original amplitude and the new phase. """ if isinstance(phase_or_generator, CoherentField): phase = phase_or_generator() elif callable(phase_or_generator): phase = phase_or_generator(old_probe) else: phase = phase_or_generator if len(phase.shape) == 2: phase = phase[..., None] old_probe_data = old_probe() if len(old_probe_data.shape) == 4: phase = phase[None, ...] probe_new = np.abs(old_probe()) * phase / jnp.abs(phase) new_data = type(old_probe.data) if isinstance(old_probe.data, ArrayParametrization) else probe_new return eqx.tree_at(lambda p: p.data, old_probe, new_data)
[docs] def multiply_field_phase( old_probe: CoherentField, phase_or_generator: Complex[Array, "n m"] | Callable[[CoherentField], Complex[Array, "n m"]], ) -> CoherentField: """Multiply the field by a phase-only factor. Unlike :py:func:`replace_field_phase`, this adds phase on top of the existing field phase (and amplitude) by multiplying the field with a unit-amplitude complex factor. Args: old_probe: The coherent field to modify. phase_or_generator: Phase factor to multiply, provided as a complex array or as a callable that accepts a :py:class:`~ptyrax.field.CoherentField` and returns a complex array. Returns: A new :py:class:`~ptyrax.field.CoherentField` with the additional phase applied. """ if callable(phase_or_generator): phase = phase_or_generator(old_probe) else: phase = phase_or_generator probe_new = old_probe() * phase / jnp.abs(phase) return eqx.tree_at(lambda p: p.data, old_probe, type(old_probe)(probe_new))
[docs] @gin.configurable def set_probe_wavelength( old_probe: CoherentField, wavelength: float | list | tuple | np.ndarray | jnp.ndarray ) -> CoherentField: """Set the wavelength(s) of a probe field. If a single float is given it is broadcast to all modes. An array of wavelengths must match the number of probe modes. Args: old_probe: The probe field to update. wavelength: Wavelength value(s) in meters. Either a single float (applied to all modes) or a sequence matching the mode count. Returns: A new :py:class:`~ptyrax.field.CoherentField` with updated wavelength. Raises: ValueError: If the number of provided wavelengths does not match the number of probe modes. """ if isinstance(wavelength, float): wavelength = jnp.array([wavelength] * old_probe.shape[0]) else: wavelength = jnp.array(wavelength) if len(wavelength) != old_probe.shape[0]: raise ValueError( f"Number of new wavelengths {len(wavelength)} does not match number of probe modes {old_probe.shape[0]}" ) return eqx.tree_at(lambda p: p.wavelength, old_probe, wavelength)