ptyrax.utils#

Functions

abs_sq(x)

Compute the squared absolute value without taking a square root.

adjoint(f)

Compute the conjugate transpose of an array.

center_crop(arr, target_len, axis)

Crop an array symmetrically along a single axis.

center_pad(arr, target_len, axis)

Symmetrically zero-pad an array along a single axis.

center_scan_pos(scan_pos)

Center scanning positions by subtracting their mean.

compile_policy_patterns(policy_map)

Compile regex policies once and validate patterns up-front.

complex_to_rgb(im[, log10, cmap, gamma, ...])

Convert a complex-valued image to an RGB representation.

compute_center_of_mass_shift(img[, order])

Compute the center-of-mass shift required to center an image.

convert_to_ft_sampling(pixel_number, pixel_size)

Compute the reciprocal-space (Fourier) sampling parameters.

count_parameters(pytree)

Counts the total number of parameters in a PyTree.

fft(x[, axes, norm, fftshift])

Compute centered 2D FFT.

flatten_dict(d[, parent_key, sep])

Recursively flatten nested dictionaries into a single-level dict.

hsv_to_rgb(h, s, v)

Convert HSV color values to RGB.

identity(*args)

Identity function that returns its arguments unchanged.

ifft(x[, axes, norm, fftshift])

Compute centered 2D inverse FFT.

join_hdf5_paths(prefix, suffix)

Join two HDF5-like paths while keeping a single separator.

load_hdf5(file_path[, key_translation])

This function loads the data in a hdf5 file from file_path and returns a dictionary.

make_length_n(parameter[, n])

Utility to convert a single parameter value or a list of parameter values into an array of length n.

make_or_reuse_axes([fig, gs])

Create or reuse matplotlib Figure and Axes.

make_path_string(path)

Convert a JAX key-path list to a dot-separated string.

median_pixel_value(img)

Compute the median pixel of an image.

normalize(a[, axes])

Normalize an array by its maximum value along specified axes.

normalize_hdf5_path(path)

Normalize an HDF5-like path by collapsing duplicate separators.

normalize_power(a[, axis])

Normalize an array so its total power equals 1.

orthogonal(f)

Orthogonalize rows of a matrix using SVD.

orthogonalize(f)

Orthogonalize probe modes in a multi-mode array.

parallel_orthogonal(f)

Vectorized version of orthogonal.

phase_only_exp(phase)

Compute a unit-magnitude complex exponential from a real phase.

phase_only_exp_np(phase)

NumPy version of phase_only_exp().

plot(im[, show, gs, fig, dpi, plot_text])

Plot an image array, dispatching to complex or real plotting.

plot_complex(im, fig, gs[, log10, cmap, ...])

Plot a complex-valued image by converting it to RGB using a phase- magnitude colormap.

plot_real(im, fig, gs[, title, cmap, gamma, ...])

Plot a real-valued image with optional gamma correction and log scaling.

real_to_rgb(tensor[, log10, cmap, gamma])

Map a real-valued array to an RGB image via a matplotlib colormap.

reduced_grad(fn[, reduction_fn])

Create a gradient function with a scalar reduction applied first.

repeat_axis(arr, target_len, axis)

Repeat elements along an axis and center-crop to target_len.

resize_to_match(arr, target_shape[, ...])

Resize arr to target_shape: - Contracting axes: crop symmetrically - Expanding axes: per-axis policy or default_policy

save_hdf5(file_path, data)

This function saves the data (a dictionary) to a hdf5 file at file_path.

scaled_mean(a[, scale])

Compute the mean of an array scaled by a constant factor.

set_probe_data_preserve_parametrization(...)

Update model probe data while preserving the parametrization wrapper.

shift_image(img, shift[, order])

Shift a 2D image using scipy.ndimage.map_coordinates with zero padding.

single_identity(args)

Return a single-element tuple unchanged.

slice_at_center_to_shape(x, center, target_shape)

Extract a centered slice of x with the given target_shape around center.

soft_clip(a, min_value, max_value[, ...])

Differentiable soft-clipping function.

sort_images_by_time(image_paths)

Sort image file paths by their filesystem modification time.

tile_axes(arr, target_shape)

Tile an array to at least target_shape and center-crop to exact size.

tree_slice_first(tree)

Slices the first element from each leaf in a PyTree.

unstack_tree(stacked)

