ptyrax.models.ptychography

Contents

ptyrax.models.ptychography#

Functions

add_wavelength_channels(model, initial_model)

Extend the illumination model with additional wavelength channels.

dropout(model, epoch, optimizer_state[, ...])

Apply random dropout to probe and interaction arrays.

initialize_3d_tilted_sampling(...[, n_dim, ...])

Compute interaction, probe, and forward sampling grids from the scattering geometry.

limit_illumination_na(model, max_na)

Apply a numerical aperture limit to the illumination probe.

limit_reflection_NA(model, max_na)

Apply a numerical aperture limit to the interaction reflection coefficient.

load_model_from_reconstruction(model, ...[, ...])

Load model weights/state from a previous reconstruction HDF5 into a new model.

make_multislice(interaction, ...[, ...])

Convert a single-slice interaction model into a multi-slice model.

multiply_interaction_xiz_function(model[, ...])

Multiply the interaction reflection coefficient by a depth-dependent transfer function.

normalize_illumination(model, **kwargs)

Normalize the probe illumination to unit total energy.

offset_displacement(model, offset)

Add a constant offset to the multi-slice inter-slice distances.

plot_model(model[, show, fig, ax])

Plot the probe and illumination fields of a PtychographyModel.

plot_scattering_vector_xy(scattering_vector, ...)

Plot Fourier scattering vectors in the sample-frame XY plane.

plot_spheres_xz(sample_rotation_matrix, ...)

Plot detector, illumination, and scattering vectors on the Ewald sphere in the XZ plane.

preprocess_model(model[, preprocess_functions])

Apply a sequence of preprocessing functions to the model.

reinitialize_interaction(interaction, ...[, ...])

Reinitialize the interaction model with a new data initializer.

reinitialize_interaction_inverted(...[, ...])

Reinitialize the interaction model with inverted amplitude.

remove_detector_darkcounts(model, **kwargs)

Zero out the detector dark-count (background) correction.

replace_illumination(model[, ...])

Replace the model's illumination with a transformed version.

replace_illumination_from_hdf5(model, hdf5_path)

Replace the model's illumination with parameters loaded from an HDF5 file.

replace_interaction(model[, ...])

Replace the model's interaction with a transformed version.

replace_interaction_from_hdf5(model, hdf5_path)

Replace the model's interaction parameters with values loaded from an HDF5 file.

replace_probe(model, new_probe_generator)

Replace the model's probe field with a transformed version.

reset_illumination(model, initial_model)

Reset the illumination model to its initial state.

reset_interaction(model, initial_model)

Reset the interaction model to its initial state.

scale_illumination_equal_pixel_size(model[, ...])

Rescale the probe illumination field by interpolating to a new pixel size.

scale_model_wavelength(model[, scale, ...])

Scale the model wavelength by a multiplicative factor.

scale_scan_range(model[, scale])

Scale the sample scan positions by a multiplicative factor.

set_illumination_total_energy(model[, ...])

Scale the probe so that its total integrated intensity equals a target value.

set_interaction_real_only(model, **kwargs)

Discard the imaginary part of the interaction reflection coefficient.

set_mean_phase(interaction[, target_mean_phase])

Shift the global phase of the interaction reflection coefficient.

set_model_constant_tilt_angle(model[, ...])

Override sample and detector orientations with those corresponding to a constant tilt angle.

set_model_wavelength(model[, wavelength, ...])

Set the model wavelength to an absolute value.

set_outside_scan_range_to(model, epoch, ...)

Suppress reconstruction artifacts outside the scanned area.

shift_probe_and_interaction(model, epoch, ...)

Center the probe by shifting both probe and interaction fields.

Classes

ImagePredictionModel()

Abstract base for models that predict images (or equivalently simulates detectors in an experiment).

PtychographyModel(illumination, interaction, ...)

Top-level ptychography model composed of illumination, interaction, propagator, and detector.

class ptyrax.models.ptychography.ImagePredictionModel[source]#

Bases: Module

Abstract base for models that predict images (or equivalently simulates detectors in an experiment). Since this is an instance of eqx.Module, any instance of an image prediction model is also a pytree. It is therefore allowed to pass instances of this class to jitted functions without any hassle.

To ensure functional purity, instances of this class are frozen after initialization. This means that no fields of the instance can be modified after initialization. Instead, to modify any fields, a new instance must be created with the desired changes. This can be done using the eqx.tree_at function. For more information, see the Equinox documentation.

Since this is a dataclass, fields may be defined using annotations. If any of the fields are jax Arrays but should not be optimized over, do not mark them as static using eqx.field, this will lead to errors when trying to jit functions using the model. Instead, leave them as normal fields, and make sure that the optimizer specification does not include these parameters. Static parameters are allowed for non-array fields, such as static shapes, configuration constants, etc.

abstractmethod __call__()[source]#

The main part to predict images from the model. This will be called inside the jit-boundary, so all operations contained herein must be jittable. The output must be a single image of shape (m, n). Before this function is called, any parts of the model which are types of parametrizations will be evaluated to their underlying arrays. Therefore, inside of this function, all parametrizations can be treated as normal arrays. For example, any model parameters which are instances of IndexSliceParametrization (with leading dimension d for the dataset index), will have their leading dimension removed inside of ImagePredictionModel.__call__(). This way, the model needs not worry about including batch dimensions at all, every prediction should be just for a single image.

Returns:

The predicted image of shape (m, n).

Return type:

Float[Array, “ m n”]

abstractmethod classmethod from_image_dataset(dataset, *args, **kwargs)[source]#

Instantiates the ImagePredictionModel based on its corresponding dataset. This will be called outside the jit-boundary, so all fields of the ImageDataset are likely numpy arrays.

