from abc import abstractmethod
from typing import Any, Callable, Generic, Optional, TypeVar, Union
import equinox as eqx
import gin
import jax
import jax.numpy as jnp
import jaxwt as jwt
import numpy as np
from chromatix.ops import binarize
from jaxtyping import Array, Bool, Complex, Float, PyTree, Shaped
[docs]
def phase_only_exp(phase: Float[Array, "... m n"]) -> Complex[Array, "... m n"]:
r"""Convert a real-valued phase array to a complex unit-magnitude array.
Computes :math:`e^{i \phi}` as :math:`\cos(\phi) + i \sin(\phi)` using JAX.
Args:
phase: Real-valued array of phases in radians.
Returns:
Complex array with unit magnitude and the given phases.
Raises:
ValueError: If the input array is complex-valued.
"""
if jnp.iscomplexobj(phase):
raise ValueError("Input to phase_only_exp should be real-valued phase array.")
return jnp.cos(phase) + 1j * jnp.sin(phase)
[docs]
def phase_only_exp_np(phase: Float[Array, "... m n"]) -> Complex[Array, "... m n"]:
r"""Convert a real-valued phase array to a complex unit-magnitude array
using NumPy.
Computes :math:`e^{i \phi}` as :math:`\cos(\phi) + i \sin(\phi)` using NumPy
(host-side, non-JIT-compatible).
Args:
phase: Real-valued NumPy array of phases in radians.
Returns:
Complex NumPy array with unit magnitude and the given phases.
Raises:
ValueError: If the input array is complex-valued.
"""
if np.iscomplexobj(phase):
raise ValueError("Input to phase_only_exp should be real-valued phase array.")
return np.cos(phase) + 1j * np.sin(phase)
T = TypeVar("T")
[docs]
def resolve_parametrizations(module: T, index: int | None = None) -> T:
"""Recursively resolve all parametrizations in a module to their underlying
arrays.
Resolves both :py:class:`~ptyrax.parametrizations.ArrayParametrization` and
:py:class:`~ptyrax.parametrizations.IndexDependentParameter` instances in the
module's pytree. This should be called before using a model inside a JAX
jit-compiled function.
Args:
module: The module (or pytree) containing parametrizations to resolve.
index: If provided, resolves index-dependent parameters at this specific
dataset index. If ``None``, resolves all indices at once (stacking
along the leading dimension).
Returns:
A copy of the module with all parametrizations replaced by their
resolved array values.
Example:
>>> resolved_model = resolve_parametrizations(model, index=0)
>>> output = resolved_model(inputs)
"""
module = resolve_array_parametrizations(module)
if index is not None:
module = resolve_index_dependent_parameters(module, index)
else:
module = resolve_index_dependent_parameters_all(module)
return module
[docs]
def resolve_array_parametrizations(module: T) -> T:
"""Recursively resolve all
:py:class:`~ptyrax.parametrizations.ArrayParametrization` instances in a
module.
Traverses the pytree and replaces each
:py:class:`~ptyrax.parametrizations.ArrayParametrization` leaf with the
array returned by calling it. Nested parametrizations are resolved
recursively.
Args:
module: The module (or pytree) containing array parametrizations.
Returns:
A copy of the module with all array parametrizations replaced by
their output arrays.
"""
def resolve_fn(x: Any) -> Any: # noqa: ANN401
if isinstance(x, ArrayParametrization):
# Recursion ensures nested parametrizations are also resolved
return resolve_array_parametrizations(x())
return x
return jax.tree.map(resolve_fn, module, is_leaf=lambda x: isinstance(x, ArrayParametrization)) # type: ignore
[docs]
class ArrayParametrization(eqx.Module):
"""Abstract base class for array parametrizations.
An array parametrization wraps trainable parameters and produces an output
array via its ``__call__`` method. Subclasses implement specific constraints
(e.g., phase-only, normalized, outer-product decomposition) while exposing
a uniform interface that can be resolved transparently using
:py:func:`~ptyrax.parametrizations.resolve_array_parametrizations`.
Attributes:
output_shape: The shape of the array produced by ``__call__``.
Example:
>>> parametrized_model = build_model(config)
>>> resolved = resolve_parametrizations(parametrized_model, index=0)
>>> output = resolved(inputs)
"""
output_shape: tuple = eqx.field(static=True)
[docs]
@abstractmethod
def __call__(self, *args, **kwargs) -> Array:
"""Compute and return the parametrized output array."""
pass
@property
def shape(self) -> tuple[int, ...]:
"""The shape of the output array produced by this parametrization."""
return self.output_shape
def __getitem__(self, name: str) -> Any: # noqa: ANN401
try:
return super().__getattribute__(name)
except (AttributeError, TypeError) as e:
raise AttributeError(
f"""'ArrayParametrization' object has no item '{name}'.
It is likely that the module you are trying to access has not resolved parametrizations yet.
Please use the 'resolve_parametrizations' function to resolve all parametrizations before
calling the module,
i.e. 'resolved_module = resolve_parametrizations(module, index); resolved_module(...)'.
Alternatively, when operating outside the jit-boundary, you can access all parametrizations
via the 'all' property, i.e. 'all_parametrizations = module.foo.all'.
"""
) from e
def __getattr__(self, name: str) -> Any: # noqa: ANN401
try:
return super().__getattribute__(name)
except (AttributeError, TypeError) as e:
raise AttributeError(
f"""'ArrayParametrization' object has no attribute '{name}'.
It is likely that the module you are trying to access has not resolved parametrizations yet.
Inside of the jit-boundary, you should use the 'resolve_parametrizations' function to resolve all
parametrizations before calling the module,
i.e. 'resolved_module = resolve_parametrizations(module, index); resolved_module(...)'.
Alternatively, when operating outside the jit-boundary, you can access all parametrizations
via the 'all' property, i.e. 'all_parametrizations = module.foo.all'.
"""
) from e
[docs]
@gin.register
def mean_l2_norm(a: Float[Array, "* n m d"], axes: tuple[int, ...] = (-3, -2)) -> Float[Array, "* d"]:
r"""Compute the mean L2 norm of an array along the specified axes.
This is typically used as a ``scaling_function`` for
:py:class:`~ptyrax.parametrizations.NormalizedArrayParametrization` to
normalize arrays to a target mean L2 norm.
Args:
a: Input array to compute the norm of.
axes: Axes along which to compute the L2 norm before averaging.
Returns:
The mean L2 norm, with the specified axes reduced.
"""
return jnp.mean(jnp.linalg.norm(a, axes))
[docs]
@gin.register
class NormalizedArrayParametrization(ArrayParametrization):
r"""Parametrization that normalizes an array by a fixed scale factor.
Stores the data divided by a scale factor and multiplies by that factor
on evaluation. This keeps the internal parameters at unit scale during
optimization, which can improve gradient conditioning.
The scale factor is computed either from a provided ``scaling_function``
applied to the initial data, or from the explicit ``scale`` argument.
Attributes:
output_shape: Shape of the output array.
_data: Internal normalized data (``initial_data / scale``).
_scale: The normalization scale factor.
Example:
>>> param = NormalizedArrayParametrization(probe_array, scale=1e3)
>>> resolved_array = param() # returns probe_array
"""
output_shape: tuple = eqx.field(static=True)
_data: Shaped[Array, "..."]
_scale: float
def __init__(
self,
initial_data: Float[Array, "..."],
scale: float = 1.0,
scaling_function: Callable[[Shaped[Array, "..."]], Shaped[Array, "..."]] = None,
) -> None:
"""Initialize the normalized parametrization.
Args:
initial_data: The array to parametrize.
scale: Fixed scale factor to normalize by. Ignored if
``scaling_function`` is provided.
scaling_function: Optional callable that computes the scale factor
from ``initial_data``. Takes precedence over ``scale``.
"""
if scaling_function is not None:
self._scale = scaling_function(initial_data)
else:
self._scale = scale
self._data = initial_data / self._scale
self.output_shape = initial_data.shape
[docs]
def __call__(self) -> Shaped[Array, "..."]:
"""Return the reconstructed array (data multiplied by scale).
Returns:
The denormalized array.
"""
# Stop gradient prevents training of the reference value, but can also block
# the gradient if new parametrization are created based on previously trained
# arrays. Only use ArrayParametrizations on the model parameters, never on
# intermediate parameters!
output = self._data * self._scale
return output
def __array__(self, dtype: TypeVar | None = None) -> Shaped[Array, "..."]:
"""Support conversion to NumPy array via ``np.asarray(param)``."""
# noqa: ANN401
output = self._data * self._scale
return output.astype(dtype) if dtype is not None else output
[docs]
@gin.configurable
class NormalizedReferencedArrayParametrization(ArrayParametrization):
r"""Parametrization that stores a trainable offset relative to a fixed
reference.
The output is computed as ``data * scale + reference_value``, where
``reference_value`` is the initial data (with stopped gradients) and
``data`` starts at zero. This allows optimization to learn a correction
on top of an initial estimate without modifying the reference.
Attributes:
output_shape: Shape of the output array.
_data: Trainable offset initialized to zeros.
_scale: Scale factor applied to the trainable offset.
_reference_value: Fixed reference array (gradient-stopped).
Example:
>>> param = NormalizedReferencedArrayParametrization(initial_probe)
>>> # Initially returns initial_probe (since _data is zero)
>>> assert jnp.allclose(param(), initial_probe)
"""
output_shape: tuple = eqx.field(static=True)
_data: Shaped[Array, "..."]
_scale: float
_reference_value: Shaped[Array, "..."]
def __init__(
self,
initial_data: Shaped[Array, "..."],
scale: float = 1.0,
scaling_function: Callable[[Shaped[Array, "..."]], Shaped[Array, "..."]] = None,
) -> None:
"""Initialize the normalized referenced parametrization.
Args:
initial_data: The initial reference array. Stored with stopped
gradients and used as the baseline output.
scale: Scale factor for the trainable offset.
scaling_function: Optional callable that computes the scale factor
from ``initial_data``. Takes precedence over ``scale``.
"""
if scaling_function is not None:
self._scale = scaling_function(initial_data)
else:
self._scale = scale
self._reference_value = initial_data
self._data = jnp.zeros_like(initial_data)
self.output_shape = initial_data.shape
[docs]
def __call__(self) -> Shaped[Array, "..."]:
"""Return the parametrized array as ``data * scale + reference``.
The reference value has its gradient stopped to prevent training
of the initial estimate.
Returns:
The sum of the scaled trainable offset and the fixed reference.
"""
# Stop gradient prevents training of the reference value, but can also block
# the gradient if new parametrization are created based on previously trained
# arrays. Only use ArrayParametrizations on the model parameters, never on
# intermediate parameters!
def broadcast_to_match(a: Shaped, b: Shaped) -> Shaped:
# a: e.g. (N, H, W) or (H, W)
# b: e.g. (N,) or ()
if b.ndim == 0:
return b # scalar, fine
if b.ndim == 1 and a.ndim > 1:
# Add trailing singleton dimensions to match rank
return b.reshape((b.shape[0],) + (1,) * (a.ndim - 1))
return b
scale = broadcast_to_match(self._data, jax.lax.stop_gradient(jnp.array(self._scale)))
output = self._data * scale + jax.lax.stop_gradient(self._reference_value)
return output
[docs]
class PhaseOnlyArrayParametrization(ArrayParametrization):
r"""Parametrization that constrains an array to unit magnitude (phase-only).
Stores a real-valued phase array and produces a complex array with unit
magnitude via :math:`e^{i \phi}`. This is useful for representing phase
screens or phase-only optical elements.
Attributes:
output_shape: Shape of the output complex array.
phase: Trainable real-valued phase array in radians.
Example:
>>> param = PhaseOnlyArrayParametrization(output_shape=(128, 128))
>>> field = param() # complex array with |field| == 1 everywhere
"""
output_shape: tuple = eqx.field(static=True)
phase: Float[Array, "M N"]
def __init__(
self,
output_shape: tuple,
phase_initializer: Callable[[tuple], Float] = jnp.zeros,
) -> None:
"""Initialize the phase-only parametrization.
Args:
output_shape: Shape of the output complex array.
phase_initializer: Callable that takes a shape tuple and returns
the initial phase values. Defaults to zeros.
"""
self.phase = phase_initializer(output_shape)
self.output_shape = output_shape
[docs]
def __call__(self, **kwargs) -> Complex[Array, ""]:
r"""Compute the unit-magnitude complex array from the stored phase.
Returns:
Complex array :math:`e^{i \phi}` with unit magnitude.
"""
return phase_only_exp(self.phase)
[docs]
@gin.configurable()
class OuterProductArrayParametrization(ArrayParametrization):
r"""Parametrization that represents a 2D array as a sum of outer products.
Decomposes a 2D array of shape ``(M, N, d)`` into ``n_outer`` rank-1
components, stored as column vectors of shape ``(M, n_outer, d)`` and row
vectors of shape ``(n_outer, N, d)``. The output is reconstructed via
Einstein summation: :math:`A_{MNd} = \sum_s c_{Msd} \cdot r_{sNd}`.
This reduces the number of free parameters from :math:`M \times N` to
:math:`n_{outer} \times (M + N)`, which can act as a regularizer.
Attributes:
output_shape: Shape of the reconstructed output array.
column_vector: Column factors of shape ``(..., M, n_outer, d)``.
row_vector: Row factors of shape ``(..., n_outer, N, d)``.
Example:
>>> param = OuterProductArrayParametrization(output_shape=(64, 64, 1), n_outer=4)
>>> array = param() # shape (64, 64, 1)
"""
output_shape: tuple
column_vector: Shaped[Array, "... M s d"]
row_vector: Shaped[Array, "... s N d"]
def __init__(
self,
output_shape: tuple,
n_outer: int,
initializer: Optional[Callable[[tuple], Array]] = jnp.ones,
) -> None:
"""Initialize the outer product parametrization.
Args:
output_shape: Desired shape of the output array. Must have at least
3 dimensions ``(..., M, N, d)``.
n_outer: Number of rank-1 outer-product components.
initializer: Callable that takes a shape tuple and returns an
initial array. Defaults to ``jnp.ones``.
Raises:
ValueError: If ``output_shape`` has fewer than 3 dimensions.
"""
if len(output_shape) < 3:
raise ValueError(
"Output shape must have at least 3 dimensions to use OuterProductDataModel, "
f"got {output_shape}. The last dimension may be 1."
)
self.output_shape = output_shape
self.column_vector = initializer(output_shape[:-2] + (n_outer,) + output_shape[-1:])
self.row_vector = initializer(output_shape[:-3] + (n_outer,) + output_shape[-2:])
[docs]
def __call__(self) -> Shaped[Array, "... M N d"]:
"""Reconstruct the full array from the outer product of column and row
vectors.
Returns:
The reconstructed array of shape ``output_shape``.
"""
return jnp.einsum(
"...Msd, ...sNd -> ...MNd",
self.column_vector,
self.row_vector,
optimize="optimal",
)
[docs]
@gin.configurable()
class WaveletArrayParametrization(ArrayParametrization):
"""Parametrization that stores an array in the wavelet domain.
Applies a 2D Haar wavelet decomposition to the initial data and stores
the wavelet coefficients as the trainable parameters. On evaluation,
the inverse wavelet transform reconstructs the spatial-domain array.
Operating in the wavelet domain can provide a multi-scale representation
that encourages sparsity in the wavelet basis.
Attributes:
wavelet_coefficients: Trainable wavelet coefficients (Haar, level 1).
Example:
>>> param = WaveletArrayParametrization(initial_array)
>>> reconstructed = param() # spatial-domain array
"""
wavelet_coefficients: Shaped[Array, "..."]
def __init__(self, data_or_initializer: Union[Callable[[tuple], Shaped]]) -> None:
"""Initialize the wavelet parametrization.
Args:
data_or_initializer: Either a JAX array to decompose, or a
zero-argument callable that returns one.
"""
if callable(data_or_initializer):
initial_data = data_or_initializer()
else:
initial_data = data_or_initializer
self.output_shape = initial_data.shape
self.wavelet_coefficients = jwt.wavedec2(initial_data, "haar", level=1)
[docs]
def __call__(self, *args, **kwargs) -> Shaped[Array, "..."]:
"""Reconstruct the spatial-domain array from wavelet coefficients.
Returns:
The inverse-wavelet-transformed array.
"""
return jwt.waverec2(self.wavelet_coefficients, "haar")[0]
[docs]
@gin.configurable()
class DirectArrayParametrization(ArrayParametrization):
"""Identity parametrization that wraps an array without any transformation.
This is a no-op parametrization: calling it simply returns the stored data
unchanged. It is useful for providing a uniform
:py:class:`~ptyrax.parametrizations.ArrayParametrization` interface around
arrays that require no constraints.
Attributes:
_data: The stored array.
Example:
>>> param = DirectArrayParametrization(jnp.ones((64, 64)))
>>> assert jnp.array_equal(param(), jnp.ones((64, 64)))
"""
_data: Float[Array, "..."]
def __init__(self, data_or_initializer: Union[Callable[[tuple], Shaped]]) -> Shaped[Array, "..."]:
"""Initialize the direct parametrization.
Args:
data_or_initializer: Either a JAX array to store directly, or a
zero-argument callable that returns one.
"""
if callable(data_or_initializer):
self._data = data_or_initializer()
else:
self._data = data_or_initializer
[docs]
def __call__(self) -> Float[Array, "..."]:
"""Return the stored array unchanged.
Returns:
The wrapped array.
"""
return self._data
@property
def output_shape(self) -> tuple[int, ...]:
"""The shape of the stored array."""
return self._data.shape
[docs]
class BinaryArrayParametrization(ArrayParametrization):
"""Parametrization that produces a binary array via thresholding."""
_data: Float[Array, "..."]
threshold: float = eqx.field(static=True)
def __init__(self, data_or_initializer: Union[Callable[[tuple], Shaped]], threshold: float = 0.5) -> None:
"""Initialize the binary array parametrization.
Args:
data_or_initializer: Either a JAX array to store, or a zero-argument
callable that returns one.
threshold: Threshold value for binarization. Output will be 1 where
data > threshold and 0 elsewhere.
"""
if callable(data_or_initializer):
_data = data_or_initializer()
else:
_data = data_or_initializer
if jnp.iscomplex(_data):
_data = jnp.angle(_data - jnp.mean(_data))
_data = (_data - jnp.min(_data)) / (jnp.max(_data) - jnp.min(_data)) # Normalize to [0, 1]
self._data = _data
self.output_shape = _data.shape
self.threshold = threshold
[docs]
def __call__(self) -> Bool[Array, "..."]:
"""Return a binary array obtained by thresholding the stored data.
Returns:
Binary array where values are 1 if data > threshold and 0 otherwise.
"""
output = binarize(self._data, self.threshold)
return output
[docs]
def as_direct_array_parametrization(data: Shaped[Array, "..."]) -> DirectArrayParametrization:
"""Wrap an array in a
:py:class:`~ptyrax.parametrizations.DirectArrayParametrization` if needed.
If ``data`` is already an
:py:class:`~ptyrax.parametrizations.ArrayParametrization`, it is returned
unchanged. Otherwise, it is wrapped in a
:py:class:`~ptyrax.parametrizations.DirectArrayParametrization`.
Args:
data: An array or existing parametrization.
Returns:
An ``ArrayParametrization`` wrapping the input data.
"""
if isinstance(data, ArrayParametrization):
return data
return DirectArrayParametrization(data)
T = TypeVar("T")
[docs]
class IndexDependentParameter(eqx.Module, Generic[T]):
"""Abstract base class for parameters that vary with the dataset index.
In ptychography, certain model parameters (e.g., scan positions, per-frame
aberrations) differ for each diffraction pattern in the dataset. This class
provides a uniform interface to access index-specific values either
explicitly via :meth:`at_index` or implicitly after resolution via
:meth:`at_current_index`.
Subclasses must implement :meth:`at_index`, :attr:`n`, and :attr:`all`.
Attributes:
_index: The currently bound index, or ``None`` if not yet resolved.
"""
_index: int | None = eqx.field(static=True, default=None)
[docs]
@abstractmethod
def at_index(self, index: int) -> T:
"""Return the parameter value at a specific dataset index.
Args:
index: The dataset index to retrieve.
Returns:
The parameter value for the given index.
"""
pass
[docs]
def at(self, index: int) -> T:
"""Alias for :meth:`at_index`."""
return self.at_index(index)
[docs]
def at_current_index(self) -> T:
if self._index is None:
raise ValueError(
"IndexDependentParametrization has not been resolved yet. "
"Please use the 'resolve_index_dependent_parametrizations' function to resolve all parametrizations "
"before calling the module inside the jit-boundary. i.e. `resolved_module "
"resolve_index_dependent_parametrizations(module, index); resolved_module(...).`"
"When calling outside the jit-boundary, you can use the 'at_index' method to get the parametrization "
"at a specific index, or use the 'all' property to get all parametrizations."
)
return self.at_index(self._index)
@property
@abstractmethod
def n(self) -> int:
"""The number of distinct index values this parameter supports."""
pass
@property
@abstractmethod
def all(self) -> T:
"""Return all parameter values across all indices."""
pass
[docs]
class IndexSliceParameter(IndexDependentParameter[T]):
"""Index-dependent parameter that selects a slice along a given dimension.
Stores a pytree of arrays where one dimension corresponds to the dataset
index. Calling :meth:`at_index` slices along that dimension to extract
the parameter for a single index.
Attributes:
parameters: The full pytree of parameters (with an index dimension).
slice_dim: Static pytree matching ``parameters`` indicating which
dimension to slice for each leaf.
Example:
>>> positions = jnp.zeros((100, 2)) # 100 scan positions, 2D
>>> param = IndexSliceParameter(positions, dim=0)
>>> param.at_index(5) # shape (2,)
"""
parameters: PyTree
# slice_dim has same treedef as parameters
slice_dim: PyTree = eqx.field(static=True)
def __init__(self, parameters: PyTree, dim: int = 0) -> None:
"""Initialize the index-slice parameter.
Args:
parameters: A pytree (array or module) with one dimension
representing the dataset index.
dim: The dimension along which to slice. Interpreted relative to
the start of the array shape (converted to a negative index
internally for robustness to leading-dimension changes).
Raises:
ValueError: If ``parameters`` is a module that does not support
indexing.
"""
def get_slice_dim(leaf: Shaped[Array, "..."]) -> int:
return -len(leaf.shape) + dim if isinstance(leaf, Array) else None
# Set slice dim to first dimension at initialization time
# Due to other model transformations, self.parameters shape may change at runtime
# So long as these transformations only adjust leading dimensions, this is safe.
resolved_parameters = resolve_array_parametrizations(parameters)
self.slice_dim = jax.tree.map(get_slice_dim, resolved_parameters)
if isinstance(parameters, eqx.Module):
if not hasattr(resolved_parameters, "__getitem__"):
raise ValueError(
f"""The provided module {type(parameters)} does not support indexing.
All IndexSliceParameters must support indexing. """
)
self.parameters = parameters
return
self.parameters = jnp.array(parameters)
[docs]
def at_index(self, index: int) -> Shaped[Array, "..."]:
"""Slice the parameters at the given dataset index.
Args:
index: Dataset index to select.
Returns:
The parameter pytree with the index dimension removed.
"""
def slice_at_dim(leaf: Shaped[Array, "..."], dim: int) -> Shaped[Array, "..."]:
if dim is None:
return leaf
full_slice = [slice(None)] * len(leaf.shape)
full_slice[dim] = index
return leaf[tuple(full_slice)]
sliced_parameters = jax.tree.map(slice_at_dim, self.parameters, self.slice_dim)
return sliced_parameters
@property
def all(self) -> Shaped[Array, "indices ..."]:
"""The full parameter pytree including the index dimension."""
return self.parameters
@property
def n(self) -> int:
"""The number of indices (size of the slicing dimension)."""
return resolve_array_parametrizations(self).parameters.shape[0]
def __getitem__(self, slice: slice) -> Any: # noqa: ANN401
raise ValueError(
"Using __getitem__ on IndexSliceParametrization. "
"It is likely that the module you are trying to access has not resolved parametrizations yet. "
"Please use the 'resolve_parametrizations' function to resolve all parametrizations before "
"calling the module, "
"i.e. 'resolved_module = resolve_parametrizations(module, index); resolved_module(...)'. "
"Alternatively, when operating outside the jit-boundary, you can access all parametrizations "
"via the 'all' property instead when operating"
"outside the jit-boundary.",
ValueError,
)
def __getattr__(self, name: str) -> Any: # noqa: ANN401
try:
return super().__getattribute__(name)
except (AttributeError, TypeError) as e:
raise AttributeError(
f"""'IndexSliceParametrization' object has no attribute '{name}'.
It is likely that the module you are trying to access has not resolved parametrizations yet.
Please use the 'resolve_parametrizations' function to resolve all parametrizations before
calling the module,
i.e. 'resolved_module = resolve_parametrizations(module, index); resolved_module(...)'.
Alternatively, when operating outside the jit-boundary, you can access all parametrizations
via the 'all' property, i.e. 'all_parametrizations = module.foo.all'.
"""
) from e
[docs]
class BoundSliceParameter(eqx.Module):
"""A resolved index-dependent parameter bound to a single index.
Created by :py:func:`~ptyrax.parametrizations.resolve_index_dependent_parameters`
when an :py:class:`~ptyrax.parametrizations.IndexDependentParameter` is resolved
at a specific index. Use :meth:`at_current_index` to access the bound value.
Attributes:
parameter: The resolved parameter value for the bound index.
n: The total number of indices the original parameter supported.
"""
parameter: PyTree
n: int = eqx.field(static=True)
def __init__(self, parameter: PyTree, n: int) -> None:
"""Initialize the bound slice parameter.
Args:
parameter: The resolved parameter value for a single index.
n: Total number of indices in the original parameter.
"""
self.parameter = parameter
self.n = n
[docs]
def at_index(self, *args) -> T:
raise ValueError(
"Using at_index on BoundSliceParameter. ",
"Likely cause is attempting to use `at_index` while the model IndexableParameter has already been "
"resolved. If this is the case, simply use the `at_current_index` method instead.",
)
[docs]
def at_current_index(self) -> T:
"""Return the parameter value for the bound index."""
return self.parameter
@property
def all(self) -> T:
raise ValueError(
"BoundSliceParameter does not support accessing all parameters. "
"The model has already been resolved, it is bound to a single index. "
"Note: write your model without and index dimension, using the 'at_current_index' method to access "
"any parameter which is a function of the dataset index."
)
[docs]
def resolve_index_dependent_parameters(module: T, index: int) -> T:
"""Recursively resolve all index-dependent parameters at a specific index.
Traverses the pytree and replaces each
:py:class:`~ptyrax.parametrizations.IndexDependentParameter` with a
:py:class:`~ptyrax.parametrizations.BoundSliceParameter` containing the
value at the specified index.
Args:
module: The module (or pytree) containing index-dependent parameters.
index: The dataset index to resolve at.
Returns:
A copy of the module with all index-dependent parameters bound to
the given index.
"""
def resolve_fn(x: Any) -> Any: # noqa: ANN401
if isinstance(x, IndexDependentParameter):
x_resolved = resolve_array_parametrizations(x)
x_resolved_index = x_resolved.at(index)
# Recursion ensures nested parametrizations are also resolved
return BoundSliceParameter(x_resolved_index, n=x_resolved.n)
return x
return jax.tree.map(resolve_fn, module, is_leaf=lambda x: isinstance(x, IndexDependentParameter))
[docs]
def resolve_index_dependent_parameters_all(module: T) -> T:
"""Recursively resolve all index-dependent parameters, returning all
indices.
Traverses the pytree and replaces each
:py:class:`~ptyrax.parametrizations.IndexDependentParameter` with its
:attr:`all` property, which includes all index values stacked along the
leading dimension.
Args:
module: The module (or pytree) containing index-dependent parameters.
Returns:
A copy of the module with all index-dependent parameters replaced by
their full (all-index) values.
"""
def resolve_fn(x: Any) -> Any: # noqa: ANN401
if isinstance(x, IndexDependentParameter):
x_resolved = resolve_array_parametrizations(x)
x_resolved_all = x_resolved.all
# Recursion ensures nested parametrizations are also resolved
return x_resolved_all
return x
return jax.tree.map(resolve_fn, module, is_leaf=lambda x: isinstance(x, IndexDependentParameter))