Split a stacked PyTree (with leading batch axis) into a list of PyTrees.

vmap_nested(fn, in_axes, *args, **kwargs)

Apply nested jax.vmap() for multiple batch dimensions.

warn_if_duplicate_normalized_keys(...)

Warn when key normalization collapses distinct checkpoint paths.

wrap_like_parametrization(reference, value)

Wrap a value in the same parametrization type as a reference.

zero_pad_to_shape(x, target_shape)

Symmetrically zero-pad an array to target_shape.

ptyrax.utils.abs_sq(x)[source]#

Compute the squared absolute value without taking a square root.

More efficient than jnp.abs(x)**2 for complex inputs.

Parameters:

x (Shaped[Array, '...']) – Input array (real or complex).

Returns:

Real-valued squared magnitude.

Return type:

Float[Array, ‘…’]

ptyrax.utils.adjoint(f)[source]#

Compute the conjugate transpose of an array.

Parameters:

f (Shaped[Array, '... m n']) – Input array.

Returns:

Conjugate transpose (swapping last two axes).

Return type:

Shaped[Array, ‘… m n’]

ptyrax.utils.center_crop(arr, target_len, axis)[source]#

Crop an array symmetrically along a single axis.

Parameters:
  • arr (ndarray) – Input array.

  • target_len (int) – Desired length along axis.

  • axis (int) – Axis to crop.

Returns:

Center-cropped array.

Return type:

ndarray

ptyrax.utils.center_pad(arr, target_len, axis)[source]#

Symmetrically zero-pad an array along a single axis.

Parameters:
  • arr (ndarray) – Input array.

  • target_len (int) – Desired length along axis.

  • axis (int) – Axis to pad.

Returns:

Zero-padded array.

Return type:

ndarray

ptyrax.utils.center_scan_pos(scan_pos)[source]#

Center scanning positions by subtracting their mean.

Parameters:

scan_pos (Float[Array, 'N d']) – Array of scanning positions.

Returns:

Mean-subtracted positions.

Return type:

Float[Array, ‘N d’]

ptyrax.utils.compile_policy_patterns(policy_map)[source]#

Compile regex policies once and validate patterns up-front.

Parameters:

policy_map (dict[str, dict[str, Any]] | None)

Return type:

list[tuple[Pattern[str], dict[str, Any]]]

ptyrax.utils.complex_to_rgb(im, log10=False, cmap='hsv', gamma=1.0, clim=None, scale_min=False, max=1.0)[source]#

Convert a complex-valued image to an RGB representation.

Encodes magnitude as brightness and phase as hue using an HSV colormap.

Parameters:
  • im (Complex[Array, '... w h']) – Complex input image.

  • log10 (bool) – Apply log10 to magnitude before mapping.

  • cmap (Literal['hsv', 'lab']) – Colormap style (currently only "hsv" supported).

  • gamma (float) – Gamma exponent applied to magnitude.

  • clim (tuple[float, float]) – Optional (min, max) clipping range for magnitude.

  • scale_min (bool) – Reserved (not used).

  • max (float) – Maximum RGB output value.

Returns:

RGB float array with values in [0, max].

Return type:

Float[Array, ‘… w h 3’]

ptyrax.utils.compute_center_of_mass_shift(img, order=1)[source]#

Compute the center-of-mass shift required to center an image.

Parameters:
  • img (Float[Array, '... n m']) – Input image (can be batched); non-negative values are assumed.

  • order (int) – Power to raise pixel values when computing the weighted center.

Returns:

Shift vector (dy, dx) to translate the image to the center.

Return type:

Float[Array, ‘2’]

ptyrax.utils.convert_to_ft_sampling(pixel_number, pixel_size, scaling_factor=1.0, prop_dist=None, wavelength=None)[source]#

Compute the reciprocal-space (Fourier) sampling parameters.

Given real-space pixel count and size, returns the corresponding Fourier pixel count and size. Optionally computes the scaling from wavelength and propagation distance.

Parameters:
  • pixel_number (Integer[Array, '... d']) – Number of pixels in each dimension.

  • pixel_size (Integer[Array, '... d']) – Pixel size in each dimension.

  • scaling_factor (float) – Manual scaling factor (mutually exclusive with wavelength/prop_dist).

  • prop_dist (float) – Propagation distance (requires wavelength).

  • wavelength (float) – Photon wavelength (requires prop_dist).