Parameters:
  • dataset (ImageDataset) – An instance of the ImageDataset which the model will attempt to predict. When

  • ImagePredictionModel (implementing)

  • ImageDataset. (this usually should also mean implementing a corresponding)

Returns:

The model that will be used in the optimization loop to predict images from the dataset.

Return type:

ImagePredictionModel

abstract property image_shape: tuple[TypeAliasForwardRef('jaxtyping.Integer'), TypeAliasForwardRef('jaxtyping.Integer')]#

Returns: tuple[int, int]: The shape of a single image which the model predicts.

load(file_path)[source]#

Deserialize model leaves from disk into the current model structure.

The current model instance acts as the structural template (pytree skeleton) and its leaf values are replaced with the serialized values from file_path.

Parameters:

file_path (Path) – Path to a previously saved .eqx file.

Returns:

A new model instance with the same structure as self but with leaf values loaded from the file.

Return type:

PtychographyModel

property n_indices: jaxtyping.Integer#

The number of dataset indices (scan positions) this model spans.

Determined by inspecting all IndexSliceParameter leaves in the model tree and returning the maximum leading dimension found. If no IndexSliceParameter fields exist, returns 1 with a warning.

Returns:

The number of dataset indices the model is configured for.

print_parameter_paths(prefix='')[source]#

Prints all parameters of the model to the console.

This is useful for debugging and for getting an overview of the model’s parameters. It uses the ptyrax.utils.print_parameters function, which recursively prints all parameters in a readable format.

Parameters:

prefix (str)

Return type:

None

resolve(*args)[source]#

Resolves all parametrizations in the model to their underlying arrays.

Returns:

A new instance of the model with all parametrizations resolved to arrays.

Return type:

ImagePredictionModel

save(file_path)[source]#

Serialize the model’s leaves to disk using Equinox serialization.

Saves all array leaves of the model (parameters, buffers) to a binary file that can later be loaded with load().

Parameters:

file_path (Path) – Destination file path (typically with .eqx extension).

Return type:

None

abstractmethod to_image_dataset(predicted_images)[source]#

Converts predicted images back to an ImageDataset. This is mainly useful for evaluation and logging purposes.

Parameters:
  • predicted_images (Float[Array, "* m n"]) – The predicted images from the model. These will likely come from

  • simulation. (the output of a)

Returns:

An instance of the ImageDataset containing the predicted images.

Return type:

ImageDataset

class ptyrax.models.ptychography.PtychographyModel(illumination, interaction, propagator, detector)[source]#

Bases: ImagePredictionModel

Top-level ptychography model composed of illumination, interaction, propagator, and detector.

The classmethod from_image_dataset initializes model components from a Ptychogram dataset.

Parameters:
__call__(**kwargs)[source]#

Predict images from the model.

This will be called inside the jit-boundary, so all operations contained herein must be jittable. The output must be a single image of shape (m, n). Before this function is called, any parts of the model which are types of parametrizations will be evaluated to their underlying arrays. Therefore, inside of this function, all parametrizations can be treated as normal arrays. For example, any model parameters which are instances of IndexSliceParametrization (with leading dimension d for the dataset index), will have their leading dimension removed inside of __call__(). This way, the model needs not worry about including batch dimensions at all, every prediction should be just for a single image.

Return type:

tuple[Float[Array, ‘* m n’], Bool[Array, ‘* d’]]

__plot__(*args, **kwargs)[source]#

Display the geometry of the ptychography setup.

including the detector and illumination directions on a unit sphere, as well as the Fourier scattering vectors in the sample frame. This is useful for visualizing the experimental geometry and understanding how the sample is being probed. The function computes the necessary geometric quantities and creates three subplots: one showing the detector and illumination directions on a unit sphere in the XZ plane, one showing the Fourier scattering vectors in the XY plane, and one showing the overall geometry of the sample and detector.

Return type:

None

detector: Detector#
exit_field(index=0)[source]#

Return the field immediately after the interaction model for a given dataset index.

Parameters:

index (jaxtyping.Integer)

Return type:

CoherentField

classmethod from_hdf5(file_path, *, params_root='params', **kwargs)[source]#

Instantiate a ptychography model from an HDF5 file/group params subtree.

Parameters:
  • file_path (str | PathLike | File | Group)

  • params_root (str)

Return type:

PtychographyModel

classmethod from_hdf5_state(state, *, illumination_class=<class 'ptyrax.models.illumination.DirectIllumination'>, interaction_class=<class 'ptyrax.models.interaction.FresnelReflection'>, detector_class=<class 'ptyrax.models.detector.BackgroundEqualWeightDetector'>, propagator_class=<class 'ptyrax.models.propagation.FarfieldPropagator'>)[source]#

Reconstruct a PtychographyModel from a flat HDF5 state dictionary.

This method rebuilds the model’s component hierarchy (illumination, interaction, detector, propagator) from a dictionary of named arrays typically produced by load_hdf5_state().

Parameters:
  • state (Dict[str, ndarray]) – Dictionary mapping HDF5 dataset paths to numpy arrays, as returned by load_hdf5_state().

  • illumination_class (type[IlluminationModel]) – Class to use for the illumination model. Must implement a from_coherent_field classmethod.

  • interaction_class (type[InteractionModel]) – Class to use for the sample interaction model.

  • detector_class (type[Detector]) – Class to use for the detector model. Must implement a from_hdf5_state classmethod.

  • propagator_class (type[Propagator]) – Class to use for the field propagator.

Returns:

A PtychographyModel with parameters populated from the HDF5 state.

Raises:
  • TypeError – If illumination_class does not implement from_coherent_field.

  • ValueError – If required keys are missing from state.

Return type:

PtychographyModel

