Using ptyrax together with chromatix#
Chromatix is a great toolbox for use in computational imaging, containing a lot of optical elements in both a functional and Equinox based API. Since they are built on the same equinox base-code, Chromatix and Ptyrax can go hand-in-hand quite well! We will demonstrate this by training the the Zernike fitting and optimizing it in ptyrax!
import equinox as eqx
import jax.numpy as jnp
from jax import Array
This is the ZernikePSF element from the Zernike fitting tutorial:
import chromatix.functional as cxf
import numpy as np
from chromatix.ops import shot_noise
from chromatix.utils import zernike_aberrations
from jaxtyping import Key
class ZernikePSF(eqx.Module):
coefficients: Array # This is what we want to optimize!
ansi_indices: tuple[int, ...] = eqx.field(
static=True,
converter=lambda x: tuple(x) if isinstance(x, (list, np.ndarray)) else x,
default=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
)
camera_shape: tuple[int, int] = eqx.field(static=True, default=(256, 256))
camera_pixel_pitch: float = eqx.field(static=True, default=0.125)
f: float = eqx.field(static=True, default=100.0)
NA: float = eqx.field(static=True, default=0.8)
n: float = eqx.field(static=True, default=1.33)
wavelength: float = eqx.field(static=True, default=0.532)
upsample: int = eqx.field(static=True, default=4)
pad: int = eqx.field(static=True, default=128)
@eqx.filter_jit
def __call__(self, key: Key | None = None) -> Array:
shape = self.image_shape
spacing = self.upsample * self.f * self.wavelength / (self.n * shape[0] * self.camera_pixel_pitch)
field = cxf.objective_point_source(shape, spacing, self.wavelength, 0.0, self.f, self.n, self.NA, power=1e3)
aberrations = zernike_aberrations(
shape,
spacing,
self.wavelength,
self.n,
self.f,
self.NA,
self.ansi_indices,
self.coefficients,
normalize=False,
)
field = cxf.phase_change(field, aberrations)
field = cxf.ff_lens(field, self.f, self.n)
image = field.intensity
if key is not None:
image = shot_noise(key, image)
return image
To train it in Ptyrax, we only need to adjust it to the form of ImagePredictionModel, and write a corresponding ImageDataset.
import chromatix.functional as cxf
import numpy as np
from chromatix.ops import shot_noise
from chromatix.utils import zernike_aberrations
from jaxtyping import Array, Float, Key
from tensorboardX import SummaryWriter
from ptyrax.dataset import ImageDataset, SimpleImageDataset
from ptyrax.models import ImagePredictionModel
class ZernikePSF(ImagePredictionModel):
coefficients: Array # This is what we want to optimize!
ansi_indices: tuple[int, ...] = eqx.field(
static=True,
converter=lambda x: tuple(x) if isinstance(x, (list, np.ndarray)) else x,
default=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
)
camera_shape: tuple[int, int] = eqx.field(static=True, default=(256, 256))
camera_pixel_pitch: float = eqx.field(static=True, default=0.125)
f: float = eqx.field(static=True, default=100.0)
NA: float = eqx.field(static=True, default=0.8)
n: float = eqx.field(static=True, default=1.33)
wavelength: float = eqx.field(static=True, default=0.532)
upsample: int = eqx.field(static=True, default=4)
pad: int = eqx.field(static=True, default=128)
@eqx.filter_jit
def __call__(self, key: Key | None = None) -> Array:
shape = self.image_shape
spacing = self.upsample * self.f * self.wavelength / (self.n * shape[0] * self.camera_pixel_pitch)
field = cxf.objective_point_source(shape, spacing, self.wavelength, 0.0, self.f, self.n, self.NA, power=1e3)
aberrations = zernike_aberrations(
shape,
spacing,
self.wavelength,
self.n,
self.f,
self.NA,
self.ansi_indices,
self.coefficients,
normalize=False,
)
field = cxf.phase_change(field, aberrations)
field = cxf.ff_lens(field, self.f, self.n)
# return amplitude: better convergence
image = field.amplitude
# We remove the shot noise regularization
# if key is not None:
# image = shot_noise(key, image)
return image
@classmethod
def from_image_dataset(cls, dataset: SimpleImageDataset, *args, **kwargs) -> None:
return cls(*args, camera_shape=dataset.image_shape, **kwargs)
def to_image_dataset(self, predicted_images: Float[Array, " n h w"]) -> ImageDataset:
return SimpleImageDataset(predicted_images)
def __log_epoch__(self, writer: SummaryWriter, epoch: int, prefix: str = "", **kwargs) -> None:
writer.add_histogram(f"{prefix}/coefficients", self.coefficients, epoch)
@property
def image_shape(self) -> tuple[int, int]:
return tuple(np.array(self.camera_shape) * self.upsample + self.pad)
/home/ssenhorst1/workspace/ptyrax/.venv/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
In this example we do not include experimental conditions in the dataset for simplicity. In practice, an experimentalist would likely wish to make \(z, f, \lambda, n\) and \(\text{NA}\) part of the dataset, and change the from_image_dataset to match. Once we have a simple model, we can simulate just fine:
from ptyrax.simulate import simulate_model
model = ZernikePSF(jnp.ones(10)*0.1/(2*jnp.pi))
ground_truth_model = ZernikePSF(jnp.array([2.0, 5.0, 3.0, 0, 1, 0, 1, 0, 1, 0]) / (2 * jnp.pi / model.wavelength))
simulated_dataset = simulate_model(ground_truth_model)
/home/ssenhorst1/workspace/ptyrax/ptyrax/models/ptychography.py:392: UserWarning: The model does not seem to have any IndexSliceParametrization fields. The 'n_indices' property will return 1. If this is unintended, please ensure that the model uses IndexSliceParametrization for dataset-indexed parameters.
warnings.warn(
100%|██████████| 1/1 [00:01<00:00, 1.79s/it]
Of course, since there is no dataset index dependence, this is a somewhat boring dataset of a single image. But it should be enough to train our model on:
from datetime import datetime
import optax
from tensorboardX import SummaryWriter
from ptyrax.reconstruct import train_model
from ptyrax.training import OptimizerSpecification, initialize_optimizer_and_state
optimizer_state, optimizer = initialize_optimizer_and_state(
model,
optimizers=[
OptimizerSpecification(
name="all",
match_patterns=[".*"],
optimizer=optax.adam(learning_rate=1e-2),
)
],
)
log_dir = f"logs/zernike_psf_example/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
writer = SummaryWriter(log_dir=log_dir)
simulated_dataset.to_gpu()
trained_model, trained_optimizer = train_model(
model,
simulated_dataset,
optimizer,
optimizer_state,
num_epochs=100,
writer=writer,
)
Epoch: 100%|██████████| 100/100 [03:55<00:00, 2.36s/it]
import matplotlib.pyplot as plt
from ptyrax.utils import plot
# Seems like we've hit a decent(?) local minimum
fig = plt.figure(figsize=(12, 4))
gs = fig.add_gridspec(1, 3)
fig, *_ = plot(model(), title="Initial model", fig=fig, gs=gs[0, 0], gamma=.33)
plt.gca().set_title("Initial model")
_ = plot(trained_model(), title="Trained model", fig=fig, gs=gs[0, 1], gamma=.33)
plt.gca().set_title("Trained model")
_ = plot(ground_truth_model(), title="Ground truth model", fig=fig, gs=gs[0, 2], gamma=.33)
plt.gca().set_title("Ground truth model")
Text(0.5, 1.0, 'Ground truth model')