ptyrax.reconstruct#

Functions

make_batches(n, batch_size, shuffle_key)

Create shuffled batch index arrays for one epoch.

post_training(model, log_dir, output_file[, ...])

Save model outputs after training completes.

reconstruct(dataset_path, output_file, log_dir)

Trains a model on a ptychogram dataset and saves the results.

train_batch(model, scanning_position_index, ...)

Perform a single gradient step on one mini-batch.

train_epoch(model, dataset, optimizer, ...)

Run a single training epoch over all batches.

train_model(model, dataset, optimizer, ...)

Top-level training loop that logs initial state and delegates to train_session().

train_session(model, dataset, optimizer, ...)

Trains a model on a ptychogram dataset.

ptyrax.reconstruct.make_batches(n, batch_size, shuffle_key)[source]#

Create shuffled batch index arrays for one epoch.

Parameters:
  • n (int) – Total number of samples (must be divisible by batch_size).

  • batch_size (int) – Number of samples per batch.

  • shuffle_key (Key) – JAX PRNG key for random permutation.

Returns:

Integer array of shape (n // batch_size, batch_size) containing shuffled sample indices.

Return type:

Integer[Array, ‘num_batches batch_size’]

ptyrax.reconstruct.post_training(model, log_dir, output_file, save_equinox_model=True)[source]#

Save model outputs after training completes.

Writes the model as HDF5 and optionally as an equinox binary (.eqx), and converts TensorBoard logs to an HDF5 archive.

Parameters:
  • model (PtychographyModel) – Trained model to serialize.

  • log_dir (str) – Directory containing TensorBoard event files.

  • output_file (str) – Destination HDF5 file for model parameters.

  • save_equinox_model (bool) – Also save raw .eqx binary.

Return type:

None

ptyrax.reconstruct.reconstruct(dataset_path, output_file, log_dir, dataset_load_fn=<function from_hdf5>, preprocess_functions=(), continue_from_reconstruction=None, sweep_id=None, *, key, model_type=<class 'ptyrax.models.ptychography.PtychographyModel'>, **kwargs)[source]#

Trains a model on a ptychogram dataset and saves the results.

Parameters:
Return type:

ImagePredictionModel

ptyrax.reconstruct.train_batch(model, scanning_position_index, measured_diffraction_pattern, optimizer, optimizer_state)[source]#

Perform a single gradient step on one mini-batch.

Computes the loss and gradient, conjugates the Wirtinger derivative, applies the optimizer update, and returns the updated model.

Parameters:
  • model (ImagePredictionModel) – Current model state.

  • scanning_position_index (Integer[Array, 'b']) – Batch of scanning position indices.

  • measured_diffraction_pattern (Float[Array, 'b w h']) – Measured images for this batch.

  • optimizer (GradientTransformation) – Optax gradient transformation.

  • optimizer_state (Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) – Current optimizer state.

Returns:

Tuple of ((total_loss, named_losses), model, optimizer_state).

Return type:

Tuple[Tuple[float, Any], ImagePredictionModel, Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]]

ptyrax.reconstruct.train_epoch(model, dataset, optimizer, optimizer_state, step, epoch, epoch_logger, *, key)[source]#

Run a single training epoch over all batches.

Shuffles the dataset into batches and performs one gradient update per batch using jax.lax.scan().

Parameters:
  • model (ImagePredictionModel) – Current model state.

  • dataset (ImageDataset) – Training dataset containing measured images.

  • optimizer (GradientTransformation) – Optax gradient transformation.

  • optimizer_state (Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) – Current optimizer state.

  • step (int) – Global training step counter (unused, reserved).

  • epoch (int) – Current epoch index (0-based).

  • epoch_logger (Callable) – Callback invoked at end of epoch with (epoch, model, total_loss, losses, ordered=False).

  • key (Key) – JAX PRNG key for batch shuffling.

Returns:

Tuple of (model, optimizer_state, step, epoch) with updated values.

Return type:

Tuple[ImagePredictionModel, Any, int, int]

ptyrax.reconstruct.train_model(model, dataset, optimizer, optimizer_state, writer, *, key=Array([0, 42], dtype=uint32), **kwargs)[source]#

Top-level training loop that logs initial state and delegates to train_session().

Parameters:
  • model (PtychographyModel) – Model to train.

  • dataset (ImageDataset) – Training dataset.

  • optimizer (GradientTransformation) – Optax gradient transformation.

  • optimizer_state (Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) – Initial optimizer state.

  • writer (SummaryWriter) – TensorBoard SummaryWriter for logging.

  • key (Key) – JAX PRNG key.

  • **kwargs – Extra keyword arguments forwarded to train_session.

Returns:

Tuple of (trained_model, trained_optimizer_state).

Return type:

Tuple[PtychographyModel, Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]]

ptyrax.reconstruct.train_session(model, dataset, optimizer, optimizer_state, batch_logger, epoch_logger, num_epochs=30, epoch_callbacks=(), initial_epoch=0, *, key, **kwargs)[source]#

Trains a model on a ptychogram dataset. The model is trained for a number of epochs, with a given learning rate.

Parameters:
  • model (PtychographyModel) – The model to train.

  • ptychogram (Ptychogram) – The dataset to train on.

  • batch_logger (Callable) – A function that logs the loss for each batch.

  • epoch_logger (Callable) – A function that logs the state of the model at the end of each epoch.

  • num_epochs (int) – The number of epochs to train for.

  • epoch_callbacks (Iterable[Callable]) – A list of callback functions which are run after each epoch.

  • initial_epoch (int) – The epoch to start training from.

  • dataset (ImageDataset)

  • optimizer (GradientTransformation)

  • optimizer_state (Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree])

  • key (Key)

Return type:

Tuple[ImagePredictionModel, Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]]