classmethod from_image_dataset(ptychogram, illumination_class=<class 'ptyrax.models.illumination.DirectIllumination'>, probe_initializer=<function aperture>, interaction_class=<class 'ptyrax.models.interaction.FresnelReflection'>, interaction_initializer=<function uniform>, detector_class=<class 'ptyrax.models.detector.BackgroundEqualWeightDetector'>, propagator_class=<class 'ptyrax.models.propagation.FarfieldPropagator'>, *, tensorboard_writer=None, key=Array([ 0, 42], dtype=uint32), fixed_sampling=None)[source]#

Construct a PtychographyModel from a ptychography dataset.

This factory method computes the experimental geometry (sample/detector coordinate systems), determines the Fourier sampling grids from the scattering geometry, initializes the illumination probe and interaction model, and assembles all components into a complete forward model.

Parameters:
  • ptychogram (Ptychogram) – The Ptychogram dataset containing diffraction patterns, scan positions, and experimental metadata.

  • illumination_class (type[IlluminationModel]) – Class to use for the illumination model.

  • probe_initializer (Callable[[SamplingGrid], Complex[Array, 'n d']]) – Callable that generates initial probe field data on a given SamplingGrid.

  • interaction_class (type[InteractionModel]) – Class to use for the sample interaction model.

  • interaction_initializer (Callable[[SamplingGrid], Complex[Array, 'n d']]) – Callable that generates initial interaction (e.g. reflection coefficient) data.

  • detector_class (type[Detector]) – Class to use for the detector model.

  • propagator_class (type[Propagator]) – Class to use for the field propagator.

  • tensorboard_writer (SummaryWriter) – Optional TensorBoard writer for logging sampling geometry during initialization.

  • key (Key) – JAX PRNG key for stochastic initialization.

  • fixed_sampling (tuple[TypeAliasForwardRef('jaxtyping.Integer'), TypeAliasForwardRef('jaxtyping.Integer'), TypeAliasForwardRef('jaxtyping.Integer'), TypeAliasForwardRef('jaxtyping.Integer')]) – If provided, a tuple (interaction_shape, interaction_pixel_size, probe_shape, probe_pixel_size) that bypasses automatic sampling computation.

Returns:

A fully initialized PtychographyModel.

Return type:

None

illumination: IlluminationModel#
property image_shape: tuple[TypeAliasForwardRef('jaxtyping.Integer'), TypeAliasForwardRef('jaxtyping.Integer')]#

Returns: tuple[int, int]: The shape of a single image which the model predicts.

static initialize_detector_coordinates(ptychogram)[source]#

Create a CoordinateSystem for detector positions and orientations from a Ptychogram.

The detector translations are normalized for better optimization performance. Output is wrapped in IndexSliceParametrization to specify indexing over the dataset dimension.

Parameters:

ptychogram (Ptychogram) – Source dataset containing detector_orientations and detector_positions.

Returns:

A CoordinateSystem normalized for detector geometry initialization.

Return type:

CoordinateSystem

static initialize_sample_coordinates(ptychogram)[source]#

Create a CoordinateSystem for sample positions and orientations from a Ptychogram. The sample translations are normalized for better optimization performance. Output is wrapped in IndexSliceParametrization to specify indexing over the dataset dimension.

Parameters:

ptychogram (Ptychogram) – Source dataset containing sample_orientations and sample_positions.

Returns:

A CoordinateSystem with normalized translations suitable for initializing interactions.

Return type:

CoordinateSystem

interaction: InteractionModel#
load(file_path)#

Deserialize model leaves from disk into the current model structure.

The current model instance acts as the structural template (pytree skeleton) and its leaf values are replaced with the serialized values from file_path.

Parameters:

file_path (Path) – Path to a previously saved .eqx file.

Returns:

A new model instance with the same structure as self but with leaf values loaded from the file.

Return type:

PtychographyModel

property n_indices: jaxtyping.Integer#

The number of dataset indices (scan positions) this model spans.

Determined by inspecting all IndexSliceParameter leaves in the model tree and returning the maximum leading dimension found. If no IndexSliceParameter fields exist, returns 1 with a warning.

Returns:

The number of dataset indices the model is configured for.

print_parameter_paths(prefix='')#

Prints all parameters of the model to the console.

This is useful for debugging and for getting an overview of the model’s parameters. It uses the ptyrax.utils.print_parameters function, which recursively prints all parameters in a readable format.

Parameters:

prefix (str)

Return type:

None

propagator: Propagator#
resolve(*args)#

Resolves all parametrizations in the model to their underlying arrays.

Returns:

A new instance of the model with all parametrizations resolved to arrays.

Return type:

ImagePredictionModel

save(file_path)#

Serialize the model’s leaves to disk using Equinox serialization.

Saves all array leaves of the model (parameters, buffers) to a binary file that can later be loaded with load().

Parameters:

file_path (Path) – Destination file path (typically with .eqx extension).

Return type:

None

to_image_dataset(predicted_diffraction_patterns)[source]#

Convert model state and predicted diffraction patterns into a Ptychogram.

Packs the predicted images together with the model’s current geometric parameters (sample/detector positions, orientations, wavelength, pixel size) into a dataset suitable for saving or comparison with measured data.

Parameters:

predicted_diffraction_patterns (Float[Array, '* m n']) – Predicted intensity patterns with shape matching the number of scan positions and detector pixels.

Returns:

A Ptychogram populated with predictions and the model’s current geometry.

Return type:

Ptychogram

ptyrax.models.ptychography.add_wavelength_channels(model, initial_model, additional_wavelengths=0)[source]#

Extend the illumination model with additional wavelength channels.

Pads the leading (wavelength) dimension of the probe data array with additional_wavelengths new entries initialized to the defaults from initial_model. This enables polychromatic reconstructions by growing the spectral dimension during training.

Parameters:
  • model (PtychographyModel) – The current ptychography model.

  • initial_model (PtychographyModel) – Model providing default values for the new channels.

  • additional_wavelengths (jaxtyping.Integer) – Number of wavelength channels to add.

