Using ptyrax together with chromatix

Using ptyrax together with chromatix#

Chromatix is a great toolbox for use in computational imaging, containing a lot of optical elements in both a functional and Equinox based API. Since they are built on the same equinox base-code, Chromatix and Ptyrax can go hand-in-hand quite well! We will demonstrate this by training the the Zernike fitting and optimizing it in ptyrax!

import equinox as eqx
import jax.numpy as jnp
from jax import Array

This is the ZernikePSF element from the Zernike fitting tutorial:

import chromatix.functional as cxf
import numpy as np
from chromatix.ops import shot_noise
from chromatix.utils import zernike_aberrations
from jaxtyping import Key


class ZernikePSF(eqx.Module):
    coefficients: Array  # This is what we want to optimize!
    ansi_indices: tuple[int, ...] = eqx.field(
        static=True,
        converter=lambda x: tuple(x) if isinstance(x, (list, np.ndarray)) else x,
        default=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
    )
    camera_shape: tuple[int, int] = eqx.field(static=True, default=(256, 256))
    camera_pixel_pitch: float = eqx.field(static=True, default=0.125)
    f: float = eqx.field(static=True, default=100.0)
    NA: float = eqx.field(static=True, default=0.8)
    n: float = eqx.field(static=True, default=1.33)
    wavelength: float = eqx.field(static=True, default=0.532)
    upsample: int = eqx.field(static=True, default=4)
    pad: int = eqx.field(static=True, default=128)

    @eqx.filter_jit
    def __call__(self, key: Key | None = None) -> Array:
        shape = self.image_shape
        spacing = self.upsample * self.f * self.wavelength / (self.n * shape[0] * self.camera_pixel_pitch)
        field = cxf.objective_point_source(shape, spacing, self.wavelength, 0.0, self.f, self.n, self.NA, power=1e3)
        aberrations = zernike_aberrations(
            shape,
            spacing,
            self.wavelength,
            self.n,
            self.f,
            self.NA,
            self.ansi_indices,
            self.coefficients,
            normalize=False,
        )
        field = cxf.phase_change(field, aberrations)
        field = cxf.ff_lens(field, self.f, self.n)
        image = field.intensity
        if key is not None:
            image = shot_noise(key, image)
        return image

To train it in Ptyrax, we only need to adjust it to the form of ImagePredictionModel, and write a corresponding ImageDataset.

import chromatix.functional as cxf
import numpy as np
from chromatix.ops import shot_noise
from chromatix.utils import zernike_aberrations
from jaxtyping import Array, Float, Key
from tensorboardX import SummaryWriter

from ptyrax.dataset import ImageDataset, SimpleImageDataset
from ptyrax.models import ImagePredictionModel


class ZernikePSF(ImagePredictionModel):
    coefficients: Array  # This is what we want to optimize!
    ansi_indices: tuple[int, ...] = eqx.field(
        static=True,
        converter=lambda x: tuple(x) if isinstance(x, (list, np.ndarray)) else x,
        default=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
    )
    camera_shape: tuple[int, int] = eqx.field(static=True, default=(256, 256))
    camera_pixel_pitch: float = eqx.field(static=True, default=0.125)
    f: float = eqx.field(static=True, default=100.0)
    NA: float = eqx.field(static=True, default=0.8)
    n: float = eqx.field(static=True, default=1.33)
    wavelength: float = eqx.field(static=True, default=0.532)
    upsample: int = eqx.field(static=True, default=4)
    pad: int = eqx.field(static=True, default=128)

    @eqx.filter_jit
    def __call__(self, key: Key | None = None) -> Array:
        shape = self.image_shape
        spacing = self.upsample * self.f * self.wavelength / (self.n * shape[0] * self.camera_pixel_pitch)
        field = cxf.objective_point_source(shape, spacing, self.wavelength, 0.0, self.f, self.n, self.NA, power=1e3)
        aberrations = zernike_aberrations(
            shape,
            spacing,
            self.wavelength,
            self.n,
            self.f,
            self.NA,
            self.ansi_indices,
            self.coefficients,
            normalize=False,
        )
        field = cxf.phase_change(field, aberrations)
        field = cxf.ff_lens(field, self.f, self.n)
        # return amplitude: better convergence
        image = field.amplitude
        # We remove the shot noise regularization
        # if key is not None:
        #     image = shot_noise(key, image)
        return image

    @classmethod
    def from_image_dataset(cls, dataset: SimpleImageDataset, *args, **kwargs) -> None:
        return cls(*args, camera_shape=dataset.image_shape, **kwargs)

    def to_image_dataset(self, predicted_images: Float[Array, " n h w"]) -> ImageDataset:
        return SimpleImageDataset(predicted_images)

    def __log_epoch__(self, writer: SummaryWriter, epoch: int, prefix: str = "", **kwargs) -> None:
        writer.add_histogram(f"{prefix}/coefficients", self.coefficients, epoch)

    @property
    def image_shape(self) -> tuple[int, int]:
        return tuple(np.array(self.camera_shape) * self.upsample + self.pad)
