ptyrax.utils#
Functions
|
Compute the squared absolute value without taking a square root. |
|
Compute the conjugate transpose of an array. |
|
Crop an array symmetrically along a single axis. |
|
Symmetrically zero-pad an array along a single axis. |
|
Center scanning positions by subtracting their mean. |
|
Compile regex policies once and validate patterns up-front. |
|
Convert a complex-valued image to an RGB representation. |
|
Compute the center-of-mass shift required to center an image. |
|
Compute the reciprocal-space (Fourier) sampling parameters. |
|
Counts the total number of parameters in a PyTree. |
|
Compute centered 2D FFT. |
|
Recursively flatten nested dictionaries into a single-level dict. |
|
Convert HSV color values to RGB. |
|
Identity function that returns its arguments unchanged. |
|
Compute centered 2D inverse FFT. |
|
Join two HDF5-like paths while keeping a single separator. |
|
This function loads the data in a hdf5 file from file_path and returns a dictionary. |
|
Utility to convert a single parameter value or a list of parameter values into an array of length n. |
|
Create or reuse matplotlib Figure and Axes. |
|
Convert a JAX key-path list to a dot-separated string. |
|
Compute the median pixel of an image. |
|
Normalize an array by its maximum value along specified axes. |
|
Normalize an HDF5-like path by collapsing duplicate separators. |
|
Normalize an array so its total power equals 1. |
|
Orthogonalize rows of a matrix using SVD. |
Orthogonalize probe modes in a multi-mode array. |
|
Vectorized version of orthogonal. |
|
|
Compute a unit-magnitude complex exponential from a real phase. |
|
NumPy version of |
|
Plot an image array, dispatching to complex or real plotting. |
|
Plot a complex-valued image by converting it to RGB using a phase- magnitude colormap. |
|
Plot a real-valued image with optional gamma correction and log scaling. |
|
Map a real-valued array to an RGB image via a matplotlib colormap. |
|
Create a gradient function with a scalar reduction applied first. |
|
Repeat elements along an axis and center-crop to |
|
Resize arr to target_shape: - Contracting axes: crop symmetrically - Expanding axes: per-axis policy or default_policy |
|
This function saves the data (a dictionary) to a hdf5 file at file_path. |
|
Compute the mean of an array scaled by a constant factor. |
Update model probe data while preserving the parametrization wrapper. |
|
|
Shift a 2D image using scipy.ndimage.map_coordinates with zero padding. |
|
Return a single-element tuple unchanged. |
|
Extract a centered slice of x with the given target_shape around center. |
|
Differentiable soft-clipping function. |
|
Sort image file paths by their filesystem modification time. |
|
Tile an array to at least |
|
Slices the first element from each leaf in a PyTree. |
|
Split a stacked PyTree (with leading batch axis) into a list of PyTrees. |
|
Apply nested |
Warn when key normalization collapses distinct checkpoint paths. |
|
|
Wrap a value in the same parametrization type as a reference. |
|
Symmetrically zero-pad an array to |
- ptyrax.utils.abs_sq(x)[source]#
Compute the squared absolute value without taking a square root.
More efficient than
jnp.abs(x)**2for complex inputs.- Parameters:
x (Shaped[Array, '...']) – Input array (real or complex).
- Returns:
Real-valued squared magnitude.
- Return type:
Float[Array, ‘…’]
- ptyrax.utils.adjoint(f)[source]#
Compute the conjugate transpose of an array.
- Parameters:
f (Shaped[Array, '... m n']) – Input array.
- Returns:
Conjugate transpose (swapping last two axes).
- Return type:
Shaped[Array, ‘… m n’]
- ptyrax.utils.center_crop(arr, target_len, axis)[source]#
Crop an array symmetrically along a single axis.
- Parameters:
arr (ndarray) – Input array.
target_len (int) – Desired length along
axis.axis (int) – Axis to crop.
- Returns:
Center-cropped array.
- Return type:
ndarray
- ptyrax.utils.center_pad(arr, target_len, axis)[source]#
Symmetrically zero-pad an array along a single axis.
- Parameters:
arr (ndarray) – Input array.
target_len (int) – Desired length along
axis.axis (int) – Axis to pad.
- Returns:
Zero-padded array.
- Return type:
ndarray
- ptyrax.utils.center_scan_pos(scan_pos)[source]#
Center scanning positions by subtracting their mean.
- Parameters:
scan_pos (Float[Array, 'N d']) – Array of scanning positions.
- Returns:
Mean-subtracted positions.
- Return type:
Float[Array, ‘N d’]
- ptyrax.utils.compile_policy_patterns(policy_map)[source]#
Compile regex policies once and validate patterns up-front.
- Parameters:
policy_map (dict[str, dict[str, Any]] | None)
- Return type:
list[tuple[Pattern[str], dict[str, Any]]]
- ptyrax.utils.complex_to_rgb(im, log10=False, cmap='hsv', gamma=1.0, clim=None, scale_min=False, max=1.0)[source]#
Convert a complex-valued image to an RGB representation.
Encodes magnitude as brightness and phase as hue using an HSV colormap.
- Parameters:
im (Complex[Array, '... w h']) – Complex input image.
log10 (bool) – Apply log10 to magnitude before mapping.
cmap (Literal['hsv', 'lab']) – Colormap style (currently only
"hsv"supported).gamma (float) – Gamma exponent applied to magnitude.
clim (tuple[float, float]) – Optional (min, max) clipping range for magnitude.
scale_min (bool) – Reserved (not used).
max (float) – Maximum RGB output value.
- Returns:
RGB float array with values in [0,
max].- Return type:
Float[Array, ‘… w h 3’]
- ptyrax.utils.compute_center_of_mass_shift(img, order=1)[source]#
Compute the center-of-mass shift required to center an image.
- Parameters:
img (Float[Array, '... n m']) – Input image (can be batched); non-negative values are assumed.
order (int) – Power to raise pixel values when computing the weighted center.
- Returns:
Shift vector (dy, dx) to translate the image to the center.
- Return type:
Float[Array, ‘2’]
- ptyrax.utils.convert_to_ft_sampling(pixel_number, pixel_size, scaling_factor=1.0, prop_dist=None, wavelength=None)[source]#
Compute the reciprocal-space (Fourier) sampling parameters.
Given real-space pixel count and size, returns the corresponding Fourier pixel count and size. Optionally computes the scaling from wavelength and propagation distance.
- Parameters:
pixel_number (Integer[Array, '... d']) – Number of pixels in each dimension.
pixel_size (Integer[Array, '... d']) – Pixel size in each dimension.
scaling_factor (float) – Manual scaling factor (mutually exclusive with
wavelength/prop_dist).prop_dist (float) – Propagation distance (requires
wavelength).wavelength (float) – Photon wavelength (requires
prop_dist).
- Returns:
Tuple of
(ft_pixel_number, ft_pixel_size).- Raises:
ValueError – If arguments are inconsistent.
- Return type:
tuple[Integer[Array, ‘… d’], Float[Array, ‘… d’]]
- ptyrax.utils.count_parameters(pytree)[source]#
Counts the total number of parameters in a PyTree.
- Parameters:
pytree (PyTree) – A PyTree containing parameters.
- Returns:
Total number of parameters.
- Return type:
int
- ptyrax.utils.fft(x, axes=(-2, -1), norm='ortho', fftshift=True)[source]#
Compute centered 2D FFT.
Applies
ifftshiftbefore andfftshiftafter the FFT so that the zero-frequency component is in the center.- Parameters:
x (Inexact[Array, '...']) – Input array.
axes (tuple[int, int]) – Axes over which to compute the FFT.
norm (Literal['backward', 'ortho', 'forward']) – Normalization mode.
fftshift (bool) – If True, center the transform.
- Returns:
Complex FFT of the input.
- Return type:
Complex[Array, ‘…’]
- ptyrax.utils.flatten_dict(d, parent_key='', sep='.')[source]#
Recursively flatten nested dictionaries into a single-level dict.
- Parameters:
d (dict) – Dictionary to flatten.
parent_key (str) – Prefix for keys at this level (used in recursion).
sep (str) – Separator joining nested key parts.
- Returns:
Flat dictionary with concatenated keys.
- Return type:
dict
- ptyrax.utils.hsv_to_rgb(h, s, v)[source]#
Convert HSV color values to RGB.
- Parameters:
h (Float[Array, '...']) – Hue channel in [0, 1].
s (Float[Array, '...']) – Saturation channel in [0, 1].
v (Float[Array, '...']) – Value channel in [0, 1].
- Returns:
RGB array with last dimension of size 3.
- Return type:
Float[Array, ‘… 3’]
- ptyrax.utils.identity(*args)[source]#
Identity function that returns its arguments unchanged.
Useful as a default no-op callback in gin-configurable pipelines.
- Parameters:
*args (Any) – Any arguments.
- Returns:
The input arguments as a tuple.
- Return type:
Any
- ptyrax.utils.ifft(x, axes=(-2, -1), norm='ortho', fftshift=True)[source]#
Compute centered 2D inverse FFT.
Applies
ifftshiftbefore andfftshiftafter the IFFT so that the zero-frequency component is handled correctly.- Parameters:
x (Inexact[Array, '...']) – Input array (frequency domain).
axes (tuple) – Axes over which to compute the IFFT.
norm (Literal['backward', 'ortho', 'forward']) – Normalization mode.
fftshift (bool) – If True, center the transform.
- Returns:
Complex inverse FFT of the input.
- Return type:
Complex[Array, ‘…’]
- ptyrax.utils.join_hdf5_paths(prefix, suffix)[source]#
Join two HDF5-like paths while keeping a single separator.
- Parameters:
prefix (str | None)
suffix (str)
- Return type:
str
- ptyrax.utils.load_hdf5(file_path, key_translation=None)[source]#
This function loads the data in a hdf5 file from file_path and returns a dictionary.
The data_type = all, params, ptychogram
- Parameters:
file_path (str)
key_translation (dict)
- Return type:
dict
- ptyrax.utils.make_length_n(parameter, n=2)[source]#
Utility to convert a single parameter value or a list of parameter values into an array of length n.
- Parameters:
parameter (float | int | ndarray | list | tuple)
n (int)
- Return type:
array
- ptyrax.utils.make_or_reuse_axes(fig=None, gs=None)[source]#
Create or reuse matplotlib Figure and Axes.
- Parameters:
fig (Figure) – Existing figure (required if
gsis provided).gs (SubplotSpec) – GridSpec subplot to add an axes to.
- Returns:
Tuple of
(fig, ax).- Raises:
ValueError – If
gsis provided withoutfig.- Return type:
tuple[Figure, Axes]
- ptyrax.utils.make_path_string(path)[source]#
Convert a JAX key-path list to a dot-separated string.
- Parameters:
path (list[str]) – List of JAX tree path elements.
- Returns:
Dot-joined path string.
- Return type:
str
- ptyrax.utils.median_pixel_value(img)[source]#
Compute the median pixel of an image.
- Parameters:
img (Float[Array, '...'])
- Return type:
float
- ptyrax.utils.normalize(a, axes=None)[source]#
Normalize an array by its maximum value along specified axes.
- Parameters:
a (Shaped[Array, '...']) – Input array.
axes (tuple[int, ...]) – Axes over which to take the maximum.
- Returns:
Array divided by its maximum.
- Return type:
Shaped[Array, ‘…’]
- ptyrax.utils.normalize_hdf5_path(path)[source]#
Normalize an HDF5-like path by collapsing duplicate separators.
- Parameters:
path (str | None)
- Return type:
str
- ptyrax.utils.normalize_power(a, axis=(0, 1))[source]#
Normalize an array so its total power equals 1.
The normalization enforces \(\sum |a|^2 = 1\) over
axis.- Parameters:
a (Shaped[Array, '...']) – Input array.
axis (tuple[int, ...]) – Axes over which to compute the power.
- Returns:
Power-normalized array.
- Return type:
Shaped[Array, ‘…’]
- ptyrax.utils.orthogonal(f)[source]#
Orthogonalize rows of a matrix using SVD.
- Parameters:
f (Shaped[Array, '... d mn']) – Input matrix or batch of matrices.
- Returns:
Matrix with orthogonalized rows.
- Return type:
Shaped[Array, ‘… d mn’]
- ptyrax.utils.orthogonalize(f)[source]#
Orthogonalize probe modes in a multi-mode array.
- Parameters:
f (Shaped[Array, '... m n']) – Array of shape
(modes, m, n)representing probe modes.- Returns:
Orthogonalized probe modes with same shape.
- Raises:
NotImplementedError – If input has additional wavelength dimensions.
ValueError – If no mode dimension is present.
- Return type:
Shaped[Array, ‘… m n’]
- ptyrax.utils.parallel_orthogonal(f)#
Vectorized version of orthogonal. Takes similar arguments as orthogonal but with additional array axes over which orthogonal is mapped.
Original documentation:
Orthogonalize rows of a matrix using SVD.
- Parameters:
f (Shaped[Array, '... d mn']) – Input matrix or batch of matrices.
- Returns:
Matrix with orthogonalized rows.
- Return type:
Shaped[Array, ‘… d mn’]
- ptyrax.utils.phase_only_exp(phase)[source]#
Compute a unit-magnitude complex exponential from a real phase.
Equivalent to
exp(1j * phase)but avoids complex input.- Parameters:
phase (Float[Array, '... m n']) – Real-valued phase array in radians.
- Returns:
Complex array with unit magnitude and given phase.
- Raises:
ValueError – If
phaseis already complex.- Return type:
Complex[Array, ‘… m n’]
- ptyrax.utils.phase_only_exp_np(phase)[source]#
NumPy version of
phase_only_exp().- Parameters:
phase (Float[Array, '... m n']) – Real-valued phase array in radians.
- Returns:
Complex NumPy array with unit magnitude.
- Raises:
ValueError – If
phaseis already complex.- Return type:
Complex[Array, ‘… m n’]
- ptyrax.utils.plot(im, show=False, gs=None, fig=None, dpi=150, plot_text=False, **kwargs)[source]#
Plot an image array, dispatching to complex or real plotting.
Supports complex-valued images (phase-magnitude colormap), real-valued images, and objects with a
__plot__method.- Parameters:
im (Array | ndarray | bool | number | bool | int | float | complex | LiteralArray) – Image array to plot.
show (bool) – If True, call
plt.show().gs (GridSpec | None) – Optional GridSpec for subplot placement.
fig (Figure | None) – Optional existing Figure.
dpi (int | None) – Figure DPI.
plot_text (bool) – If True, annotate the figure with parameter info.
**kwargs – Extra arguments forwarded to the plotting backend.
- Returns:
Tuple of
(fig, gs, image_artists).- Return type:
tuple[Figure, SubplotSpec, list[AxesImage]]
- ptyrax.utils.plot_complex(im, fig, gs, log10=False, cmap='hsv', gamma=1, title=None, clim=None, cbar=None, **kwargs)[source]#
Plot a complex-valued image by converting it to RGB using a phase- magnitude colormap.
- Parameters:
im (Complex[Array, '...']) – Complex image.
fig (Figure) – Matplotlib Figure to draw on.
gs (SubplotSpec) – GridSpec/SubplotSpec to place the image.
log10 (bool) – Whether to take log10 of magnitude.
cmap (str)
gamma (float)
title (str | None)
clim (tuple[float, float] | None)
cbar (bool | None)
- Returns:
Tuple of (fig, gs, image artists).
- Return type:
tuple[Figure, SubplotSpec, list[AxesImage]]
- ptyrax.utils.plot_real(im, fig, gs, title=None, cmap=None, gamma=1.0, cbar=False, log10=False, epsilon=1e-10, **kwargs)[source]#
Plot a real-valued image with optional gamma correction and log scaling.
- Parameters:
im (Float[Array, '...']) – Real-valued image.
fig (Figure) – Matplotlib Figure to draw on.
gs (SubplotSpec) – GridSpec/SubplotSpec to place the image.
gamma (float) – Gamma exponent to apply to the image.
log10 (bool) – Whether to apply log10 scaling (uses epsilon to avoid log(0)).
title (str)
cmap (str | None)
cbar (bool)
epsilon (float)
- Returns:
Tuple of (fig, gs, image artists).
- Return type:
tuple[Figure, SubplotSpec, list[AxesImage]]
- ptyrax.utils.real_to_rgb(tensor, log10=False, cmap='magma', gamma=1.0, **kwargs)[source]#
Map a real-valued array to an RGB image via a matplotlib colormap.
- Parameters:
tensor (Float[Array, '...']) – Real input array.
log10 (bool) – Apply log10 scaling.
cmap (str) – Matplotlib colormap name.
gamma (float) – Gamma exponent applied to normalized amplitude.
- Returns:
RGB float array with shape
(*tensor.shape, 3).- Return type:
Float[Array, ‘… 3’]
- ptyrax.utils.reduced_grad(fn, reduction_fn=<function <lambda>>)[source]#
Create a gradient function with a scalar reduction applied first.
- Parameters:
fn (Callable) – Function whose output will be reduced then differentiated.
reduction_fn (Callable) – Scalar reduction (default: sum of absolute values).
- Returns:
Gradient function of the reduced output.
- Return type:
Callable
- ptyrax.utils.repeat_axis(arr, target_len, axis)[source]#
Repeat elements along an axis and center-crop to
target_len.- Parameters:
arr (ndarray) – Input array.
target_len (int) – Desired length along
axis.axis (int) – Axis to expand.
- Returns:
Repeated and center-cropped array.
- Return type:
ndarray
- ptyrax.utils.resize_to_match(arr, target_shape, axis_policies=None, default_policy='pad')[source]#
Resize arr to target_shape: - Contracting axes: crop symmetrically - Expanding axes: per-axis policy or default_policy
- Parameters:
arr (ndarray)
target_shape (Tuple[int, ...])
axis_policies (Dict[int, str])
default_policy (str)
- Return type:
ndarray
- ptyrax.utils.save_hdf5(file_path, data)[source]#
This function saves the data (a dictionary) to a hdf5 file at file_path.
- Parameters:
file_path (str)
data (dict)
- Return type:
None
- ptyrax.utils.scaled_mean(a, scale=1.0, **kwargs)[source]#
Compute the mean of an array scaled by a constant factor.
Typically configured via gin as a loss reduction function.
- Parameters:
a (Array | ndarray | bool | number | bool | int | float | complex | LiteralArray) – Input array.
scale (float) – Multiplicative scaling applied after mean reduction.
**kwargs – Extra arguments forwarded to
jax.numpy.mean().
- Returns:
Scalar mean value times
scale.- Return type:
Float[Array, ‘’]
- ptyrax.utils.set_probe_data_preserve_parametrization(model, new_probe_data)[source]#
Update model probe data while preserving the parametrization wrapper.
Uses
eqx.tree_at()to replacemodel.illumination.probe.datawithnew_probe_datawrapped in the original parametrization type.- Parameters:
model (Any) – Model whose probe data should be updated.
new_probe_data (Any) – New data to set.
- Returns:
Updated model with new probe data.
- Return type:
Any
- ptyrax.utils.shift_image(img, shift, order=1)[source]#
Shift a 2D image using scipy.ndimage.map_coordinates with zero padding.
- Parameters:
img (Array) – 2D NumPy array
shift (Array) – Tuple or array (dy, dx), the shift to apply
order (int) – Interpolation order (1 = bilinear, 3 = cubic, etc.)
- Returns:
Shifted image with same shape as input, zero-padded at edges.
- Return type:
Array
- ptyrax.utils.single_identity(args)[source]#
Return a single-element tuple unchanged.
Convenience identity function for callbacks that receive a tuple.
- Parameters:
args (tuple[Any, ...]) – Input tuple.
- Returns:
The same tuple.
- Return type:
Any
- ptyrax.utils.slice_at_center_to_shape(x, center, target_shape)[source]#
Extract a centered slice of x with the given target_shape around center.
- Parameters:
x (Float[Array, '... n m']) – Array from which to slice. Last two dimensions are spatial.
center (Float[Array, '2']) – Center coordinate (y, x) in the same coordinate system as x.
target_shape (tuple[int, int]) – Desired output spatial shape (height, width).
- Returns:
Dynamically sliced view of x with spatial dims equal to target_shape.
- Return type:
Float[Array, ‘… n m’]
- ptyrax.utils.soft_clip(a, min_value, max_value, relu_like=<PjitFunction of <function identity>>, scale=1.0)[source]#
Differentiable soft-clipping function.
Applies a smooth clamping operation using a ReLU-like activation to keep values within
[min_value, max_value].- Parameters:
a (Shaped[Array, '...']) – Input array.
min_value (float | int) – Lower clipping bound.
max_value (float | int) – Upper clipping bound.
relu_like (Callable) – Activation function for soft clamping.
scale (float) – Scaling factor applied before and after clipping.
- Returns:
Clipped array.
- Return type:
Shaped[Array, ‘…’]
- ptyrax.utils.sort_images_by_time(image_paths)[source]#
Sort image file paths by their filesystem modification time.
- Parameters:
image_paths (list[str]) – List of file paths to sort.
- Returns:
Paths sorted by ascending modification time.
- Return type:
list[str]
- ptyrax.utils.tile_axes(arr, target_shape)[source]#
Tile an array to at least
target_shapeand center-crop to exact size.- Parameters:
arr (ndarray) – Input array to tile.
target_shape (Tuple[int, ...]) – Desired output shape.
- Returns:
Array tiled and cropped to
target_shape.- Return type:
ndarray
- ptyrax.utils.tree_slice_first(tree)[source]#
Slices the first element from each leaf in a PyTree.
- Parameters:
tree (PyTree) – A PyTree containing arrays.
- Returns:
A new PyTree with the first element sliced from each leaf.
- Return type:
PyTree
- ptyrax.utils.unstack_tree(stacked)[source]#
Split a stacked PyTree (with leading batch axis) into a list of PyTrees.
- Parameters:
stacked (PyTree) – PyTree where every leaf has a leading batch dimension.
- Returns:
List of PyTrees, one per element along axis 0.
- Return type:
list[PyTree]
- ptyrax.utils.vmap_nested(fn, in_axes, *args, **kwargs)[source]#
Apply nested
jax.vmap()for multiple batch dimensions.- Parameters:
fn (Callable) – Function to vectorize.
in_axes (tuple[int, ...]) – Tuple of axes—one per nesting level.
- Returns:
Multi-level vmapped function.
- Return type:
Callable
- ptyrax.utils.warn_if_duplicate_normalized_keys(state_keys, normalize_keys)[source]#
Warn when key normalization collapses distinct checkpoint paths.
If
normalize_keysis enabled and stripping leading underscores from path segments causes two different keys to map to the same normalized key, a warning is emitted.- Parameters:
state_keys (list[str]) – List of HDF5 state keys.
normalize_keys (bool) – Whether normalization is active.
- Return type:
None
- ptyrax.utils.wrap_like_parametrization(reference, value)[source]#
Wrap a value in the same parametrization type as a reference.
If
referenceis an ArrayParametrization, wrapsvaluein the same subclass; otherwise returnsvalueunchanged.- Parameters:
reference (Any) – Object whose type is inspected.
value (Any) – Value to potentially wrap.
- Returns:
Wrapped or unwrapped value.
- Return type:
Any
- ptyrax.utils.zero_pad_to_shape(x, target_shape)[source]#
Symmetrically zero-pad an array to
target_shape.- Parameters:
x (Float[Array, '... n m']) – Input array (last two dims are spatial).
target_shape (tuple[int, int]) – Desired spatial dimensions.
- Returns:
Zero-padded array with spatial dimensions equal to
target_shape.- Return type:
Float[Array, ‘… n m’]