Returns:

Tuple of (ft_pixel_number, ft_pixel_size).

Raises:

ValueError – If arguments are inconsistent.

Return type:

tuple[Integer[Array, ‘… d’], Float[Array, ‘… d’]]

ptyrax.utils.count_parameters(pytree)[source]#

Counts the total number of parameters in a PyTree.

Parameters:

pytree (PyTree) – A PyTree containing parameters.

Returns:

Total number of parameters.

Return type:

int

ptyrax.utils.fft(x, axes=(-2, -1), norm='ortho', fftshift=True)[source]#

Compute centered 2D FFT.

Applies ifftshift before and fftshift after the FFT so that the zero-frequency component is in the center.

Parameters:
  • x (Inexact[Array, '...']) – Input array.

  • axes (tuple[int, int]) – Axes over which to compute the FFT.

  • norm (Literal['backward', 'ortho', 'forward']) – Normalization mode.

  • fftshift (bool) – If True, center the transform.

Returns:

Complex FFT of the input.

Return type:

Complex[Array, ‘…’]

ptyrax.utils.flatten_dict(d, parent_key='', sep='.')[source]#

Recursively flatten nested dictionaries into a single-level dict.

Parameters:
  • d (dict) – Dictionary to flatten.

  • parent_key (str) – Prefix for keys at this level (used in recursion).

  • sep (str) – Separator joining nested key parts.

Returns:

Flat dictionary with concatenated keys.

Return type:

dict

ptyrax.utils.hsv_to_rgb(h, s, v)[source]#

Convert HSV color values to RGB.

Parameters:
  • h (Float[Array, '...']) – Hue channel in [0, 1].

  • s (Float[Array, '...']) – Saturation channel in [0, 1].

  • v (Float[Array, '...']) – Value channel in [0, 1].

Returns:

RGB array with last dimension of size 3.

Return type:

Float[Array, ‘… 3’]

ptyrax.utils.identity(*args)[source]#

Identity function that returns its arguments unchanged.

Useful as a default no-op callback in gin-configurable pipelines.

Parameters:

*args (Any) – Any arguments.

Returns:

The input arguments as a tuple.

Return type:

Any

ptyrax.utils.ifft(x, axes=(-2, -1), norm='ortho', fftshift=True)[source]#

Compute centered 2D inverse FFT.

Applies ifftshift before and fftshift after the IFFT so that the zero-frequency component is handled correctly.

Parameters:
  • x (Inexact[Array, '...']) – Input array (frequency domain).

  • axes (tuple) – Axes over which to compute the IFFT.

  • norm (Literal['backward', 'ortho', 'forward']) – Normalization mode.

  • fftshift (bool) – If True, center the transform.

Returns:

Complex inverse FFT of the input.

Return type:

Complex[Array, ‘…’]

ptyrax.utils.join_hdf5_paths(prefix, suffix)[source]#

Join two HDF5-like paths while keeping a single separator.

Parameters:
  • prefix (str | None)

  • suffix (str)

Return type:

str

ptyrax.utils.load_hdf5(file_path, key_translation=None)[source]#

This function loads the data in a hdf5 file from file_path and returns a dictionary.

The data_type = all, params, ptychogram

Parameters:
  • file_path (str)

  • key_translation (dict)

Return type:

dict

ptyrax.utils.make_length_n(parameter, n=2)[source]#

Utility to convert a single parameter value or a list of parameter values into an array of length n.

Parameters:
  • parameter (float | int | ndarray | list | tuple)

  • n (int)

Return type:

array

ptyrax.utils.make_or_reuse_axes(fig=None, gs=None)[source]#

Create or reuse matplotlib Figure and Axes.

Parameters:
  • fig (Figure) – Existing figure (required if gs is provided).

  • gs (SubplotSpec) – GridSpec subplot to add an axes to.

Returns:

Tuple of (fig, ax).

Raises:

ValueError – If gs is provided without fig.

Return type:

tuple[Figure, Axes]

ptyrax.utils.make_path_string(path)[source]#

Convert a JAX key-path list to a dot-separated string.

Parameters:

path (list[str]) – List of JAX tree path elements.

Returns:

Dot-joined path string.

Return type:

str

ptyrax.utils.median_pixel_value(img)[source]#

Compute the median pixel of an image.

Parameters:

img (Float[Array, '...'])

Return type:

float

ptyrax.utils.normalize(a, axes=None)[source]#

Normalize an array by its maximum value along specified axes.

Parameters:
  • a (Shaped[Array, '...']) – Input array.

  • axes (tuple[int, ...]) – Axes over which to take the maximum.

Returns:

Array divided by its maximum.

Return type:

Shaped[Array, ‘…’]

ptyrax.utils.normalize_hdf5_path(path)[source]#

Normalize an HDF5-like path by collapsing duplicate separators.

Parameters:

path (str | None)

Return type:

str

ptyrax.utils.normalize_power(a, axis=(0, 1))[source]#

Normalize an array so its total power equals 1.

The normalization enforces \(\sum |a|^2 = 1\) over axis.

Parameters:
  • a (Shaped[Array, '...']) – Input array.

  • axis (tuple[int, ...]) – Axes over which to compute the power.

Returns:

Power-normalized array.

Return type:

Shaped[Array, ‘…’]

ptyrax.utils.orthogonal(f)[source]#

Orthogonalize rows of a matrix using SVD.

Parameters:

f (Shaped[Array, '... d mn']) – Input matrix or batch of matrices.

Returns:

Matrix with orthogonalized rows.

Return type:

Shaped[Array, ‘… d mn’]

ptyrax.utils.orthogonalize(f)[source]#

Orthogonalize probe modes in a multi-mode array.

Parameters:

f (Shaped[Array, '... m n']) – Array of shape (modes, m, n) representing probe modes.

Returns:

Orthogonalized probe modes with same shape.

Raises:
  • NotImplementedError – If input has additional wavelength dimensions.

  • ValueError – If no mode dimension is present.

Return type:

Shaped[Array, ‘… m n’]

ptyrax.utils.parallel_orthogonal(f)#

Vectorized version of orthogonal. Takes similar arguments as orthogonal but with additional array axes over which orthogonal is mapped.

Original documentation:

Orthogonalize rows of a matrix using SVD.

Parameters:

f (Shaped[Array, '... d mn']) – Input matrix or batch of matrices.

Returns:

Matrix with orthogonalized rows.

Return type:

Shaped[Array, ‘… d mn’]

ptyrax.utils.phase_only_exp(phase)[source]#

Compute a unit-magnitude complex exponential from a real phase.

Equivalent to exp(1j * phase) but avoids complex input.

Parameters:

phase (Float[Array, '... m n']) – Real-valued phase array in radians.

Returns:

Complex array with unit magnitude and given phase.

Raises:

ValueError – If phase is already complex.

Return type:

Complex[Array, ‘… m n’]

ptyrax.utils.phase_only_exp_np(phase)[source]#

NumPy version of phase_only_exp().

Parameters:

phase (Float[Array, '... m n']) – Real-valued phase array in radians.

Returns:

Complex NumPy array with unit magnitude.

Raises:

ValueError – If phase is already complex.

Return type:

Complex[Array, ‘… m n’]

ptyrax.utils.plot(im, show=False, gs=None, fig=None, dpi=150, plot_text=False, **kwargs)[source]#

Plot an image array, dispatching to complex or real plotting.

Supports complex-valued images (phase-magnitude colormap), real-valued images, and objects with a __plot__ method.

Parameters:
  • im (Array | ndarray | bool | number | bool | int | float | complex | LiteralArray) – Image array to plot.

  • show (bool) – If True, call plt.show().

  • gs (GridSpec | None) – Optional GridSpec for subplot placement.

  • fig (Figure | None) – Optional existing Figure.

  • dpi (int | None) – Figure DPI.

  • plot_text (bool) – If True, annotate the figure with parameter info.

  • **kwargs – Extra arguments forwarded to the plotting backend.

Returns:

Tuple of (fig, gs, image_artists).

Return type:

tuple[Figure, SubplotSpec, list[AxesImage]]

ptyrax.utils.plot_complex(im, fig, gs, log10=False, cmap='hsv', gamma=1, title=None, clim=None, cbar=None, **kwargs)[source]#

Plot a complex-valued image by converting it to RGB using a phase- magnitude colormap.

Parameters:
  • im (Complex[Array, '...']) – Complex image.

  • fig (Figure) – Matplotlib Figure to draw on.

  • gs (SubplotSpec) – GridSpec/SubplotSpec to place the image.

  • log10 (bool) – Whether to take log10 of magnitude.

  • cmap (str)

  • gamma (float)

  • title (str | None)

  • clim (tuple[float, float] | None)

  • cbar (bool | None)

