Custom model: Through-focus wavefront sensing#
This notebook demonstrates through-focus wavefront sensing using the ptyrax framework. We will build a custom ImagePredictionModel that predicts intensity images at different focal planes, following the same structure as the custom_prediction_model tutorial.
Physical model#
In through-focus wavefront sensing, we measure intensity images at a series of axial positions relative to the nominal focus. At each defocus position \(z\) (measured in waves), the field at the detector is the Fourier transform of the aberrated pupil multiplied by a quadratic defocus phase:
where \(P(u,v)\) is the complex pupil function containing the aberrations we wish to recover, \(R\) is the aperture radius, and \(z\) is the defocus in waves. The measured intensity is:
By collecting images across a range of defocus values, the phase diversity introduced by the defocus masks makes the problem well-conditioned enough to recover the pupil aberrations \(P\) via gradient-based optimization.
Initialization#
As with any custom model in ptyrax, we start by defining the dataset. Our ThroughFocusDataset stores the measured intensity images alongside the corresponding defocus distances. Under the hood, ptyrax will call from_image_dataset() to initialize the model from this dataset, so all dataset information (here: defocus distances) is available at model construction time.
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import logging
import sys
import gin
try:
if INITIALIZED: # type: ignore
pass
except NameError:
gin.enter_interactive_mode()
root = logging.getLogger()
root.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
root.addHandler(handler)
INITIALIZED = True
import dataclasses
import pathlib
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Float
import ptyrax.dataset
from ptyrax.utils import load_hdf5
@dataclasses.dataclass
class ThroughFocusDataset(ptyrax.dataset.ImageDataset):
intensity_images: Float[Array, "d m n"]
defocus_distances: Float[Array, "d"]
@property
def images(self) -> jnp.ndarray:
return self.intensity_images
def to_gpu(self) -> None:
self.intensity_images = jnp.asarray(self.intensity_images)
self.defocus_distances = jnp.asarray(self.defocus_distances)
@classmethod
def load_from(cls, path: pathlib.Path) -> "ThroughFocusDataset":
data = load_hdf5(path)
return cls(**data)
n_planes = 50
example_intensity_data = np.random.random((n_planes, 128, 128)).astype(np.float32)
example_defocus_data = np.linspace(-5, 5, n_planes).astype(np.float32)
dataset = ThroughFocusDataset(
intensity_images=example_intensity_data,
defocus_distances=example_defocus_data,
)
/home/ssenhorst1/workspace/ptyrax/.venv/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
INFO:2026-06-08 11:26:28,742:jax._src.xla_bridge:822: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2026-06-08 11:26:28,742 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
If desired we could also add other properties to the dataset, which may later be used by our predictive model, but for now let’s keep things simple.
By default, ImageDataset implements functionality for saving and loading the dataset to the HDF5 format. If we want to use the command-line interface for running reconstructions, we can save and load the dataset to disk.
dataset.save("through_focus_dataset.hdf5")
dataset = ThroughFocusDataset.load_from("through_focus_dataset.hdf5")
Next, we implement our prediction model. We have to implement the following abstract methods: __call__(), from_image_dataset(), to_image_dataset(), and image_shape().
Image prediction modules are dataclasses, so we specify the parameters of the model as field annotations; the corresponding __init__ is generated automatically.
The from_image_dataset() classmethod is called first. It receives an instance of our dataset and must return an initialized model. The defocus distances are known quantities (not to be optimized), so we store them as an IndexSliceParameter; this allows the forward model to retrieve the correct defocus scalar for each dataset index during training.
The __call__() is where the heart of the model is. This is the code that does the actual prediction. This code is compiled using JAX tracing, so it is actually only run once during tracing, after which it runs on the GPU. Because of this all operations should be jittable. This also means that using debugging tools like printing or breakpoints for debugging might not work as expected. Instead, use jax debugging tools. For more info, see the JAX documentation. For each dataset index it:
Retrieves the current defocus scalar \(z\) from
self.defocus_distance.at_current_index().Constructs the defocus phase mask \(e^{i 2\pi z (u^2+v^2)/R^2}\) in pupil coordinates.
Multiplies the aberrated pupil by the defocus mask and Fourier-transforms to the focal plane.
Returns the intensity (squared magnitude).
Warning
Beware not to accidentally use any numpy operations inside __call__()! This will give a jax.errors.TracerArrayConversionError.
import equinox as eqx
from jaxtyping import Complex
from ptyrax.models.ptychography import ImagePredictionModel
from ptyrax.parametrizations import IndexSliceParameter
from ptyrax.utils import fft, phase_only_exp
class ThroughFocusPredictionModel(ImagePredictionModel):
# The complex pupil function encoding the aberrations we wish to recover.
aberrations: Complex[Array, "m n"]
# The defocus distance (in waves) for each dataset index. These are known quantities
# and should not be optimized; simply ensure no optimizer is matched to them.
# Inside __call__ this resolves to a scalar Float array at the current index.
defocus_distance: IndexSliceParameter[Float[Array, ""]]
@classmethod
def from_image_dataset(
cls,
image_dataset: ThroughFocusDataset,
) -> "ThroughFocusPredictionModel":
# Initialize aberrations to unit amplitude with a circular aperture (no aberration).
aberrations = jnp.ones(image_dataset.image_shape, dtype=jnp.complex64)
x, y = jnp.indices(image_dataset.image_shape)
rr = (x - image_dataset.image_shape[0] / 2) ** 2 + (y - image_dataset.image_shape[1] / 2) ** 2
aberrations = aberrations.at[jnp.sqrt(rr) > (image_dataset.image_shape[0] / 4)].set(0)
# Store the defocus distances as an IndexSliceParameter so __call__ can
# retrieve the scalar defocus value for the current dataset index.
defocus_distance = IndexSliceParameter(jnp.asarray(image_dataset.defocus_distances, dtype=jnp.float32))
return cls(aberrations=aberrations, defocus_distance=defocus_distance)
def to_image_dataset(self, predicted_images: Float[Array, "d m n"]) -> ThroughFocusDataset:
return ThroughFocusDataset(
intensity_images=predicted_images,
defocus_distances=self.defocus_distance.all,
)
def __call__(self) -> Float[Array, " m n"]:
# Retrieve the defocus (in waves) for the current dataset index.
defocus = self.defocus_distance.at_current_index()
# Build normalised radial coordinate squared in the pupil plane.
# The aperture radius is shape[0]/4, consistent with from_image_dataset.
shape = self.aberrations.shape
x, y = jnp.indices(shape)
aperture_radius = shape[0] / 4
rr_normalized = ((x - shape[0] / 2) ** 2 + (y - shape[1] / 2) ** 2) / aperture_radius**2
# Defocus phase mask: phi(u,v) = 2*pi*z * (r/R)^2
defocus_mask = phase_only_exp(2 * jnp.pi * defocus * rr_normalized)
# Propagate the aberrated, defocused pupil to the focal plane.
field = fft(self.aberrations * defocus_mask)
return jnp.abs(field) ** 2
@property
def image_shape(self) -> tuple[int, int]:
return self.aberrations.shape
We can check that we can instantiate our prediction model from the dataset we created:
model = ThroughFocusPredictionModel.from_image_dataset(dataset)
Now that we have our model and dataset, we can start a training session using train_model(). We need to specify an optimizer. Ptyrax supports per-parameter optimizers matched via regex patterns acting on the pytree path of each leaf. To view all paths we can use a small helper function.
import jax.tree
def print_all_paths(pytree: eqx.Module, prefix: str = "") -> None:
jax.tree.map_with_path( # noqa: F821
lambda path, x: logging.info(f"{prefix}{'.'.join([p.name for p in path])}"),
pytree,
)
print_all_paths(model)
2026-06-08 11:26:30,672 - root - INFO - aberrations
2026-06-08 11:26:30,707 - root - INFO - defocus_distance.parameters
As we can see, defocus_distance became a nested structure when wrapped as an IndexSliceParameter. To target it with an optimizer, use the wildcard pattern "defocus_distance.*". In this tutorial however, the defocus distances are known and fixed, so we only optimize aberrations.
from datetime import datetime
import optax
from jax import random
from tensorboardX import SummaryWriter
from ptyrax.reconstruct import train_model
from ptyrax.training import OptimizerSpecification, initialize_optimizer_and_state
log_dir = f"logs/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
writer = SummaryWriter(log_dir=log_dir)
key = random.PRNGKey(0)
optimizer_specs = [
OptimizerSpecification(
name="aberrations",
match_patterns=["aberrations"],
optimizer=optax.adam(1e-4),
learn_rate_schedule=optax.constant_schedule(value=1.0),
)
]
optimizer_state, optimizer = initialize_optimizer_and_state(model, optimizers=[])
trained_model, _ = train_model(
model=model,
dataset=dataset,
optimizer=optimizer,
optimizer_state=optimizer_state,
num_epochs=10,
writer=writer,
key=key,
)
Epoch: 100%|██████████| 10/10 [00:04<00:00, 2.15it/s]
Training is now running and logging progress to the logs/ directory. We can inspect the training curves in TensorBoard.
# %load_ext tensorboard
# %tensorboard --logdir logs
Custom logging#
By default, the loss (epoch total) and data fidelity (per position) are logged as scalars, and an intensity image at index 0 is logged at every epoch. To extend the logging, we can add an optional __log_epoch__ method to our model. The log_image() function provides a convenient way to log complex scalar images to TensorBoard. Below we add logging of the current aberration estimate at every epoch.
from ptyrax.logger import log_image
class ThroughFocusPredictionModel(ImagePredictionModel):
# The complex pupil function encoding the aberrations we wish to recover.
aberrations: Complex[Array, "m n"]
# The defocus distance (in waves) for each dataset index.
defocus_distance: IndexSliceParameter[Float[Array, ""]]
@classmethod
def from_image_dataset(
cls,
image_dataset: ThroughFocusDataset,
) -> "ThroughFocusPredictionModel":
aberrations = jnp.ones(image_dataset.image_shape, dtype=jnp.complex64)
x, y = jnp.indices(image_dataset.image_shape)
rr = (x - image_dataset.image_shape[0] / 2) ** 2 + (y - image_dataset.image_shape[1] / 2) ** 2
aberrations = aberrations.at[jnp.sqrt(rr) > (image_dataset.image_shape[0] / 4)].set(0)
defocus_distance = IndexSliceParameter(jnp.asarray(image_dataset.defocus_distances, dtype=jnp.float32))
return cls(aberrations=aberrations, defocus_distance=defocus_distance)
def to_image_dataset(self, predicted_images: Float[Array, "d m n"]) -> ThroughFocusDataset:
return ThroughFocusDataset(
intensity_images=predicted_images,
defocus_distances=self.defocus_distance.all,
)
def __call__(self) -> Float[Array, " m n"]:
defocus = self.defocus_distance.at_current_index()
shape = self.aberrations.shape
x, y = jnp.indices(shape)
aperture_radius = shape[0] / 4
rr_normalized = ((x - shape[0] / 2) ** 2 + (y - shape[1] / 2) ** 2) / aperture_radius**2
defocus_mask = phase_only_exp(2 * jnp.pi * defocus * rr_normalized)
field = fft(self.aberrations * defocus_mask)
return jnp.abs(field) ** 2
@property
def image_shape(self) -> tuple[int, int]:
return self.aberrations.shape
def __log_epoch__(self, writer: SummaryWriter, epoch: int, prefix: str = "", **kwargs) -> None:
log_image(writer, "aberrations", self.aberrations, epoch)
model = ThroughFocusPredictionModel.from_image_dataset(dataset)
optimizer_state, optimizer = initialize_optimizer_and_state(model, optimizers=optimizer_specs)
log_dir = f"logs/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
writer = SummaryWriter(log_dir=log_dir)
trained_model, _ = train_model(
model=model,
dataset=dataset,
optimizer=optimizer,
optimizer_state=optimizer_state,
num_epochs=20,
writer=writer,
key=key,
)
Epoch: 100%|██████████| 20/20 [00:15<00:00, 1.31it/s]
If we check the logs now, we see that the current aberration estimate is logged every epoch.
Simulation#
Because the dataset above is fully random, the optimisation does not converge to any meaningful solution. To verify that the model can recover aberrations, we generate a simulated dataset from a known ground truth pupil. We use a combination of horizontal coma and vertical astigmatism as our ground truth. To replace values in existing equinox modules, we use eqx.tree_at. The syntax might look involved, but it boils down to writing a small inner function that returns the field to replace and the replacement value. For more info, see the equinox documentation.
from matplotlib import pyplot as plt
from ptyrax.utils import plot
x = jnp.linspace(-64, 64, 128, endpoint=False)[jnp.newaxis, :]
y = jnp.linspace(-64, 64, 128, endpoint=False)[:, jnp.newaxis]
rr = jnp.square(x) + jnp.square(y)
# Horizontal primary coma for amplitude A: phase = A * (x/R) * (3*(r/R)^2 - 2), where x/R = r/R cos(theta).
# Vertical astigmatism: phase = A * (y^2 - x^2) / R^2
aperture_radius = 32
x_norm = x / aperture_radius
y_norm = y / aperture_radius
rho_sq = rr / aperture_radius**2
coma_phase = 10.0 * x_norm * (3 * rho_sq - 2)
astig_phase = 5.0 * (y_norm**2 - x_norm**2)
ground_truth_aberrations = jnp.exp(1j * (coma_phase + astig_phase))
ground_truth_aberrations = ground_truth_aberrations.at[jnp.sqrt(rr) > aperture_radius].set(0)
fig = plt.figure(figsize=(5, 5))
gs = fig.add_gridspec(1, 1)
plot(ground_truth_aberrations, title="Ground Truth Aberrations", fig=fig, gs=gs[0, 0])
ground_truth_model = eqx.tree_at(
lambda m: m.aberrations,
model,
ground_truth_aberrations,
)
from ptyrax.simulate import simulate_model
simulated_dataset = simulate_model(ground_truth_model)
simulated_dataset.save("simulated_through_focus_dataset.hdf5")
100%|██████████| 50/50 [00:00<00:00, 105.17it/s]
Now let’s see whether the model can recover the coma and astigmatism, starting from the flat (no-aberration) initial pupil.
optimizer_specs = [
OptimizerSpecification(
name="aberrations",
match_patterns=["aberrations"],
optimizer=optax.adam(1e-2),
learn_rate_schedule=optax.constant_schedule(1e-2),
)
]
optimizer_state, optimizer = initialize_optimizer_and_state(model, optimizers=optimizer_specs)
log_dir = f"logs/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
writer = SummaryWriter(log_dir=log_dir)
trained_model, _ = train_model(
model=model,
dataset=simulated_dataset,
optimizer=optimizer,
optimizer_state=optimizer_state,
num_epochs=100,
writer=writer,
key=key,
)
Epoch: 20%|██ | 20/100 [00:17<01:21, 1.02s/it]
2026-06-08 11:27:12,557 - matplotlib.image - WARNING - Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0..1.0000001].
Epoch: 100%|██████████| 100/100 [01:25<00:00, 1.17it/s]
Looks like it went down to machine precision! Since jax uses float32 by default, this might be a bit lower than you are used to, roughly \(10^{-6}\). Let’s compare the recovered pupil to the ground truth.
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
fig = plt.figure(figsize=(10, 5))
gs = GridSpec(1, 3, figure=fig)
_ = plot(model.aberrations, title="Initial aberrations", fig=fig, gs=gs[:, 0])
# Division by a single pixel fixes the global phase ambiguity
_ = plot(
trained_model.aberrations / trained_model.aberrations[64, 64], title="Recovered aberrations", fig=fig, gs=gs[:, 1]
)
_ = plot(ground_truth_model.aberrations, title="True aberrations", fig=fig, gs=gs[:, 2])
plt.show()
Next steps#
We have seen how to use ptyrax to recover pupil aberrations from through-focus intensity measurements. The model can be extended with for example non-paraxial propagation.
Another option would be to do ablation studies on noise robustness and experimental conditions for reconstruction success. For this, we recommend to check out the tutorial on using ptyrax to do experiments using data version control.