ptyrax.parametrizations

Contents

ptyrax.parametrizations#

Functions

as_direct_array_parametrization(data)

Wrap an array in a DirectArrayParametrization if needed.

mean_l2_norm(a[, axes])

Compute the mean L2 norm of an array along the specified axes.

phase_only_exp(phase)

Convert a real-valued phase array to a complex unit-magnitude array.

phase_only_exp_np(phase)

Convert a real-valued phase array to a complex unit-magnitude array using NumPy.

resolve_array_parametrizations(module)

Recursively resolve all ArrayParametrization instances in a module.

resolve_index_dependent_parameters(module, index)

Recursively resolve all index-dependent parameters at a specific index.

resolve_index_dependent_parameters_all(module)

Recursively resolve all index-dependent parameters, returning all indices.

resolve_parametrizations(module[, index])

Recursively resolve all parametrizations in a module to their underlying arrays.

Classes

ArrayParametrization(output_shape)

Abstract base class for array parametrizations.

BinaryArrayParametrization(data_or_initializer)

Parametrization that produces a binary array via thresholding.

BoundSliceParameter(parameter, n)

A resolved index-dependent parameter bound to a single index.

DirectArrayParametrization(data_or_initializer)

Identity parametrization that wraps an array without any transformation.

IndexDependentParameter([_index])

Abstract base class for parameters that vary with the dataset index.

IndexSliceParameter(parameters[, dim])

Index-dependent parameter that selects a slice along a given dimension.

NormalizedArrayParametrization(initial_data)

Parametrization that normalizes an array by a fixed scale factor.

NormalizedReferencedArrayParametrization(...)

Parametrization that stores a trainable offset relative to a fixed reference.

OuterProductArrayParametrization(...[, ...])

Parametrization that represents a 2D array as a sum of outer products.

PhaseOnlyArrayParametrization(output_shape)

Parametrization that constrains an array to unit magnitude (phase-only).

WaveletArrayParametrization(data_or_initializer)

Parametrization that stores an array in the wavelet domain.

class ptyrax.parametrizations.ArrayParametrization(output_shape)[source]#

Bases: Module

Abstract base class for array parametrizations.

An array parametrization wraps trainable parameters and produces an output array via its __call__ method. Subclasses implement specific constraints (e.g., phase-only, normalized, outer-product decomposition) while exposing a uniform interface that can be resolved transparently using resolve_array_parametrizations().

Variables:

output_shape – The shape of the array produced by __call__.

Parameters:

output_shape (tuple)

Example

>>> parametrized_model = build_model(config)
>>> resolved = resolve_parametrizations(parametrized_model, index=0)
>>> output = resolved(inputs)
abstractmethod __call__(*args, **kwargs)[source]#

Compute and return the parametrized output array.

Return type:

Array

output_shape: tuple#
property shape: tuple[int, ...]#

The shape of the output array produced by this parametrization.

class ptyrax.parametrizations.BinaryArrayParametrization(data_or_initializer, threshold=0.5)[source]#

Bases: ArrayParametrization

Parametrization that produces a binary array via thresholding.

Parameters:
  • data_or_initializer (Callable[[tuple], Shaped])

  • threshold (float)

__call__()[source]#

Return a binary array obtained by thresholding the stored data.

Returns:

Binary array where values are 1 if data > threshold and 0 otherwise.

Return type:

Bool[Array, ‘…’]

output_shape: tuple#
property shape: tuple[int, ...]#

The shape of the output array produced by this parametrization.

threshold: float#
class ptyrax.parametrizations.BoundSliceParameter(parameter, n)[source]#

Bases: Module

A resolved index-dependent parameter bound to a single index.

Created by resolve_index_dependent_parameters() when an IndexDependentParameter is resolved at a specific index. Use at_current_index() to access the bound value.

Variables:
  • parameter – The resolved parameter value for the bound index.

  • n – The total number of indices the original parameter supported.

Parameters:
  • parameter (PyTree)

  • n (int)

property all: T#
at_current_index()[source]#

Return the parameter value for the bound index.

Return type:

T

at_index(*args)[source]#
Return type:

T

n: int#
parameter: PyTree#
class ptyrax.parametrizations.DirectArrayParametrization(data_or_initializer)[source]#

Bases: ArrayParametrization

Identity parametrization that wraps an array without any transformation.

This is a no-op parametrization: calling it simply returns the stored data unchanged. It is useful for providing a uniform ArrayParametrization interface around arrays that require no constraints.