/home/ssenhorst1/workspace/ptyrax/.venv/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

In this example we do not include experimental conditions in the dataset for simplicity. In practice, an experimentalist would likely wish to make \(z, f, \lambda, n\) and \(\text{NA}\) part of the dataset, and change the from_image_dataset to match. Once we have a simple model, we can simulate just fine:

from ptyrax.simulate import simulate_model

model = ZernikePSF(jnp.ones(10)*0.1/(2*jnp.pi))
ground_truth_model = ZernikePSF(jnp.array([2.0, 5.0, 3.0, 0, 1, 0, 1, 0, 1, 0]) / (2 * jnp.pi / model.wavelength))
simulated_dataset = simulate_model(ground_truth_model)
/home/ssenhorst1/workspace/ptyrax/ptyrax/models/ptychography.py:392: UserWarning: The model does not seem to have any IndexSliceParametrization fields. The 'n_indices' property will return 1. If this is unintended, please ensure that the model uses IndexSliceParametrization for dataset-indexed parameters.
  warnings.warn(
100%|██████████| 1/1 [00:01<00:00,  1.79s/it]

Of course, since there is no dataset index dependence, this is a somewhat boring dataset of a single image. But it should be enough to train our model on:

from datetime import datetime

import optax
from tensorboardX import SummaryWriter

from ptyrax.reconstruct import train_model
from ptyrax.training import OptimizerSpecification, initialize_optimizer_and_state

optimizer_state, optimizer = initialize_optimizer_and_state(
    model,
    optimizers=[
        OptimizerSpecification(
            name="all",
            match_patterns=[".*"],
            optimizer=optax.adam(learning_rate=1e-2),
        )
    ],
)
log_dir = f"logs/zernike_psf_example/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
writer = SummaryWriter(log_dir=log_dir)
simulated_dataset.to_gpu()
trained_model, trained_optimizer = train_model(
    model,
    simulated_dataset,
    optimizer,
    optimizer_state,
    num_epochs=100,
    writer=writer,
)
Epoch: 100%|██████████| 100/100 [03:55<00:00,  2.36s/it]
import matplotlib.pyplot as plt

from ptyrax.utils import plot

# Seems like we've hit a decent(?) local minimum
fig = plt.figure(figsize=(12, 4))
gs = fig.add_gridspec(1, 3)
fig, *_ = plot(model(), title="Initial model", fig=fig, gs=gs[0, 0], gamma=.33)
plt.gca().set_title("Initial model")

_ = plot(trained_model(), title="Trained model", fig=fig, gs=gs[0, 1], gamma=.33)
plt.gca().set_title("Trained model")

_ = plot(ground_truth_model(), title="Ground truth model", fig=fig, gs=gs[0, 2], gamma=.33)
plt.gca().set_title("Ground truth model")
Text(0.5, 1.0, 'Ground truth model')
../_images/b85917ab2dfc84e53ba4f8d380b06962afa062b41fe39854aa73656de82e7e88.png