Returns:

A new PtychographyModel with extended wavelength dimension.

Return type:

PtychographyModel

ptyrax.models.ptychography.dropout(model, epoch, optimizer_state, apply_every=1, max_epoch=1800, probe_re='.*probe.*data', interaction_re='.*interaction.*data', fraction=0.5, fraction_decay=0.983, *, key=None)[source]#

Apply random dropout to probe and interaction arrays.

Randomly zeroes a fraction of elements in the probe and interaction data arrays (identified by regex on their pytree paths). The dropout fraction decays exponentially with epoch as fraction * fraction_decay ** epoch. This regularizer can help prevent overfitting.

Parameters:
  • model (PtychographyModel) – The ptychography model to modify.

  • epoch (int) – Current epoch (used with apply_every, max_epoch, and decay).

  • optimizer_state (optax.OptState) – Current optimizer state (returned unchanged).

  • apply_every (int) – Only apply every apply_every epochs.

  • max_epoch (int) – Stop applying after this epoch.

  • probe_re (str) – Regex pattern matching probe data paths in the pytree.

  • interaction_re (str) – Regex pattern matching interaction data paths.

  • fraction (float) – Base dropout fraction (probability of zeroing each element).

  • fraction_decay (float) – Exponential decay rate for the dropout fraction per epoch.

  • key (Optional[Key]) – JAX PRNG key for generating the random dropout mask.

Returns:

A tuple (model_with_dropout, optimizer_state).

Raises:

ValueError – If key is None.

Return type:

tuple[PtychographyModel, PyTree[PtychographyModel]]

ptyrax.models.ptychography.initialize_3d_tilted_sampling(detector_sampling, shape, sample_coordinates, detector_coordinates, wavelengths, n_dim=2, fourier_oversampling_factor=array([1., 1.]), real_oversampling_factor=array([1., 1.]), probe_fourier_oversampling_factor=None, writer=None, epoch=None, prefix='')[source]#

Compute interaction, probe, and forward sampling grids from the scattering geometry.

Given the detector and sample coordinate systems, this function determines the Fourier bounds of the scattering vectors in the sample frame and derives real-space pixel sizes that satisfy the Nyquist condition for the tilted geometry. It returns three SamplingGrid instances defining the discretization for the interaction (sample), probe, and forward propagation fields. For most use cases, the probe and forward sampling grids will be the same, but separate oversampling factors can be provided for flexibility.

The real-space pixel size is computed as:

\[\Delta x = \frac{\lambda_{\min}}{\xi_{\max}}\]

where $xi_{max}$ is the maximum angular frequency extent in the sample frame and $lambda_{min}$ is the shortest wavelength.

Parameters:
  • detector_sampling (SamplingGrid) – Pixel grid of the detector (shape and pixel size).

  • shape (tuple[TypeAliasForwardRef('jaxtyping.Integer'), TypeAliasForwardRef('jaxtyping.Integer')]) – Nominal detector shape (nx, ny) used as base for grid dimensions.

  • sample_coordinates (CoordinateSystem) – Sample positions and orientations (may be wrapped in IndexSliceParameter).

  • detector_coordinates (CoordinateSystem) – Detector positions and orientations.

  • wavelengths (Array) – Array of illumination wavelengths.

  • n_dim (Literal[2, 3]) – Dimensionality of the scattering geometry, 2 (planar) or 3 (full 3-D Fourier bounds including $xi_z$).

  • fourier_oversampling_factor (ndarray) – Multiplicative factor(s) applied to the Fourier bounds for the forward grid.

  • real_oversampling_factor (ndarray) – Multiplicative factor(s) for the number of grid points in the forward grid.

  • probe_fourier_oversampling_factor (ndarray) – If provided, separate oversampling factor(s) for the probe grid (otherwise matches forward grid).

  • writer (SummaryWriter | None) – Optional TensorBoard writer for logging geometry plots.

  • epoch (TypeAliasForwardRef('jaxtyping.Integer') | None) – Epoch number for TensorBoard logging.

  • prefix (str) – Prefix string for TensorBoard tags.

Returns:

A tuple (interaction_sampling, probe_sampling, forward_sampling) of SamplingGrid instances.

Return type:

tuple[SamplingGrid, SamplingGrid, SamplingGrid]

ptyrax.models.ptychography.limit_illumination_na(model, max_na)[source]#

Apply a numerical aperture limit to the illumination probe.

Propagates the probe to Fourier (angular) space, applies an elliptical aperture mask defined by max_na, and propagates back. Frequencies beyond the NA limit are zeroed, effectively band-limiting the probe.

Parameters:
  • model (PtychographyModel) – The ptychography model to modify. Must use DirectIllumination.

  • max_na (float | tuple[float, float]) – Maximum numerical aperture as a scalar (isotropic) or a tuple (na_x, na_y) for anisotropic limiting.

Returns:

A new PtychographyModel with the NA-limited probe.

Raises:

ValueError – If max_na is not a scalar or 2-element tuple, or if the illumination is not DirectIllumination.

Return type:

DirectIllumination

ptyrax.models.ptychography.limit_reflection_NA(model, max_na)[source]#

Apply a numerical aperture limit to the interaction reflection coefficient.

Multiplies the reflection coefficient by an elliptical aperture mask in real space (which corresponds to a Fourier-space NA limit for the reflected field). Frequencies beyond max_na are zeroed.

Parameters:
  • model (PtychographyModel) – The ptychography model to modify. Must use FresnelReflection.

  • max_na (float | tuple[float, float]) – Maximum numerical aperture as a scalar (isotropic) or a tuple (na_x, na_y) for anisotropic limiting.

Returns:

A new PtychographyModel with the NA-limited interaction.

Raises:

ValueError – If max_na is not a scalar or 2-element tuple, or if the interaction is not FresnelReflection.

Return type:

InteractionModel

ptyrax.models.ptychography.load_model_from_reconstruction(model, reconstruction_path, policy_map={'.*': {'default': 'pad'}}, **kwargs)[source]#

Load model weights/state from a previous reconstruction HDF5 into a new model.

Parameters:
  • dataset – Dataset used to construct the model shape.

  • reconstruction_path (str) – Path to the HDF5 file containing saved model parameters.

  • policy_map (dict) – Mapping of regex patterns to handling policies when applying HDF5.

  • model (ImagePredictionModel)

Returns:

An ImagePredictionModel instance with parameters loaded from reconstruction_path.

Return type:

ImagePredictionModel

ptyrax.models.ptychography.make_multislice(interaction, interaction_generators, slice_displacements, separable_in_z=False, inverted_bottom=True, symmetric=False)[source]#

Convert a single-slice interaction model into a multi-slice model.

Creates a MultiSlice by generating individual interaction slices from the given generators and stacking them at the specified z-displacements.

Parameters:
  • interaction (InteractionModel) – The base interaction model used as template for each slice.

  • interaction_generators (list[Callable[[InteractionModel], InteractionModel]]) – List of callables that each produce a new interaction model from the base. If a single generator is provided, it is reused for all slices.

  • slice_displacements (list[TypeAliasForwardRef('jaxtyping.Float')]) – Z-positions (depths) of each slice. Length must match interaction_generators (or 1 generator is broadcast).

  • separable_in_z (bool) – Whether to treat slices as separable in z during propagation.

  • inverted_bottom (bool) – Whether the bottom slice uses an inverted coordinate frame (reflection geometry).

  • symmetric (bool) – If True, centers the slice displacements around zero.

Returns:

A MultiSlice model.

Raises:

ValueError – If the number of generators does not match the number of slice displacements.

Return type:

MultiSlice

ptyrax.models.ptychography.multiply_interaction_xiz_function(model, tilt_angle=0.0, thickness=0.0, xiz_function=<PjitFunction of <function cos>>)[source]#

Multiply the interaction reflection coefficient by a depth-dependent transfer function.

Applies a zero-order correction for sample depth by multiplying the reflection coefficient in Fourier space by a function of the out-of-plane spatial frequency $xi_z$. The transfer function models the effect of finite sample thickness on the scattered wave.

The $xi_z$ component is computed from the tilt geometry as:

\[\xi_z = k \left( \sqrt{1 - (\xi_x + s_x)^2 - (\xi_y + s_y)^2} - s_z \right)\]

where $k = 2pi / lambda$ and $(s_x, s_y, s_z)$ is the specular direction.

Parameters:
  • model (PtychographyModel) – The ptychography model to modify.

  • tilt_angle (jaxtyping.Float) – Sample tilt angle in degrees, used to compute the specular direction.

  • thickness (jaxtyping.Float) – Physical thickness of the sample layer. Controls the argument to xiz_function as thickness / 2 * xi_z.

  • xiz_function (Callable[[Inexact[Array, 'm n']], Inexact[Array, 'm n']]) – Function applied element-wise to the scaled $xi_z$ array. Common choices are jnp.cos (default, for two-layer models) or jnp.sinc (for rectangular models).

Returns:

A new PtychographyModel with the modified interaction.

Raises:

TypeError – If the interaction is not a FresnelReflection.

Return type:

PtychographyModel

ptyrax.models.ptychography.normalize_illumination(model, **kwargs)[source]#

Normalize the probe illumination to unit total energy.

Divides the probe data by its L2 norm so that \(\sum |\mathrm{probe}|^2 = 1\). This removes amplitude ambiguity between probe and interaction during optimization.

Parameters:

model (PtychographyModel) – The ptychography model to modify.

Returns:

A new PtychographyModel with the normalized probe.

Return type:

PtychographyModel

ptyrax.models.ptychography.offset_displacement(model, offset)[source]#

Add a constant offset to the multi-slice inter-slice distances.

Shifts all slice_distances in the interaction model by a fixed offset. This is useful for adjusting the nominal depth separation between slices in a multi-slice reconstruction.

Parameters:
  • model (PtychographyModel) – The ptychography model to modify (must have a multi-slice interaction).

  • offset (tuple[TypeAliasForwardRef('jaxtyping.Float'), TypeAliasForwardRef('jaxtyping.Float')]) – Offset to add to each slice distance, as a tuple (dz_0, dz_1, ...).

Returns:

A new PtychographyModel with adjusted slice distances.

Return type:

PtychographyModel

ptyrax.models.ptychography.plot_model(model, show=True, fig=None, ax=None, **kwargs)[source]#

Plot the probe and illumination fields of a PtychographyModel.

Displays the model’s probe field and illumination state side by side. This is a gin-configurable convenience function for quick visual inspection of the current model state during reconstruction.

Parameters:
  • model (PtychographyModel) – The ptychography model to visualize.

  • show (bool) – Whether to call plt.show() after plotting.

  • fig (Figure) – Existing figure to draw into. If None, a new figure is created.

  • ax (Axes) – Axes array with at least two elements. If None, new axes are created.

  • **kwargs – Additional keyword arguments forwarded to plot().

Returns:

A tuple of (figure, axes, images) where images is the combined list of AxesImage objects from both subplots.

Return type:

tuple[Figure, SubplotSpec, list[AxesImage]]

ptyrax.models.ptychography.plot_scattering_vector_xy(scattering_vector, sample_rotation_matrix, fig=None, gs=None)[source]#

Plot Fourier scattering vectors in the sample-frame XY plane.

Projects the scattering vectors (momentum transfer) into the sample coordinate system and displays their $xi_x$ vs $xi_y$ components. This is useful for verifying lateral Fourier coverage of the experiment.