Variables:

_data – The stored array.

Parameters:

data_or_initializer (Callable[[tuple], Shaped])

Return type:

Shaped[Array, ‘…’]

Example

>>> param = DirectArrayParametrization(jnp.ones((64, 64)))
>>> assert jnp.array_equal(param(), jnp.ones((64, 64)))
__call__()[source]#

Return the stored array unchanged.

Returns:

The wrapped array.

Return type:

Float[Array, ‘…’]

property output_shape: tuple[int, ...]#

The shape of the stored array.

property shape: tuple[int, ...]#

The shape of the output array produced by this parametrization.

class ptyrax.parametrizations.IndexDependentParameter(_index=None)[source]#

Bases: Module, Generic[T]

Abstract base class for parameters that vary with the dataset index.

In ptychography, certain model parameters (e.g., scan positions, per-frame aberrations) differ for each diffraction pattern in the dataset. This class provides a uniform interface to access index-specific values either explicitly via at_index() or implicitly after resolution via at_current_index().

Subclasses must implement at_index(), n, and all.

Variables:

_index – The currently bound index, or None if not yet resolved.

Parameters:

_index (int | None)

abstract property all: T#

Return all parameter values across all indices.

at(index)[source]#

Alias for at_index().

Parameters:

index (int)

Return type:

T

at_current_index()[source]#
Return type:

T

abstractmethod at_index(index)[source]#

Return the parameter value at a specific dataset index.

Parameters:

index (int) – The dataset index to retrieve.

Returns:

The parameter value for the given index.

Return type:

T

abstract property n: int#

The number of distinct index values this parameter supports.

class ptyrax.parametrizations.IndexSliceParameter(parameters, dim=0)[source]#

Bases: IndexDependentParameter[T]

Index-dependent parameter that selects a slice along a given dimension.

Stores a pytree of arrays where one dimension corresponds to the dataset index. Calling at_index() slices along that dimension to extract the parameter for a single index.

Variables:
  • parameters – The full pytree of parameters (with an index dimension).

  • slice_dim – Static pytree matching parameters indicating which dimension to slice for each leaf.

Parameters:
  • parameters (PyTree)

  • dim (int)

Example

>>> positions = jnp.zeros((100, 2))  # 100 scan positions, 2D
>>> param = IndexSliceParameter(positions, dim=0)
>>> param.at_index(5)  # shape (2,)
property all: Shaped[Array, 'indices ...']#

The full parameter pytree including the index dimension.

at(index)#

Alias for at_index().

Parameters:

index (int)

Return type:

T

at_current_index()#
Return type:

T

at_index(index)[source]#

Slice the parameters at the given dataset index.

Parameters:

index (int) – Dataset index to select.

Returns:

The parameter pytree with the index dimension removed.

Return type:

Shaped[Array, ‘…’]

property n: int#

The number of indices (size of the slicing dimension).

parameters: PyTree#
slice_dim: PyTree#
class ptyrax.parametrizations.NormalizedArrayParametrization(initial_data, scale=1.0, scaling_function=None)[source]#

Bases: ArrayParametrization

Parametrization that normalizes an array by a fixed scale factor.

Stores the data divided by a scale factor and multiplies by that factor on evaluation. This keeps the internal parameters at unit scale during optimization, which can improve gradient conditioning.

The scale factor is computed either from a provided scaling_function applied to the initial data, or from the explicit scale argument.

Variables:
  • output_shape – Shape of the output array.

  • _data – Internal normalized data (initial_data / scale).

  • _scale – The normalization scale factor.

Parameters:
  • initial_data (Float[Array, '...'])

  • scale (float)

  • scaling_function (Callable[[Shaped[Array, '...']], Shaped[Array, '...']])

Example

>>> param = NormalizedArrayParametrization(probe_array, scale=1e3)
>>> resolved_array = param()  # returns probe_array
__call__()[source]#

Return the reconstructed array (data multiplied by scale).

Returns:

The denormalized array.

Return type:

Shaped[Array, ‘…’]

output_shape: tuple#
property shape: tuple[int, ...]#

The shape of the output array produced by this parametrization.

class ptyrax.parametrizations.NormalizedReferencedArrayParametrization(initial_data, scale=1.0, scaling_function=None)[source]#

Bases: ArrayParametrization

Parametrization that stores a trainable offset relative to a fixed reference.

The output is computed as data * scale + reference_value, where reference_value is the initial data (with stopped gradients) and data starts at zero. This allows optimization to learn a correction on top of an initial estimate without modifying the reference.