Returns:

Tuple of (fig, gs, image artists).

Return type:

tuple[Figure, SubplotSpec, list[AxesImage]]

ptyrax.utils.plot_real(im, fig, gs, title=None, cmap=None, gamma=1.0, cbar=False, log10=False, epsilon=1e-10, **kwargs)[source]#

Plot a real-valued image with optional gamma correction and log scaling.

Parameters:
  • im (Float[Array, '...']) – Real-valued image.

  • fig (Figure) – Matplotlib Figure to draw on.

  • gs (SubplotSpec) – GridSpec/SubplotSpec to place the image.

  • gamma (float) – Gamma exponent to apply to the image.

  • log10 (bool) – Whether to apply log10 scaling (uses epsilon to avoid log(0)).

  • title (str)

  • cmap (str | None)

  • cbar (bool)

  • epsilon (float)

Returns:

Tuple of (fig, gs, image artists).

Return type:

tuple[Figure, SubplotSpec, list[AxesImage]]

ptyrax.utils.real_to_rgb(tensor, log10=False, cmap='magma', gamma=1.0, **kwargs)[source]#

Map a real-valued array to an RGB image via a matplotlib colormap.

Parameters:
  • tensor (Float[Array, '...']) – Real input array.

  • log10 (bool) – Apply log10 scaling.

  • cmap (str) – Matplotlib colormap name.

  • gamma (float) – Gamma exponent applied to normalized amplitude.

Returns:

RGB float array with shape (*tensor.shape, 3).

Return type:

Float[Array, ‘… 3’]

ptyrax.utils.reduced_grad(fn, reduction_fn=<function <lambda>>)[source]#

Create a gradient function with a scalar reduction applied first.

Parameters:
  • fn (Callable) – Function whose output will be reduced then differentiated.

  • reduction_fn (Callable) – Scalar reduction (default: sum of absolute values).

Returns:

Gradient function of the reduced output.

Return type:

Callable

ptyrax.utils.repeat_axis(arr, target_len, axis)[source]#

Repeat elements along an axis and center-crop to target_len.

Parameters:
  • arr (ndarray) – Input array.

  • target_len (int) – Desired length along axis.

  • axis (int) – Axis to expand.

Returns:

Repeated and center-cropped array.

Return type:

ndarray

ptyrax.utils.resize_to_match(arr, target_shape, axis_policies=None, default_policy='pad')[source]#

Resize arr to target_shape: - Contracting axes: crop symmetrically - Expanding axes: per-axis policy or default_policy

Parameters:
  • arr (ndarray)

  • target_shape (Tuple[int, ...])

  • axis_policies (Dict[int, str])

  • default_policy (str)

Return type:

ndarray

ptyrax.utils.save_hdf5(file_path, data)[source]#

This function saves the data (a dictionary) to a hdf5 file at file_path.

Parameters:
  • file_path (str)

  • data (dict)

Return type:

None

ptyrax.utils.scaled_mean(a, scale=1.0, **kwargs)[source]#

Compute the mean of an array scaled by a constant factor.

Typically configured via gin as a loss reduction function.

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex | LiteralArray) – Input array.

  • scale (float) – Multiplicative scaling applied after mean reduction.

  • **kwargs – Extra arguments forwarded to jax.numpy.mean().

Returns:

Scalar mean value times scale.

Return type:

Float[Array, ‘’]

ptyrax.utils.set_probe_data_preserve_parametrization(model, new_probe_data)[source]#

Update model probe data while preserving the parametrization wrapper.

Uses eqx.tree_at() to replace model.illumination.probe.data with new_probe_data wrapped in the original parametrization type.

Parameters:
  • model (Any) – Model whose probe data should be updated.

  • new_probe_data (Any) – New data to set.

Returns:

Updated model with new probe data.

Return type:

Any

ptyrax.utils.shift_image(img, shift, order=1)[source]#

Shift a 2D image using scipy.ndimage.map_coordinates with zero padding.

Parameters:
  • img (Array) – 2D NumPy array

  • shift (Array) – Tuple or array (dy, dx), the shift to apply

  • order (int) – Interpolation order (1 = bilinear, 3 = cubic, etc.)

Returns:

Shifted image with same shape as input, zero-padded at edges.

Return type:

Array

ptyrax.utils.single_identity(args)[source]#

Return a single-element tuple unchanged.

Convenience identity function for callbacks that receive a tuple.

Parameters:

args (tuple[Any, ...]) – Input tuple.

Returns:

The same tuple.

Return type:

Any

ptyrax.utils.slice_at_center_to_shape(x, center, target_shape)[source]#

Extract a centered slice of x with the given target_shape around center.

Parameters:
  • x (Float[Array, '... n m']) – Array from which to slice. Last two dimensions are spatial.

  • center (Float[Array, '2']) – Center coordinate (y, x) in the same coordinate system as x.

  • target_shape (tuple[int, int]) – Desired output spatial shape (height, width).

Returns:

Dynamically sliced view of x with spatial dims equal to target_shape.

Return type:

Float[Array, ‘… n m’]

ptyrax.utils.soft_clip(a, min_value, max_value, relu_like=<PjitFunction of <function identity>>, scale=1.0)[source]#

Differentiable soft-clipping function.

Applies a smooth clamping operation using a ReLU-like activation to keep values within [min_value, max_value].

Parameters:
  • a (Shaped[Array, '...']) – Input array.

  • min_value (float | int) – Lower clipping bound.

  • max_value (float | int) – Upper clipping bound.

  • relu_like (Callable) – Activation function for soft clamping.

  • scale (float) – Scaling factor applied before and after clipping.

Returns:

Clipped array.

Return type:

Shaped[Array, ‘…’]

ptyrax.utils.sort_images_by_time(image_paths)[source]#

Sort image file paths by their filesystem modification time.

Parameters:

image_paths (list[str]) – List of file paths to sort.

Returns:

Paths sorted by ascending modification time.

Return type:

list[str]

ptyrax.utils.tile_axes(arr, target_shape)[source]#

Tile an array to at least target_shape and center-crop to exact size.

Parameters:
  • arr (ndarray) – Input array to tile.

  • target_shape (Tuple[int, ...]) – Desired output shape.

Returns:

Array tiled and cropped to target_shape.

Return type:

ndarray

ptyrax.utils.tree_slice_first(tree)[source]#

Slices the first element from each leaf in a PyTree.

Parameters:

tree (PyTree) – A PyTree containing arrays.

Returns:

A new PyTree with the first element sliced from each leaf.

Return type:

PyTree

ptyrax.utils.unstack_tree(stacked)[source]#

Split a stacked PyTree (with leading batch axis) into a list of PyTrees.

Parameters:

stacked (PyTree) – PyTree where every leaf has a leading batch dimension.

Returns:

List of PyTrees, one per element along axis 0.

Return type:

list[PyTree]

ptyrax.utils.vmap_nested(fn, in_axes, *args, **kwargs)[source]#

Apply nested jax.vmap() for multiple batch dimensions.

Parameters:
  • fn (Callable) – Function to vectorize.

  • in_axes (tuple[int, ...]) – Tuple of axes—one per nesting level.

Returns:

Multi-level vmapped function.

Return type:

Callable

ptyrax.utils.warn_if_duplicate_normalized_keys(state_keys, normalize_keys)[source]#

Warn when key normalization collapses distinct checkpoint paths.

If normalize_keys is enabled and stripping leading underscores from path segments causes two different keys to map to the same normalized key, a warning is emitted.

Parameters:
  • state_keys (list[str]) – List of HDF5 state keys.

  • normalize_keys (bool) – Whether normalization is active.

Return type:

None

ptyrax.utils.wrap_like_parametrization(reference, value)[source]#

Wrap a value in the same parametrization type as a reference.

If reference is an ArrayParametrization, wraps value in the same subclass; otherwise returns value unchanged.

Parameters:
  • reference (Any) – Object whose type is inspected.

  • value (Any) – Value to potentially wrap.

Returns:

Wrapped or unwrapped value.

Return type:

Any

ptyrax.utils.zero_pad_to_shape(x, target_shape)[source]#

Symmetrically zero-pad an array to target_shape.

Parameters:
  • x (Float[Array, '... n m']) – Input array (last two dims are spatial).

  • target_shape (tuple[int, int]) – Desired spatial dimensions.

Returns:

Zero-padded array with spatial dimensions equal to target_shape.

Return type:

Float[Array, ‘… n m’]