Parameters:
  • scattering_vector (Float[Array, 's n m d']) – Scattering vectors in lab frame, shape (s, n, m, d).

  • sample_rotation_matrix (Float[Array, 's 3 3']) – Rotation matrices from lab to sample frame, shape (s, 3, 3).

  • fig (Figure) – Existing matplotlib figure to draw into. If None, a new figure is created.

  • gs (GridSpec) – A SubplotSpec for axis placement within fig.

Returns:

A tuple of (figure, axes) containing the plotted geometry.

Return type:

tuple[Figure, Axes]

ptyrax.models.ptychography.plot_spheres_xz(sample_rotation_matrix, detector_coordinate_sphere, illumination_coordinate_sphere, scattering_vector, fig=None, gs=None)[source]#

Plot detector, illumination, and scattering vectors on the Ewald sphere in the XZ plane.

Visualizes the Fourier-space geometry of the ptychography experiment by projecting the detector and illumination direction cosines, as well as the resulting scattering vectors, into the sample-frame XZ plane ($xi_x$ vs $xi_z$).

Parameters:
  • sample_rotation_matrix (Float[Array, 'n 3 3']) – Rotation matrices transforming from lab to sample frame, shape (n, 3, 3).

  • detector_coordinate_sphere (Float[Array, 'n m 3']) – Unit vectors towards detector pixels in lab frame, shape (n, m, 3).

  • illumination_coordinate_sphere (Float[Array, 'n k 3']) – Unit vectors towards illumination directions in lab frame, shape (n, k, 3).

  • scattering_vector (Float[Array, 'n m k 3']) – Difference vectors (detector - illumination) representing momentum transfer, shape (n, m, k, 3).

  • fig (Figure) – Existing matplotlib figure to draw into. If None, a new figure is created.

  • gs (GridSpec) – A SubplotSpec for axis placement within fig.

Returns:

A tuple of (figure, axes) containing the plotted geometry.

Return type:

tuple[Figure, Axes]

ptyrax.models.ptychography.preprocess_model(model, preprocess_functions=())[source]#

Apply a sequence of preprocessing functions to the model.

Parameters:
  • model (ImagePredictionModel) – The ImagePredictionModel or model to preprocess.

  • preprocess_functions (tuple[Callable[[ImagePredictionModel], ImagePredictionModel], ...]) – Tuple of callables applied in order. Each callable should take an ImagePredictionModel

  • ImagePredictionModel. (and return a processed)

Returns:

The processed ImagePredictionModel

Return type:

ImageDataset

ptyrax.models.ptychography.reinitialize_interaction(interaction, new_data_initializer, new_type=None)[source]#

Reinitialize the interaction model with a new data initializer.

Creates a fresh interaction model of the same (or different) type using the existing coordinate system and sampling grids but with newly generated reflection coefficient data.

Parameters:
  • interaction (InteractionModel) – The current interaction model (used for coordinates, sampling, and regularization functions).

  • new_data_initializer (Callable[[tuple[TypeAliasForwardRef('jaxtyping.Integer'), TypeAliasForwardRef('jaxtyping.Integer')]], InteractionModel]) – Callable that generates new coefficient data given a SamplingGrid.

  • new_type (Type[InteractionModel]) – Optional alternative interaction class. If None, uses the same type as the input.

Returns:

A new InteractionModel with reinitialized data.

Return type:

InteractionModel

ptyrax.models.ptychography.reinitialize_interaction_inverted(interaction, new_data_initializer, new_type=None)[source]#

Reinitialize the interaction model with inverted amplitude.

Like reinitialize_interaction(), but inverts the amplitude of the generated data: pixels with high amplitude in the initializer get low amplitude in the result and vice versa. Phase is set to zero. This is useful for initializing complementary or “negative” interaction patterns.

Parameters:
  • interaction (InteractionModel) – The current interaction model (used for coordinates, sampling, and regularization functions).

  • new_data_initializer (Callable[[tuple[TypeAliasForwardRef('jaxtyping.Integer'), TypeAliasForwardRef('jaxtyping.Integer')]], InteractionModel]) – Callable that generates coefficient data given a SamplingGrid.

  • new_type (Type[InteractionModel]) – Optional alternative interaction class. If None, uses the same type as the input.

Returns:

A new InteractionModel with amplitude-inverted data.

Return type:

InteractionModel

ptyrax.models.ptychography.remove_detector_darkcounts(model, **kwargs)[source]#

Zero out the detector dark-count (background) correction.

Replaces the detector’s dark_counts array with zeros, effectively disabling dark-frame subtraction. Useful when reusing a model that was initialized with measured dark counts but the current reconstruction should not apply that correction.

Parameters:

model (PtychographyModel) – The ptychography model to modify.

Returns:

A new PtychographyModel with zeroed dark counts.

Return type:

PtychographyModel

ptyrax.models.ptychography.replace_illumination(model, new_illumination_generator=<function identity>)[source]#

Replace the model’s illumination with a transformed version.

Applies a generator function to the current illumination model and substitutes the result into the ptychography model.

Parameters:
Returns:

A new PtychographyModel with the replaced illumination.

Return type:

PtychographyModel

ptyrax.models.ptychography.replace_illumination_from_hdf5(model, hdf5_path, hdf5_illumination_path='illumination', illumination_adjustment_fns=None, data_only=False, normalize=False, **kwargs)[source]#

Replace the model’s illumination with parameters loaded from an HDF5 file.

Loads previously saved illumination state from an HDF5 reconstruction file. Optionally applies a sequence of adjustment functions after loading, and can restrict the replacement to data-only (interpolated to the current grid) with optional normalization.