Variables:
  • output_shape – Shape of the output array.

  • _data – Trainable offset initialized to zeros.

  • _scale – Scale factor applied to the trainable offset.

  • _reference_value – Fixed reference array (gradient-stopped).

Parameters:
  • initial_data (Shaped[Array, '...'])

  • scale (float)

  • scaling_function (Callable[[Shaped[Array, '...']], Shaped[Array, '...']])

Example

>>> param = NormalizedReferencedArrayParametrization(initial_probe)
>>> # Initially returns initial_probe (since _data is zero)
>>> assert jnp.allclose(param(), initial_probe)
__call__()[source]#

Return the parametrized array as data * scale + reference.

The reference value has its gradient stopped to prevent training of the initial estimate.

Returns:

The sum of the scaled trainable offset and the fixed reference.

Return type:

Shaped[Array, ‘…’]

output_shape: tuple#
property shape: tuple[int, ...]#

The shape of the output array produced by this parametrization.

class ptyrax.parametrizations.OuterProductArrayParametrization(output_shape, n_outer, initializer=<function ones>)[source]#

Bases: ArrayParametrization

Parametrization that represents a 2D array as a sum of outer products.

Decomposes a 2D array of shape (M, N, d) into n_outer rank-1 components, stored as column vectors of shape (M, n_outer, d) and row vectors of shape (n_outer, N, d). The output is reconstructed via Einstein summation: \(A_{MNd} = \sum_s c_{Msd} \cdot r_{sNd}\).

This reduces the number of free parameters from \(M \times N\) to \(n_{outer} \times (M + N)\), which can act as a regularizer.

Variables:
  • output_shape – Shape of the reconstructed output array.

  • column_vector – Column factors of shape (..., M, n_outer, d).

  • row_vector – Row factors of shape (..., n_outer, N, d).

Parameters:
  • output_shape (tuple)

  • n_outer (int)

  • initializer (Callable[[tuple], Array] | None)

Example

>>> param = OuterProductArrayParametrization(output_shape=(64, 64, 1), n_outer=4)
>>> array = param()  # shape (64, 64, 1)
__call__()[source]#

Reconstruct the full array from the outer product of column and row vectors.

Returns:

The reconstructed array of shape output_shape.

Return type:

Shaped[Array, ‘… M N d’]

column_vector: Shaped[Array, '... M s d']#
output_shape: tuple#
row_vector: Shaped[Array, '... s N d']#
property shape: tuple[int, ...]#

The shape of the output array produced by this parametrization.

class ptyrax.parametrizations.PhaseOnlyArrayParametrization(output_shape, phase_initializer=<function zeros>)[source]#

Bases: ArrayParametrization

Parametrization that constrains an array to unit magnitude (phase-only).

Stores a real-valued phase array and produces a complex array with unit magnitude via \(e^{i \phi}\). This is useful for representing phase screens or phase-only optical elements.

Variables:
  • output_shape – Shape of the output complex array.

  • phase – Trainable real-valued phase array in radians.

Parameters:
  • output_shape (tuple)

  • phase_initializer (Callable[[tuple], Float])

Example

>>> param = PhaseOnlyArrayParametrization(output_shape=(128, 128))
>>> field = param()  # complex array with |field| == 1 everywhere
__call__(**kwargs)[source]#

Compute the unit-magnitude complex array from the stored phase.

Returns:

Complex array \(e^{i \phi}\) with unit magnitude.

Return type:

Complex[Array, ‘’]

output_shape: tuple#
phase: Float[Array, 'M N']#
property shape: tuple[int, ...]#

The shape of the output array produced by this parametrization.

class ptyrax.parametrizations.WaveletArrayParametrization(data_or_initializer)[source]#

Bases: ArrayParametrization

Parametrization that stores an array in the wavelet domain.

Applies a 2D Haar wavelet decomposition to the initial data and stores the wavelet coefficients as the trainable parameters. On evaluation, the inverse wavelet transform reconstructs the spatial-domain array.

Operating in the wavelet domain can provide a multi-scale representation that encourages sparsity in the wavelet basis.

Variables:

wavelet_coefficients – Trainable wavelet coefficients (Haar, level 1).

Parameters:

data_or_initializer (Callable[[tuple], Shaped])

Example

>>> param = WaveletArrayParametrization(initial_array)
>>> reconstructed = param()  # spatial-domain array
__call__(*args, **kwargs)[source]#

