ptyrax.parametrizations#
Functions
Wrap an array in a |
|
|
Compute the mean L2 norm of an array along the specified axes. |
|
Convert a real-valued phase array to a complex unit-magnitude array. |
|
Convert a real-valued phase array to a complex unit-magnitude array using NumPy. |
|
Recursively resolve all |
|
Recursively resolve all index-dependent parameters at a specific index. |
Recursively resolve all index-dependent parameters, returning all indices. |
|
|
Recursively resolve all parametrizations in a module to their underlying arrays. |
Classes
|
Abstract base class for array parametrizations. |
|
Parametrization that produces a binary array via thresholding. |
|
A resolved index-dependent parameter bound to a single index. |
|
Identity parametrization that wraps an array without any transformation. |
|
Abstract base class for parameters that vary with the dataset index. |
|
Index-dependent parameter that selects a slice along a given dimension. |
|
Parametrization that normalizes an array by a fixed scale factor. |
Parametrization that stores a trainable offset relative to a fixed reference. |
|
|
Parametrization that represents a 2D array as a sum of outer products. |
|
Parametrization that constrains an array to unit magnitude (phase-only). |
|
Parametrization that stores an array in the wavelet domain. |
- class ptyrax.parametrizations.ArrayParametrization(output_shape)[source]#
Bases:
ModuleAbstract base class for array parametrizations.
An array parametrization wraps trainable parameters and produces an output array via its
__call__method. Subclasses implement specific constraints (e.g., phase-only, normalized, outer-product decomposition) while exposing a uniform interface that can be resolved transparently usingresolve_array_parametrizations().- Variables:
output_shape – The shape of the array produced by
__call__.- Parameters:
output_shape (tuple)
Example
>>> parametrized_model = build_model(config) >>> resolved = resolve_parametrizations(parametrized_model, index=0) >>> output = resolved(inputs)
- abstractmethod __call__(*args, **kwargs)[source]#
Compute and return the parametrized output array.
- Return type:
Array
- output_shape: tuple#
- property shape: tuple[int, ...]#
The shape of the output array produced by this parametrization.
- class ptyrax.parametrizations.BinaryArrayParametrization(data_or_initializer, threshold=0.5)[source]#
Bases:
ArrayParametrizationParametrization that produces a binary array via thresholding.
- Parameters:
data_or_initializer (Callable[[tuple], Shaped])
threshold (float)
- __call__()[source]#
Return a binary array obtained by thresholding the stored data.
- Returns:
Binary array where values are 1 if data > threshold and 0 otherwise.
- Return type:
Bool[Array, ‘…’]
- output_shape: tuple#
- property shape: tuple[int, ...]#
The shape of the output array produced by this parametrization.
- threshold: float#
- class ptyrax.parametrizations.BoundSliceParameter(parameter, n)[source]#
Bases:
ModuleA resolved index-dependent parameter bound to a single index.
Created by
resolve_index_dependent_parameters()when anIndexDependentParameteris resolved at a specific index. Useat_current_index()to access the bound value.- Variables:
parameter – The resolved parameter value for the bound index.
n – The total number of indices the original parameter supported.
- Parameters:
parameter (PyTree)
n (int)
- property all: T#
- n: int#
- parameter: PyTree#
- class ptyrax.parametrizations.DirectArrayParametrization(data_or_initializer)[source]#
Bases:
ArrayParametrizationIdentity parametrization that wraps an array without any transformation.
This is a no-op parametrization: calling it simply returns the stored data unchanged. It is useful for providing a uniform
ArrayParametrizationinterface around arrays that require no constraints.- Variables:
_data – The stored array.
- Parameters:
data_or_initializer (Callable[[tuple], Shaped])
- Return type:
Shaped[Array, ‘…’]
Example
>>> param = DirectArrayParametrization(jnp.ones((64, 64))) >>> assert jnp.array_equal(param(), jnp.ones((64, 64)))
- __call__()[source]#
Return the stored array unchanged.
- Returns:
The wrapped array.
- Return type:
Float[Array, ‘…’]
- property output_shape: tuple[int, ...]#
The shape of the stored array.
- property shape: tuple[int, ...]#
The shape of the output array produced by this parametrization.
- class ptyrax.parametrizations.IndexDependentParameter(_index=None)[source]#
Bases:
Module,Generic[T]Abstract base class for parameters that vary with the dataset index.
In ptychography, certain model parameters (e.g., scan positions, per-frame aberrations) differ for each diffraction pattern in the dataset. This class provides a uniform interface to access index-specific values either explicitly via
at_index()or implicitly after resolution viaat_current_index().Subclasses must implement
at_index(),n, andall.- Variables:
_index – The currently bound index, or
Noneif not yet resolved.- Parameters:
_index (int | None)
- abstract property all: T#
Return all parameter values across all indices.
- at(index)[source]#
Alias for
at_index().- Parameters:
index (int)
- Return type:
T
- abstractmethod at_index(index)[source]#
Return the parameter value at a specific dataset index.
- Parameters:
index (int) – The dataset index to retrieve.
- Returns:
The parameter value for the given index.
- Return type:
T
- abstract property n: int#
The number of distinct index values this parameter supports.
- class ptyrax.parametrizations.IndexSliceParameter(parameters, dim=0)[source]#
Bases:
IndexDependentParameter[T]Index-dependent parameter that selects a slice along a given dimension.
Stores a pytree of arrays where one dimension corresponds to the dataset index. Calling
at_index()slices along that dimension to extract the parameter for a single index.- Variables:
parameters – The full pytree of parameters (with an index dimension).
slice_dim – Static pytree matching
parametersindicating which dimension to slice for each leaf.
- Parameters:
parameters (PyTree)
dim (int)
Example
>>> positions = jnp.zeros((100, 2)) # 100 scan positions, 2D >>> param = IndexSliceParameter(positions, dim=0) >>> param.at_index(5) # shape (2,)
- property all: Shaped[Array, 'indices ...']#
The full parameter pytree including the index dimension.
- at(index)#
Alias for
at_index().- Parameters:
index (int)
- Return type:
T
- at_current_index()#
- Return type:
T
- at_index(index)[source]#
Slice the parameters at the given dataset index.
- Parameters:
index (int) – Dataset index to select.
- Returns:
The parameter pytree with the index dimension removed.
- Return type:
Shaped[Array, ‘…’]
- property n: int#
The number of indices (size of the slicing dimension).
- parameters: PyTree#
- slice_dim: PyTree#
- class ptyrax.parametrizations.NormalizedArrayParametrization(initial_data, scale=1.0, scaling_function=None)[source]#
Bases:
ArrayParametrizationParametrization that normalizes an array by a fixed scale factor.
Stores the data divided by a scale factor and multiplies by that factor on evaluation. This keeps the internal parameters at unit scale during optimization, which can improve gradient conditioning.
The scale factor is computed either from a provided
scaling_functionapplied to the initial data, or from the explicitscaleargument.- Variables:
output_shape – Shape of the output array.
_data – Internal normalized data (
initial_data / scale)._scale – The normalization scale factor.
- Parameters:
initial_data (Float[Array, '...'])
scale (float)
scaling_function (Callable[[Shaped[Array, '...']], Shaped[Array, '...']])
Example
>>> param = NormalizedArrayParametrization(probe_array, scale=1e3) >>> resolved_array = param() # returns probe_array
- __call__()[source]#
Return the reconstructed array (data multiplied by scale).
- Returns:
The denormalized array.
- Return type:
Shaped[Array, ‘…’]
- output_shape: tuple#
- property shape: tuple[int, ...]#
The shape of the output array produced by this parametrization.
- class ptyrax.parametrizations.NormalizedReferencedArrayParametrization(initial_data, scale=1.0, scaling_function=None)[source]#
Bases:
ArrayParametrizationParametrization that stores a trainable offset relative to a fixed reference.
The output is computed as
data * scale + reference_value, wherereference_valueis the initial data (with stopped gradients) anddatastarts at zero. This allows optimization to learn a correction on top of an initial estimate without modifying the reference.- Variables:
output_shape – Shape of the output array.
_data – Trainable offset initialized to zeros.
_scale – Scale factor applied to the trainable offset.
_reference_value – Fixed reference array (gradient-stopped).
- Parameters:
initial_data (Shaped[Array, '...'])
scale (float)
scaling_function (Callable[[Shaped[Array, '...']], Shaped[Array, '...']])
Example
>>> param = NormalizedReferencedArrayParametrization(initial_probe) >>> # Initially returns initial_probe (since _data is zero) >>> assert jnp.allclose(param(), initial_probe)
- __call__()[source]#
Return the parametrized array as
data * scale + reference.The reference value has its gradient stopped to prevent training of the initial estimate.
- Returns:
The sum of the scaled trainable offset and the fixed reference.
- Return type:
Shaped[Array, ‘…’]
- output_shape: tuple#
- property shape: tuple[int, ...]#
The shape of the output array produced by this parametrization.
- class ptyrax.parametrizations.OuterProductArrayParametrization(output_shape, n_outer, initializer=<function ones>)[source]#
Bases:
ArrayParametrizationParametrization that represents a 2D array as a sum of outer products.
Decomposes a 2D array of shape
(M, N, d)inton_outerrank-1 components, stored as column vectors of shape(M, n_outer, d)and row vectors of shape(n_outer, N, d). The output is reconstructed via Einstein summation: \(A_{MNd} = \sum_s c_{Msd} \cdot r_{sNd}\).This reduces the number of free parameters from \(M \times N\) to \(n_{outer} \times (M + N)\), which can act as a regularizer.
- Variables:
output_shape – Shape of the reconstructed output array.
column_vector – Column factors of shape
(..., M, n_outer, d).row_vector – Row factors of shape
(..., n_outer, N, d).
- Parameters:
output_shape (tuple)
n_outer (int)
initializer (Callable[[tuple], Array] | None)
Example
>>> param = OuterProductArrayParametrization(output_shape=(64, 64, 1), n_outer=4) >>> array = param() # shape (64, 64, 1)
- __call__()[source]#
Reconstruct the full array from the outer product of column and row vectors.
- Returns:
The reconstructed array of shape
output_shape.- Return type:
Shaped[Array, ‘… M N d’]
- column_vector: Shaped[Array, '... M s d']#
- output_shape: tuple#
- row_vector: Shaped[Array, '... s N d']#
- property shape: tuple[int, ...]#
The shape of the output array produced by this parametrization.
- class ptyrax.parametrizations.PhaseOnlyArrayParametrization(output_shape, phase_initializer=<function zeros>)[source]#
Bases:
ArrayParametrizationParametrization that constrains an array to unit magnitude (phase-only).
Stores a real-valued phase array and produces a complex array with unit magnitude via \(e^{i \phi}\). This is useful for representing phase screens or phase-only optical elements.
- Variables:
output_shape – Shape of the output complex array.
phase – Trainable real-valued phase array in radians.
- Parameters:
output_shape (tuple)
phase_initializer (Callable[[tuple], Float])
Example
>>> param = PhaseOnlyArrayParametrization(output_shape=(128, 128)) >>> field = param() # complex array with |field| == 1 everywhere
- __call__(**kwargs)[source]#
Compute the unit-magnitude complex array from the stored phase.
- Returns:
Complex array \(e^{i \phi}\) with unit magnitude.
- Return type:
Complex[Array, ‘’]
- output_shape: tuple#
- phase: Float[Array, 'M N']#
- property shape: tuple[int, ...]#
The shape of the output array produced by this parametrization.
- class ptyrax.parametrizations.WaveletArrayParametrization(data_or_initializer)[source]#
Bases:
ArrayParametrizationParametrization that stores an array in the wavelet domain.
Applies a 2D Haar wavelet decomposition to the initial data and stores the wavelet coefficients as the trainable parameters. On evaluation, the inverse wavelet transform reconstructs the spatial-domain array.
Operating in the wavelet domain can provide a multi-scale representation that encourages sparsity in the wavelet basis.
- Variables:
wavelet_coefficients – Trainable wavelet coefficients (Haar, level 1).
- Parameters:
data_or_initializer (Callable[[tuple], Shaped])
Example
>>> param = WaveletArrayParametrization(initial_array) >>> reconstructed = param() # spatial-domain array
- __call__(*args, **kwargs)[source]#
Reconstruct the spatial-domain array from wavelet coefficients.
- Returns:
The inverse-wavelet-transformed array.
- Return type:
Shaped[Array, ‘…’]
- output_shape: tuple#
- property shape: tuple[int, ...]#
The shape of the output array produced by this parametrization.
- wavelet_coefficients: Shaped[Array, '...']#
- ptyrax.parametrizations.as_direct_array_parametrization(data)[source]#
Wrap an array in a
DirectArrayParametrizationif needed.If
datais already anArrayParametrization, it is returned unchanged. Otherwise, it is wrapped in aDirectArrayParametrization.- Parameters:
data (Shaped[Array, '...']) – An array or existing parametrization.
- Returns:
An
ArrayParametrizationwrapping the input data.- Return type:
- ptyrax.parametrizations.mean_l2_norm(a, axes=(-3, -2))[source]#
Compute the mean L2 norm of an array along the specified axes.
This is typically used as a
scaling_functionforNormalizedArrayParametrizationto normalize arrays to a target mean L2 norm.- Parameters:
a (Float[Array, '* n m d']) – Input array to compute the norm of.
axes (tuple[int, ...]) – Axes along which to compute the L2 norm before averaging.
- Returns:
The mean L2 norm, with the specified axes reduced.
- Return type:
Float[Array, ‘* d’]
- ptyrax.parametrizations.phase_only_exp(phase)[source]#
Convert a real-valued phase array to a complex unit-magnitude array.
Computes \(e^{i \phi}\) as \(\cos(\phi) + i \sin(\phi)\) using JAX.
- Parameters:
phase (Float[Array, '... m n']) – Real-valued array of phases in radians.
- Returns:
Complex array with unit magnitude and the given phases.
- Raises:
ValueError – If the input array is complex-valued.
- Return type:
Complex[Array, ‘… m n’]
- ptyrax.parametrizations.phase_only_exp_np(phase)[source]#
Convert a real-valued phase array to a complex unit-magnitude array using NumPy.
Computes \(e^{i \phi}\) as \(\cos(\phi) + i \sin(\phi)\) using NumPy (host-side, non-JIT-compatible).
- Parameters:
phase (Float[Array, '... m n']) – Real-valued NumPy array of phases in radians.
- Returns:
Complex NumPy array with unit magnitude and the given phases.
- Raises:
ValueError – If the input array is complex-valued.
- Return type:
Complex[Array, ‘… m n’]
- ptyrax.parametrizations.resolve_array_parametrizations(module)[source]#
Recursively resolve all
ArrayParametrizationinstances in a module.Traverses the pytree and replaces each
ArrayParametrizationleaf with the array returned by calling it. Nested parametrizations are resolved recursively.- Parameters:
module (T) – The module (or pytree) containing array parametrizations.
- Returns:
A copy of the module with all array parametrizations replaced by their output arrays.
- Return type:
T
- ptyrax.parametrizations.resolve_index_dependent_parameters(module, index)[source]#
Recursively resolve all index-dependent parameters at a specific index.
Traverses the pytree and replaces each
IndexDependentParameterwith aBoundSliceParametercontaining the value at the specified index.- Parameters:
module (T) – The module (or pytree) containing index-dependent parameters.
index (int) – The dataset index to resolve at.
- Returns:
A copy of the module with all index-dependent parameters bound to the given index.
- Return type:
T
- ptyrax.parametrizations.resolve_index_dependent_parameters_all(module)[source]#
Recursively resolve all index-dependent parameters, returning all indices.
Traverses the pytree and replaces each
IndexDependentParameterwith itsallproperty, which includes all index values stacked along the leading dimension.- Parameters:
module (T) – The module (or pytree) containing index-dependent parameters.
- Returns:
A copy of the module with all index-dependent parameters replaced by their full (all-index) values.
- Return type:
T
- ptyrax.parametrizations.resolve_parametrizations(module, index=None)[source]#
Recursively resolve all parametrizations in a module to their underlying arrays.
Resolves both
ArrayParametrizationandIndexDependentParameterinstances in the module’s pytree. This should be called before using a model inside a JAX jit-compiled function.- Parameters:
module (T) – The module (or pytree) containing parametrizations to resolve.
index (int | None) – If provided, resolves index-dependent parameters at this specific dataset index. If
None, resolves all indices at once (stacking along the leading dimension).
- Returns:
A copy of the module with all parametrizations replaced by their resolved array values.
- Return type:
T
Example
>>> resolved_model = resolve_parametrizations(model, index=0) >>> output = resolved_model(inputs)