Parameters:
  • model (PtychographyModel) – The ptychography model to modify.

  • hdf5_path (str) – Path to the HDF5 file containing saved illumination parameters. If None or empty, the function returns the model unchanged.

  • hdf5_illumination_path (str) – Group path prefix within the HDF5 file.

  • illumination_adjustment_fns (list[Callable[[IlluminationModel], IlluminationModel]]) – Optional list of callables applied sequentially to the loaded illumination.

  • data_only (bool) – If True, only the probe data array is replaced (interpolated to match the current model’s grid), leaving coordinates and metadata unchanged.

  • normalize (bool) – If True and data_only=True, normalize the loaded probe data to unit energy.

  • **kwargs – Additional keyword arguments forwarded to apply_hdf5_to_model().

Returns:

A new PtychographyModel with illumination loaded from HDF5.

Return type:

PtychographyModel

ptyrax.models.ptychography.replace_interaction(model, new_interaction_generator=<function identity>, **kwargs)[source]#

Replace the model’s interaction with a transformed version.

Applies one or more generator functions sequentially to the current interaction model and replaces it in the ptychography model. This is the primary mechanism for swapping interaction types or applying structural changes (e.g. converting to multi-slice).

Parameters:
Returns:

A new PtychographyModel with the replaced interaction.

Return type:

PtychographyModel

ptyrax.models.ptychography.replace_interaction_from_hdf5(model, hdf5_path, hdf5_interaction_path='interaction', **kwargs)[source]#

Replace the model’s interaction parameters with values loaded from an HDF5 file.

Loads previously saved interaction model state from an HDF5 reconstruction file and applies it to the current model’s interaction subtree.

Parameters:
  • model (PtychographyModel) – The ptychography model to modify.

  • hdf5_path (str) – Path to the HDF5 file containing saved interaction parameters.

  • hdf5_interaction_path (str) – Group path prefix within the HDF5 file where the interaction parameters are stored.

  • **kwargs – Additional keyword arguments forwarded to apply_hdf5_to_model().

Returns:

A new PtychographyModel with interaction loaded from HDF5.

Return type:

PtychographyModel

ptyrax.models.ptychography.replace_probe(model, new_probe_generator)[source]#

Replace the model’s probe field with a transformed version.

Applies one or more generator functions sequentially to the current probe (CoherentField) and updates the model. Only works with DirectIllumination.

Parameters:
Returns:

A new PtychographyModel with the replaced probe.

Raises:

TypeError – If the illumination model is not DirectIllumination.

Return type:

PtychographyModel

ptyrax.models.ptychography.reset_illumination(model, initial_model)[source]#

Reset the illumination model to its initial state.

Replaces the current (optimized) illumination with the one from initial_model. This is useful for restarting the illumination optimization from scratch while preserving other model components.

Parameters:
Returns:

A new PtychographyModel with the initial illumination.

Return type:

PtychographyModel

ptyrax.models.ptychography.reset_interaction(model, initial_model)[source]#

Reset the interaction model to its initial state.

Replaces the current (optimized) interaction model with the one from initial_model. This is useful for restarting the interaction optimization from scratch while preserving other model components.

Parameters:
Returns:

A new PtychographyModel with the initial interaction.

Return type:

PtychographyModel

ptyrax.models.ptychography.scale_illumination_equal_pixel_size(model, scale=1.0)[source]#

Rescale the probe illumination field by interpolating to a new pixel size.

Multiplies the probe’s pixel size by scale and interpolates the probe data onto the new grid, preserving the field of view in pixels while changing the physical extent. This is typically used when the wavelength is modified and the probe must be adjusted to maintain consistent Fourier sampling.

Parameters:
  • model (PtychographyModel) – The ptychography model whose illumination will be rescaled.

  • scale (jaxtyping.Float) – Multiplicative factor for the pixel size. Values > 1 enlarge pixels (zoom out), values < 1 shrink pixels (zoom in).

Returns:

A new PtychographyModel with the rescaled probe data.

Return type:

PtychographyModel

ptyrax.models.ptychography.scale_model_wavelength(model, scale=1.0, rescale_illumination=True, **kwargs)[source]#

Scale the model wavelength by a multiplicative factor.

Multiplies all wavelength entries by scale. Optionally rescales the illumination probe pixel size to maintain consistent Fourier sampling after the wavelength change.

Parameters:
  • model (PtychographyModel) – The ptychography model to modify.

  • scale (jaxtyping.Float) – Multiplicative factor applied to the wavelength.

  • rescale_illumination (bool) – If True, also rescales the probe via scale_illumination_equal_pixel_size() to compensate for the wavelength change.

Returns:

A new PtychographyModel with the scaled wavelength.

Return type:

PtychographyModel

ptyrax.models.ptychography.scale_scan_range(model, scale=1.0)[source]#

Scale the sample scan positions by a multiplicative factor.

Multiplies all interaction coordinate translations by scale. This is useful for correcting miscalibrated scan step sizes or converting between units.

Parameters:
  • model (PtychographyModel) – The ptychography model to modify.

  • scale (jaxtyping.Float) – Multiplicative factor for all scan translations.

Returns:

A new PtychographyModel with rescaled scan positions.

Return type:

PtychographyModel

ptyrax.models.ptychography.set_illumination_total_energy(model, total_energy=1.0)[source]#

Scale the probe so that its total integrated intensity equals a target value.

Computes the current total energy (sum of probe intensities) and scales the probe data by $sqrt{E_{text{target}} / E_{text{current}}}$ to achieve the desired total energy.

Parameters:
  • model (PtychographyModel) – The ptychography model to modify.

  • total_energy (jaxtyping.Float) – Desired total probe energy (sum of pixel intensities).

Returns:

A new PtychographyModel with the scaled probe.

Return type:

PtychographyModel

ptyrax.models.ptychography.set_interaction_real_only(model, **kwargs)[source]#

Discard the imaginary part of the interaction reflection coefficient.