Reconstruct the spatial-domain array from wavelet coefficients.

Returns:

The inverse-wavelet-transformed array.

Return type:

Shaped[Array, ‘…’]

output_shape: tuple#
property shape: tuple[int, ...]#

The shape of the output array produced by this parametrization.

wavelet_coefficients: Shaped[Array, '...']#
ptyrax.parametrizations.as_direct_array_parametrization(data)[source]#

Wrap an array in a DirectArrayParametrization if needed.

If data is already an ArrayParametrization, it is returned unchanged. Otherwise, it is wrapped in a DirectArrayParametrization.

Parameters:

data (Shaped[Array, '...']) – An array or existing parametrization.

Returns:

An ArrayParametrization wrapping the input data.

Return type:

DirectArrayParametrization

ptyrax.parametrizations.mean_l2_norm(a, axes=(-3, -2))[source]#

Compute the mean L2 norm of an array along the specified axes.

This is typically used as a scaling_function for NormalizedArrayParametrization to normalize arrays to a target mean L2 norm.

Parameters:
  • a (Float[Array, '* n m d']) – Input array to compute the norm of.

  • axes (tuple[int, ...]) – Axes along which to compute the L2 norm before averaging.

Returns:

The mean L2 norm, with the specified axes reduced.

Return type:

Float[Array, ‘* d’]

ptyrax.parametrizations.phase_only_exp(phase)[source]#

Convert a real-valued phase array to a complex unit-magnitude array.

Computes \(e^{i \phi}\) as \(\cos(\phi) + i \sin(\phi)\) using JAX.

Parameters:

phase (Float[Array, '... m n']) – Real-valued array of phases in radians.

Returns:

Complex array with unit magnitude and the given phases.

Raises:

ValueError – If the input array is complex-valued.

Return type:

Complex[Array, ‘… m n’]

ptyrax.parametrizations.phase_only_exp_np(phase)[source]#

Convert a real-valued phase array to a complex unit-magnitude array using NumPy.

Computes \(e^{i \phi}\) as \(\cos(\phi) + i \sin(\phi)\) using NumPy (host-side, non-JIT-compatible).

Parameters:

phase (Float[Array, '... m n']) – Real-valued NumPy array of phases in radians.

Returns:

Complex NumPy array with unit magnitude and the given phases.

Raises:

ValueError – If the input array is complex-valued.

Return type:

Complex[Array, ‘… m n’]

ptyrax.parametrizations.resolve_array_parametrizations(module)[source]#

Recursively resolve all ArrayParametrization instances in a module.

Traverses the pytree and replaces each ArrayParametrization leaf with the array returned by calling it. Nested parametrizations are resolved recursively.

Parameters:

module (T) – The module (or pytree) containing array parametrizations.

Returns:

A copy of the module with all array parametrizations replaced by their output arrays.

Return type:

T

ptyrax.parametrizations.resolve_index_dependent_parameters(module, index)[source]#

Recursively resolve all index-dependent parameters at a specific index.

Traverses the pytree and replaces each IndexDependentParameter with a BoundSliceParameter containing the value at the specified index.

Parameters:
  • module (T) – The module (or pytree) containing index-dependent parameters.

  • index (int) – The dataset index to resolve at.

Returns:

A copy of the module with all index-dependent parameters bound to the given index.

Return type:

T

ptyrax.parametrizations.resolve_index_dependent_parameters_all(module)[source]#

Recursively resolve all index-dependent parameters, returning all indices.

Traverses the pytree and replaces each IndexDependentParameter with its all property, which includes all index values stacked along the leading dimension.

Parameters:

module (T) – The module (or pytree) containing index-dependent parameters.

Returns:

A copy of the module with all index-dependent parameters replaced by their full (all-index) values.

Return type:

T

ptyrax.parametrizations.resolve_parametrizations(module, index=None)[source]#

Recursively resolve all parametrizations in a module to their underlying arrays.

Resolves both ArrayParametrization and IndexDependentParameter instances in the module’s pytree. This should be called before using a model inside a JAX jit-compiled function.

Parameters:
  • module (T) – The module (or pytree) containing parametrizations to resolve.

  • index (int | None) – If provided, resolves index-dependent parameters at this specific dataset index. If None, resolves all indices at once (stacking along the leading dimension).

Returns:

A copy of the module with all parametrizations replaced by their resolved array values.

Return type:

T

Example

>>> resolved_model = resolve_parametrizations(model, index=0)
>>> output = resolved_model(inputs)