ptyrax.reconstruct#
Functions
|
Create shuffled batch index arrays for one epoch. |
|
Save model outputs after training completes. |
|
Trains a model on a ptychogram dataset and saves the results. |
|
Perform a single gradient step on one mini-batch. |
|
Run a single training epoch over all batches. |
|
Top-level training loop that logs initial state and delegates to |
|
Trains a model on a ptychogram dataset. |
- ptyrax.reconstruct.make_batches(n, batch_size, shuffle_key)[source]#
Create shuffled batch index arrays for one epoch.
- Parameters:
n (int) – Total number of samples (must be divisible by
batch_size).batch_size (int) – Number of samples per batch.
shuffle_key (Key) – JAX PRNG key for random permutation.
- Returns:
Integer array of shape
(n // batch_size, batch_size)containing shuffled sample indices.- Return type:
Integer[Array, ‘num_batches batch_size’]
- ptyrax.reconstruct.post_training(model, log_dir, output_file, save_equinox_model=True)[source]#
Save model outputs after training completes.
Writes the model as HDF5 and optionally as an equinox binary (
.eqx), and converts TensorBoard logs to an HDF5 archive.- Parameters:
model (PtychographyModel) – Trained model to serialize.
log_dir (str) – Directory containing TensorBoard event files.
output_file (str) – Destination HDF5 file for model parameters.
save_equinox_model (bool) – Also save raw
.eqxbinary.
- Return type:
None
- ptyrax.reconstruct.reconstruct(dataset_path, output_file, log_dir, dataset_load_fn=<function from_hdf5>, preprocess_functions=(), continue_from_reconstruction=None, sweep_id=None, *, key, model_type=<class 'ptyrax.models.ptychography.PtychographyModel'>, **kwargs)[source]#
Trains a model on a ptychogram dataset and saves the results.
- Parameters:
dataset_path (str)
output_file (Path)
log_dir (str)
dataset_load_fn (Callable[[str], ImageDataset])
preprocess_functions (list[Callable[[ImageDataset], ImageDataset]])
continue_from_reconstruction (str | None)
sweep_id (str)
key (Key)
model_type (type[ImagePredictionModel])
- Return type:
- ptyrax.reconstruct.train_batch(model, scanning_position_index, measured_diffraction_pattern, optimizer, optimizer_state)[source]#
Perform a single gradient step on one mini-batch.
Computes the loss and gradient, conjugates the Wirtinger derivative, applies the optimizer update, and returns the updated model.
- Parameters:
model (ImagePredictionModel) – Current model state.
scanning_position_index (Integer[Array, 'b']) – Batch of scanning position indices.
measured_diffraction_pattern (Float[Array, 'b w h']) – Measured images for this batch.
optimizer (GradientTransformation) – Optax gradient transformation.
optimizer_state (Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) – Current optimizer state.
- Returns:
Tuple of
((total_loss, named_losses), model, optimizer_state).- Return type:
Tuple[Tuple[float, Any], ImagePredictionModel, Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]]
- ptyrax.reconstruct.train_epoch(model, dataset, optimizer, optimizer_state, step, epoch, epoch_logger, *, key)[source]#
Run a single training epoch over all batches.
Shuffles the dataset into batches and performs one gradient update per batch using
jax.lax.scan().- Parameters:
model (ImagePredictionModel) – Current model state.
dataset (ImageDataset) – Training dataset containing measured images.
optimizer (GradientTransformation) – Optax gradient transformation.
optimizer_state (Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) – Current optimizer state.
step (int) – Global training step counter (unused, reserved).
epoch (int) – Current epoch index (0-based).
epoch_logger (Callable) – Callback invoked at end of epoch with
(epoch, model, total_loss, losses, ordered=False).key (Key) – JAX PRNG key for batch shuffling.
- Returns:
Tuple of
(model, optimizer_state, step, epoch)with updated values.- Return type:
Tuple[ImagePredictionModel, Any, int, int]
- ptyrax.reconstruct.train_model(model, dataset, optimizer, optimizer_state, writer, *, key=Array([0, 42], dtype=uint32), **kwargs)[source]#
Top-level training loop that logs initial state and delegates to
train_session().- Parameters:
model (PtychographyModel) – Model to train.
dataset (ImageDataset) – Training dataset.
optimizer (GradientTransformation) – Optax gradient transformation.
optimizer_state (Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) – Initial optimizer state.
writer (SummaryWriter) – TensorBoard SummaryWriter for logging.
key (Key) – JAX PRNG key.
**kwargs – Extra keyword arguments forwarded to
train_session.
- Returns:
Tuple of
(trained_model, trained_optimizer_state).- Return type:
Tuple[PtychographyModel, Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]]
- ptyrax.reconstruct.train_session(model, dataset, optimizer, optimizer_state, batch_logger, epoch_logger, num_epochs=30, epoch_callbacks=(), initial_epoch=0, *, key, **kwargs)[source]#
Trains a model on a ptychogram dataset. The model is trained for a number of epochs, with a given learning rate.
- Parameters:
model (PtychographyModel) – The model to train.
ptychogram (Ptychogram) – The dataset to train on.
batch_logger (Callable) – A function that logs the loss for each batch.
epoch_logger (Callable) – A function that logs the state of the model at the end of each epoch.
num_epochs (int) – The number of epochs to train for.
epoch_callbacks (Iterable[Callable]) – A list of callback functions which are run after each epoch.
initial_epoch (int) – The epoch to start training from.
dataset (ImageDataset)
optimizer (GradientTransformation)
optimizer_state (Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree])
key (Key)
- Return type:
Tuple[ImagePredictionModel, Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]]