Replaces the complex reflection coefficient with its real part only. This is useful for enforcing a purely absorptive (no phase) sample model or for resetting phase artifacts.

Parameters:

model (PtychographyModel) – The ptychography model to modify.

Returns:

A new PtychographyModel with a real-valued reflection coefficient.

Return type:

PtychographyModel

ptyrax.models.ptychography.set_mean_phase(interaction, target_mean_phase=0.0)[source]#

Shift the global phase of the interaction reflection coefficient.

Applies a constant phase rotation to the reflection coefficient so that its spatial mean matches target_mean_phase. This removes global phase ambiguity that can accumulate during optimization.

Parameters:
  • interaction (InteractionModel) – The interaction model to modify.

  • target_mean_phase (jaxtyping.Float) – Desired mean phase angle (radians) of the reflection coefficient.

Returns:

A new InteractionModel with the adjusted phase.

Return type:

InteractionModel

ptyrax.models.ptychography.set_model_constant_tilt_angle(model, tilt_angle=0.0, detector_tilt_angle=None, **kwargs)[source]#

Override sample and detector orientations with those corresponding to a constant tilt angle.

Sets all sample orientations to a uniform rotation about the y-axis by tilt_angle degrees, recomputes sample positions in the new frame, and derives the specular detector orientation and position. An independent detector_tilt_angle can be specified when the detector does not sit exactly at the specular reflection.

Parameters:
  • model (PtychographyModel) – The ptychography model to modify.

  • tilt_angle (float) – Sample tilt angle in degrees (rotation about y-axis).

  • detector_tilt_angle (float | None) – If provided, overrides the detector orientation independently from the sample tilt.

Returns:

A new PtychographyModel with updated coordinate systems.

Return type:

PtychographyModel

ptyrax.models.ptychography.set_model_wavelength(model, wavelength=1.0, rescale_illumination=True, **kwargs)[source]#

Set the model wavelength to an absolute value.

Replaces the current wavelength(s) with the given value (broadcast to match shape). Optionally rescales the illumination probe to maintain consistent Fourier sampling at the new wavelength.

Parameters:
  • model (PtychographyModel) – The ptychography model to modify.

  • wavelength (jaxtyping.Float) – The new wavelength value in the same units as the model.

  • rescale_illumination (bool) – If True, rescales the probe via scale_illumination_equal_pixel_size() to compensate.

Returns:

A new PtychographyModel with the updated wavelength.

Return type:

PtychographyModel

ptyrax.models.ptychography.set_outside_scan_range_to(model, epoch, optimizer_state, extra_range_factor=1.0, set_to=<function mean>, deviation_scale=0.2, apply_every=1)[source]#

Suppress reconstruction artifacts outside the scanned area.

Replaces the interaction (reflection coefficient) values at positions beyond the scan range with a constant or averaged value. A soft suppression term further damps deviations from the replacement value to prevent edge artifacts from growing during optimization.

Parameters:
  • model (PtychographyModel) – The ptychography model to modify.

  • epoch (jaxtyping.Integer) – Current optimization epoch (used with apply_every).

  • optimizer_state (PyTree[ptyrax.models.ptychography.PtychographyModel]) – Current optimizer state (unused but required by the adjuster signature).

  • extra_range_factor (tuple[TypeAliasForwardRef('jaxtyping.Float')] | TypeAliasForwardRef('jaxtyping.Float')) – Multiplicative factor applied to scan displacements when computing the boundary radius.

  • set_to (Callable | Shaped[Array, ''] | Shaped[ndarray, ''] | bool | number | bool | int | float | complex | Shaped[LiteralArray, '']) – Value or callable to compute the replacement value. If callable, it is applied to the full coefficient array (e.g. jnp.mean).

  • deviation_scale (jaxtyping.Float) – Fraction controlling how strongly existing deviations outside the boundary are suppressed.

  • apply_every (jaxtyping.Integer) – Only apply the adjustment every apply_every epochs.

Returns:

A new PtychographyModel with suppressed out-of-range values.

Return type:

PtychographyModel

ptyrax.models.ptychography.shift_probe_and_interaction(model, epoch, optimizer_state, apply_every=1, max_epoch=1800, compute_shift_fn=<function compute_center_of_mass_shift>, perform_shift_fn=<function shift_image>, probe_re='.*probe.*data', interaction_re='.*interaction.*data', order=1, **kwargs)[source]#

Center the probe by shifting both probe and interaction fields.

Computes the center-of-mass offset of the probe amplitude and applies the corresponding sub-pixel shift to both the probe and interaction arrays (identified by regex on their pytree paths). The optimizer state arrays matching those paths are shifted as well to maintain consistency.

This adjuster helps prevent local minima where the probe starts off-center and grows to hit the edge of the reconstruction grid.

Parameters:
  • model (PtychographyModel) – The ptychography model to modify.

  • epoch (int) – Current epoch (used with apply_every and max_epoch).

  • optimizer_state (optax.OptState) – Current optimizer state; matching leaves are also shifted.

  • apply_every (int) – Only apply every apply_every epochs.

  • max_epoch (int) – Stop applying after this epoch.

  • compute_shift_fn (Callable[[Inexact[Array, 'm n']], Float[Array, '2']]) – Function to compute the 2-D shift from a spatial array.

  • perform_shift_fn (Callable[[Inexact[Array, 'm n']], Inexact[Array, '... m n']]) – Function to apply the shift to a spatial array.

  • probe_re (str) – Regex pattern matching probe data paths in the pytree.

  • interaction_re (str) – Regex pattern matching interaction data paths in the pytree.

  • order (int) – Interpolation order for the shift.

Returns:

A tuple (shifted_model, shifted_optimizer_state).

Return type:

tuple[PtychographyModel, PyTree[PtychographyModel]]