from typing import Callable, Literal, Self
import equinox as eqx
import gin
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jax import vmap
from jax.scipy.ndimage import map_coordinates
from jax.scipy.spatial.transform import Rotation as JaxRotation
from jaxtyping import Array, ArrayLike, Bool, Complex, Float, Shaped
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.gridspec import SubplotSpec
from matplotlib.lines import Line2D
from ptyrax.parametrizations import (
ArrayParametrization,
resolve_array_parametrizations,
)
from ptyrax.utils import make_or_reuse_axes
[docs]
def matrix_to_six_dimensional_representation(rotation_matrix: Float[Array, "... 3 3"]) -> Float[Array, "... 6"]:
"""Convert a 3x3 rotation matrix to its 6D continuous representation.
Uses the first two rows of the rotation matrix as the 6D representation,
following the approach from Zhou et al. (NeurIPS 2019) for continuous
rotation representations suitable for gradient-based optimization.
Args:
rotation_matrix: A shape ``(..., 3, 3)`` rotation matrix.
Returns:
A shape ``(..., 6)`` array containing the first two rows of the
rotation matrix concatenated.
See Also:
:py:func:`~ptyrax.spatial.six_dimensional_representation_to_matrix`:
The inverse operation.
"""
# TODO convert to extension of scipy.spatial.Rotation class
rotation_matrix = jnp.atleast_1d(rotation_matrix)
original_shape = rotation_matrix.shape
rotation_matrix = jnp.reshape(rotation_matrix, (-1, 3, 3))
x = rotation_matrix[..., 0, :]
y = rotation_matrix[..., 1, :]
representation = jnp.concatenate((x, y), axis=-1)
representation = jnp.reshape(representation, original_shape[:-2] + (6,))
return representation
[docs]
def six_dimensional_representation_to_matrix(representation: Float[Array, "... 6"]) -> Float[Array, "... 3 3"]:
"""Converts a 6D rotation representation to a 3x3 rotation matrix.
Based on:
"On the continuity of rotation representations in neural networks"
Yi Zhou, Connelly Barnes, Jingwan Lu, Jimei Yang, Hao Li.
Conference on Neural Information Processing Systems (NeurIPS) 2019.
Args:
representation: A shape (..., 6) array representing the rotation as a 6D vector.
Returns:
A shape (..., 3, 3) array representing the rotation as a 3x3 matrix.
"""
representation = jnp.atleast_1d(representation)
# original_shape = representation.shape
x = representation[..., 0:3]
y = representation[..., 3:6]
x = x / jnp.linalg.norm(x, axis=-1, keepdims=True)
z = jnp.cross(x, y)
z = z / jnp.linalg.norm(z, axis=-1, keepdims=True)
y = jnp.cross(z, x)
rotation_matrix = jnp.stack((x, y, z), axis=-2)
return rotation_matrix
[docs]
class Rotation(eqx.Module):
"""An equinox module wrapping a rotation using the 6D continuous
representation.
Stores rotations internally as a 6D vector (the first two rows of the rotation
matrix), which provides a continuous, singularity-free parameterization suitable
for gradient-based optimization. Based on Zhou et al. (NeurIPS 2019).
Supports batched rotations via leading batch dimensions, composition via ``*``,
inversion via the :py:attr:`inv` property, and indexing via ``[]``.
Example:
>>> import jax.numpy as jnp
>>> from ptyrax.spatial import Rotation
>>> rot = Rotation.from_matrix(jnp.eye(3))
>>> rot.as_matrix() # Returns identity matrix
"""
_representation_6d: Float[Array, "* 6"] = eqx.field(
default_factory=lambda: matrix_to_six_dimensional_representation(jnp.eye(3))
)
[docs]
@classmethod
def from_matrix(cls, matrix: Float[Array, "* 3 3"]) -> Self:
"""Construct a rotation from a 3x3 rotation matrix.
Args:
matrix: A shape ``(*, 3, 3)`` rotation matrix.
Returns:
A new :py:class:`~ptyrax.spatial.Rotation` instance.
"""
return cls(matrix_to_six_dimensional_representation(matrix))
[docs]
def as_matrix(self) -> Float[Array, "* 3 3"]:
"""Convert this rotation to a 3x3 rotation matrix.
Returns:
A shape ``(*, 3, 3)`` orthonormal rotation matrix.
"""
return six_dimensional_representation_to_matrix(self._representation_6d)
[docs]
def as_scipy_rotation(self) -> JaxRotation:
"""Convert this rotation to a JAX scipy ``Rotation`` object.
Returns:
A ``jax.scipy.spatial.transform.Rotation`` instance.
"""
return jax.scipy.spatial.transform.Rotation.from_matrix(self.as_matrix())
@property
def shape(self) -> tuple[int, ...]:
"""Batch shape of the rotation."""
return self._representation_6d.shape[:-1]
@property
def inv(self) -> Self:
"""The inverse rotation (transpose of the rotation matrix)."""
return self.from_matrix(self.as_matrix().swapaxes(-1, -2)) # Transpose = inverse
def __mul__(self, other: Self) -> Self:
return self.from_matrix(vmap(lambda a, b: a.as_matrix() @ b.as_matrix())(self, other))
def __div__(self, other: Self) -> Self:
return self.from_matrix(vmap(lambda a, b: a.as_matrix() @ b.as_matrix().T)(self, other))
def __rdiv__(self, other: Self) -> Self:
return self.from_matrix(vmap(lambda a, b: a.as_matrix().T @ b.as_matrix())(self, other))
def __getitem__(self, item: slice) -> Self:
return Rotation(self._representation_6d[item])
def __str__(self) -> str:
return f"representation: {jnp.array_str(self._representation_6d)}"
[docs]
class CoordinateSystem(eqx.Module):
"""A coordinate system defined by a translation and rotation in CXI
coordinates.
Represents a rigid-body transformation (rotation + translation) of a local
coordinate frame relative to the global CXI frame. The CXI convention defines
z along the incoming beam direction and y vertically.
The translation can be stored in either the global frame or the local frame,
controlled by ``model_in_local_frame``. When ``model_in_local_frame=True``,
the internal translation is stored in the local (rotated) frame and converted
to global coordinates on access via the :py:attr:`translation` property.
Attributes:
rotation: The :py:class:`~ptyrax.spatial.Rotation` of this coordinate system.
model_in_local_frame: If ``True``, the internal translation is stored in the
rotated local frame rather than the global frame.
Example:
>>> import jax.numpy as jnp
>>> from ptyrax.spatial import CoordinateSystem, Rotation
>>> cs = CoordinateSystem(
... rotation=Rotation.from_matrix(jnp.eye(3)),
... translation=jnp.array([1.0, 0.0, 0.0]),
... )
>>> cs.translation
"""
rotation: Rotation = eqx.field(default_factory=lambda: Rotation.from_matrix(jnp.eye(3)))
_translation: Float[Array, "* 3"] = eqx.field(default_factory=lambda: jnp.array([0.0, 0.0, 0.0]))
model_in_local_frame: bool = eqx.field(static=True, default=False)
def __init__(
self,
rotation: Rotation = Rotation.from_matrix(jnp.eye(3)),
translation: Float[Array, "* 3"] = jnp.array([0.0, 0.0, 0.0]),
model_in_local_frame: bool = False,
**kwargs,
) -> None:
if model_in_local_frame:
# This is called outside of the jit-boundary, so we should resolve parametrizations here
translation_resolved = resolve_array_parametrizations(translation)
rotation_resolved = resolve_array_parametrizations(rotation)
translation_local = jnp.einsum(
"...ij, ...j -> ...i",
rotation_resolved.as_matrix(),
translation_resolved,
)
if isinstance(translation, ArrayParametrization):
# keep the same type of parametrization
translation_local = type(translation)(translation_local)
translation = translation_local
self._translation = translation
self.rotation = rotation
self.model_in_local_frame = model_in_local_frame
@property
def translation(self) -> Float[Array, "* 3"]:
"""Translation vector in the global CXI frame.
When ``model_in_local_frame=True``, the stored local-frame translation
is rotated into the global frame before being returned.
"""
if not self.model_in_local_frame:
return self._translation
internal_translation = self._translation
# Translate from local to global
external_frame_translation = jnp.einsum(
"...ij, ...j -> ...i", self.rotation.inv.as_matrix(), internal_translation
)
return external_frame_translation
@property
def translation_internal(self) -> Float[Array, "* 3"]:
"""Translation vector in the local (rotated) frame.
When ``model_in_local_frame=False``, the stored global-frame translation
is rotated into the local frame before being returned.
"""
if self.model_in_local_frame:
return self._translation
external_translation = self.translation
# Translate from global to local
internal_frame_translation = jnp.einsum("...ij, ...j -> ...i", self.rotation.as_matrix(), external_translation)
return internal_frame_translation
@property
def x_axis(self) -> Float[Array, "* 3"]:
"""Unit vector along the local x-axis, expressed in global
coordinates."""
return self.rotation.as_matrix()[:, 0]
@property
def y_axis(self) -> Float[Array, "* 3"]:
"""Unit vector along the local y-axis, expressed in global
coordinates."""
return self.rotation.as_matrix()[:, 1]
@property
def z_axis(self) -> Float[Array, "* 3"]:
"""Unit vector along the local z-axis, expressed in global
coordinates."""
return self.rotation.as_matrix()[:, 2]
@property
def shape(self) -> tuple[int, ...]:
"""Batch shape of the coordinate system."""
return self._translation.shape[:-1]
def __getitem__(self, item: slice) -> Self:
return CoordinateSystem(
rotation=self.rotation[item],
translation=self.translation[item],
)
def __str__(self) -> str:
string_representation = """rotation: {jnp.array_str(self.rotation.as_matrix())}
translation: {jnp.array_str(self.translation)}"""
return string_representation
[docs]
class SamplingGrid(eqx.Module):
"""A discrete sampling grid with pixel sizes and optional FFT-shift
awareness.
Represents a regular 2D or 3D grid centered at the origin (or shifted by
``origin_shift``), with coordinates following CXI convention: x along
dimension 0, y along dimension 1, using numpy ``'ij'`` indexing.
Grid coordinates range from ``-N//2`` to ``(N+1)//2 - 1`` along each axis,
scaled by the corresponding pixel size. When ``fftshifted=True``, the
coordinate arrays are reordered to match FFT output layout.
Attributes:
shape: Grid dimensions as ``(n_x, n_y)`` or ``(n_x, n_y, n_z)``.
pixel_size: Physical size of each pixel along each dimension, shape ``(*, d)``.
origin_shift: Optional shift of the grid origin, shape ``(d,)``.
fftshifted: If ``True``, coordinate arrays are fftshifted.
Example:
>>> import jax.numpy as jnp
>>> from ptyrax.spatial import SamplingGrid
>>> grid = SamplingGrid(shape=(64, 64), pixel_size=jnp.array([1e-6, 1e-6]))
>>> grid.x # 1D coordinates along x
>>> grid.meshgrid # Full 2D coordinate meshgrid
"""
shape: tuple[int, ...] = eqx.field(static=True, converter=lambda x: tuple(x))
pixel_size: Float[Array, "* d"] | np.ndarray
origin_shift: Float[Array, " d"] | np.ndarray = eqx.field(default=None)
fftshifted: bool = eqx.field(static=True, default=False)
[docs]
@classmethod
def from_tuples(
cls,
shape: tuple[int | float],
pixel_size: tuple[float],
origin_shift: tuple[float] = None,
fftshifted: bool = False,
) -> "SamplingGrid":
"""Construct a :py:class:`~ptyrax.spatial.SamplingGrid` from plain
tuples.
Convenience constructor that converts tuple arguments to the appropriate
array types.
Args:
shape: Grid dimensions, e.g. ``(64, 64)``.
pixel_size: Physical pixel size per dimension, e.g. ``(1e-6, 1e-6)``.
origin_shift: Optional origin offset per dimension.
fftshifted: Whether the grid coordinates should be FFT-shifted.
Returns:
A new :py:class:`~ptyrax.spatial.SamplingGrid` instance.
"""
int_cast_shape = tuple(int(el) for el in shape)
return cls(
shape=int_cast_shape, pixel_size=jnp.array(pixel_size), origin_shift=origin_shift, fftshifted=fftshifted
)
def __post_init__(self) -> None:
if len(self.shape) != len(self.pixel_size.reshape(-1)):
raise ValueError(
f"The sampling grid shape ({len(self.shape)}D) and pixel_size ({len(self.pixel_size.reshape(-1))}D) "
"must have the same length."
)
if len(self.shape) not in (2, 3):
raise ValueError(f"The sampling grid shape must be 2D or 3D, got {len(self.shape)}D.")
def __str__(self) -> str:
return f"shape: {self.shape} \n \
pixel_size: {self.pixel_size}"
def __getitem__(self, idx: int) -> int:
if self.origin_shift is None:
return SamplingGrid(shape=self.shape, pixel_size=self.pixel_size[idx], origin_shift=None)
return SamplingGrid(shape=self.shape, pixel_size=self.pixel_size[idx], origin_shift=self.origin_shift[idx])
@property
def meshgrid(self) -> Float[Array, "d N_x N_y"] | Float[Array, "d N_x N_y N_z"]:
"""Full coordinate meshgrid with shape ``(d, N_x, N_y)`` or ``(d, N_x,
N_y, N_z)``."""
if self.ndim == 2:
return jnp.stack(jnp.meshgrid(self.x, self.y, indexing="ij"), axis=0)
return jnp.stack(jnp.meshgrid(self.x, self.y, self.z, indexing="ij"), axis=0)
@property
def n_x(self) -> int:
"""Number of grid points along the x-axis."""
return self.shape[0]
@property
def n_x_max(self) -> int:
"""Maximum x index (inclusive) in the centered grid."""
return (self.n_x + 1) // 2 - 1 # off-by-one so this is actually the largest value in arange
@property
def n_x_min(self) -> int:
"""Minimum x index in the centered grid."""
return -(self.n_x // 2)
def _n_axis(self, axis: int) -> int:
"""Number of grid points along a given axis."""
if axis >= len(self.shape):
raise ValueError(f"The grid is {self.ndim}D, axis {axis} is not defined")
return self.shape[axis]
def _n_axis_max(self, axis: int) -> int:
"""Maximum index (inclusive) in the centered grid for a given axis."""
n = self._n_axis(axis)
return (n + 1) // 2 - 1
def _n_axis_min(self, axis: int) -> int:
"""Minimum index in the centered grid for a given axis."""
n = self._n_axis(axis)
return -(n // 2)
def _axis_pixel_size(self, axis: int) -> float:
"""Physical pixel size along a given axis."""
if axis >= len(self.shape):
raise ValueError(f"The grid is {self.ndim}D, axis {axis} pixel_size is not defined")
return self.pixel_size[..., axis]
def _get_coordinates_1d(self, axis: int) -> Float[Array, " n"]:
"""1D array of physical coordinates along a given axis."""
coords = jnp.arange(self._n_axis_min(axis), self._n_axis_max(axis) + 1) * self._axis_pixel_size(axis)
return jnp.fft.fftshift(coords) if self.fftshifted else coords
@property
def x_pixel_size(self) -> float:
"""Physical pixel size along the x-axis."""
return self._axis_pixel_size(0)
@property
def x_max(self) -> float:
"""Maximum physical x-coordinate."""
return self._n_axis_max(0) * self.x_pixel_size
@property
def x_min(self) -> float:
"""Minimum physical x-coordinate."""
return self._n_axis_min(0) * self.x_pixel_size
@property
def x(self) -> Float[Array, "n_x"]:
"""1D array of physical x-coordinates."""
return self._get_coordinates_1d(0)
@property
def n_y(self) -> int:
"""Number of grid points along the y-axis."""
return self.shape[1]
@property
def n_y_max(self) -> int:
"""Maximum y index (inclusive) in the centered grid."""
return self._n_axis_max(1)
@property
def n_y_min(self) -> int:
"""Minimum y index in the centered grid."""
return self._n_axis_min(1)
@property
def y_pixel_size(self) -> float:
"""Physical pixel size along the y-axis."""
return self._axis_pixel_size(1)
@property
def y_max(self) -> float:
"""Maximum physical y-coordinate."""
return self._n_axis_max(1) * self.y_pixel_size
@property
def y_min(self) -> float:
"""Minimum physical y-coordinate."""
return self._n_axis_min(1) * self.y_pixel_size
@property
def y(self) -> Float[Array, "n_y"]:
"""1D array of physical y-coordinates."""
return self._get_coordinates_1d(1)
@property
def n_z(self) -> int:
"""Number of grid points along the z-axis.
Raises:
ValueError: If the grid is 2D.
"""
if self.ndim == 2:
raise ValueError("The grid is 2D, N_z is not defined")
return self.shape[2]
@property
def n_z_max(self) -> int:
"""Maximum z index (inclusive) in the centered grid."""
return self._n_axis_max(2)
@property
def n_z_min(self) -> int:
"""Minimum z index in the centered grid."""
return self._n_axis_min(2)
@property
def z_pixel_size(self) -> float:
"""Physical pixel size along the z-axis.
Raises:
ValueError: If the grid is 2D.
"""
return self._axis_pixel_size(2)
@property
def z_max(self) -> float:
"""Maximum physical z-coordinate."""
return self.n_z_max * self.z_pixel_size
@property
def z_min(self) -> float:
"""Minimum physical z-coordinate."""
return self.n_z_min * self.z_pixel_size
@property
def z(self) -> Float[Array, "n_z"]:
"""1D array of physical z-coordinates."""
return self._get_coordinates_1d(2)
@property
def ndim(self) -> int:
"""Number of spatial dimensions (2 or 3)."""
return len(self.shape)
@property
def xx(self) -> Float[Array, "N_x N_y"] | Float[Array, "N_x N_y N_z"]:
"""X-component of the coordinate meshgrid."""
return self.meshgrid[0]
@property
def yy(self) -> Float[Array, "N_x N_y"] | Float[Array, "N_x N_y N_z"]:
"""Y-component of the coordinate meshgrid."""
return self.meshgrid[1]
@property
def zz(self) -> Float[Array, "N_x N_y N_z"] | Float[Array, "N_x N_y N_z"]:
"""Z-component of the coordinate meshgrid (3D grids only)."""
return self.meshgrid[2]
@property
def rr(self) -> Float[Array, "N_x N_y"] | Float[Array, "N_x N_y N_z"]:
"""Radial distance from the origin at each grid point."""
if self.ndim == 2:
return jnp.sqrt(self.xx**2 + self.yy**2)
else:
return jnp.sqrt(self.xx**2 + self.yy**2 + self.zz**2)
@property
def rr2(self) -> Float[Array, "N_x N_y"] | Float[Array, "N_x N_y N_z"]:
"""Squared radial distance from the origin at each grid point."""
if self.ndim == 2:
return self.xx**2 + self.yy**2
else:
return self.xx**2 + self.yy**2 + self.zz**2
@property
def x_bounds(self) -> Float[Array, "... 2"]:
"""Min and max physical x-coordinates as a ``(2,)`` array."""
return jnp.stack([self.x_min, self.x_max], axis=-1)
@property
def y_bounds(self) -> Float[Array, "... 2"]:
"""Min and max physical y-coordinates as a ``(2,)`` array."""
return jnp.stack([self.y_min, self.y_max], axis=-1)
@property
def z_bounds(self) -> Float[Array, "... 2"]:
"""Min and max physical z-coordinates as a ``(2,)`` array."""
return jnp.stack([self.z_min, self.z_max], axis=-1)
@property
def bounds(self) -> Float[Array, "... d 2"]:
"""Coordinate bounds as a ``(*, d, 2)`` array of ``[min, max]`` per
axis.
If ``origin_shift`` is set, it is added to the bounds.
"""
if self.ndim == 2:
bounds = jnp.stack([self.x_bounds, self.y_bounds], axis=-2)
else:
bounds = jnp.stack([self.x_bounds, self.y_bounds, self.z_bounds], axis=-2)
if self.origin_shift is not None:
bounds = bounds + self.origin_shift[:, np.newaxis]
return bounds
@property
def field_of_view(self) -> Float[Array, "d 2"]:
"""Alias for :py:attr:`bounds`."""
return self.bounds
@property
def extent(self) -> Float[Array, "... 2d"]:
"""Flattened bounds as a ``(*, 2d)`` array, e.g. ``(x_min, x_max,
y_min, y_max)``."""
if self.ndim == 2:
return self.bounds.reshape(self.bounds.shape[:-2] + (4,))
else:
return self.bounds.reshape(self.bounds.shape[:-2] + (6,))
[docs]
def edges(self, n_per_edge: int = 100) -> Float[Array, "N 3"]:
"""Generate 3D coordinates along the four edges of the 2D grid
boundary.
Returns points tracing the rectangular boundary of the grid in the
local x-y plane (z=0), useful for visualization.
Args:
n_per_edge: Number of sample points along each edge.
Returns:
An ``(N, 3)`` array of boundary coordinates with z=0.
"""
x_edge = jnp.concatenate(
[
jnp.tile(jnp.array([self.x_min]), n_per_edge), # Constant x as a function of y
jnp.tile(jnp.array([self.x_max]), n_per_edge),
jnp.linspace(self.x_min, self.x_max, n_per_edge),
jnp.linspace(self.x_min, self.x_max, n_per_edge),
jnp.array([0.0]),
]
)
y_edge = jnp.concatenate(
[
jnp.linspace(self.y_min, self.y_max, n_per_edge),
jnp.linspace(self.y_min, self.y_max, n_per_edge),
jnp.tile(jnp.array([self.y_min]), n_per_edge),
jnp.tile(jnp.array([self.y_max]), n_per_edge),
jnp.array([0.0]),
]
)
z_edge = jnp.zeros_like(x_edge)
edges = jnp.stack([x_edge, y_edge, z_edge], axis=-1)
return edges
[docs]
def subgrid_coordinates(
self,
n_x: int = 20,
n_y: int = 20,
) -> Float[Array, "n_x n_y 3"]:
"""Generate a coarse subgrid of 3D coordinates spanning the grid
extent.
Useful for visualization or sampling a sparse set of points across the grid.
Args:
n_x: Number of points along the x-axis.
n_y: Number of points along the y-axis.
Returns:
An ``(n_x * n_y, 3)`` array of coordinates with z=0.
"""
x = jnp.linspace(self.x_min, self.x_max, n_x, endpoint=True)
y = jnp.linspace(self.y_min, self.y_max, n_y, endpoint=True)
xx, yy = jnp.meshgrid(x, y, indexing="ij")
xx, yy = (xx.flatten(), yy.flatten())
return jnp.stack([xx, yy, jnp.zeros_like(xx)], axis=-1)
[docs]
def lines_for_plot(self, n_x: int = 20, n_y: int = 20) -> Float[Array, "N 2"]:
"""Generate grid lines for 2D plotting.
Produces horizontal and vertical line segments spanning the grid, suitable
for overlaying on a plot.
Args:
n_x: Number of vertical grid lines.
n_y: Number of horizontal grid lines.
Returns:
An array of line segments, each of shape ``(2, N)``.
"""
x = jnp.linspace(self.x_min, self.x_max, n_x, endpoint=True)
y = jnp.linspace(self.y_min, self.y_max, n_y, endpoint=True)
lines = []
lines.extend(jnp.stack([jnp.full_like(y, x_i), y], axis=-1) for x_i in x)
lines.extend(jnp.stack([x, jnp.full_like(x, y_i)], axis=-1) for y_i in y)
return jnp.asarray(lines)
[docs]
def to_far_field(
self,
wavelength: Float[Array, ""] | float = 1.0,
propagation_distance: Float[Array, ""] | float = 1.0,
fftshifted: bool = False,
) -> Self:
r"""Compute the reciprocal-space (far-field) grid corresponding to this
grid.
Applies the Fraunhofer diffraction relation to convert pixel sizes from
real space to reciprocal space:
$$\Delta q_i = \frac{\lambda \cdot z}{\Delta x_i \cdot N_i}$$
where $\lambda$ is the wavelength, $z$ is the propagation distance,
$\Delta x_i$ is the real-space pixel size, and $N_i$ is the number of
pixels along axis $i$.
Args:
wavelength: Radiation wavelength (scalar).
propagation_distance: Propagation distance to the far-field plane (scalar).
fftshifted: Whether the returned grid should be FFT-shifted.
Returns:
A new :py:class:`~ptyrax.spatial.SamplingGrid` with reciprocal-space pixel sizes.
Raises:
ValueError: If ``wavelength`` is not scalar.
"""
if not isinstance(wavelength, float) and wavelength.shape != ():
raise ValueError(
"Wavelength was not scalar in computing far-field. If multiple wavelengths are "
"used, these must be vmapped over"
)
# explicit indices must be used to avoid casting to either jnp.array or np.array
new_pixel_size = tuple(
wavelength * propagation_distance / (px * shp) for px, shp in zip(self.pixel_size, self.shape)
)
if isinstance(self.pixel_size, jax.Array):
new_pixel_size = jnp.asarray(new_pixel_size)
if isinstance(self.pixel_size, np.ndarray):
new_pixel_size = np.asarray(new_pixel_size)
return SamplingGrid.from_tuples(shape=self.shape, pixel_size=new_pixel_size, fftshifted=fftshifted)
[docs]
def __plot__(
self,
fig: Figure = None,
gs: SubplotSpec = None,
show: bool = True,
**kwargs,
) -> tuple[Figure, SubplotSpec, list[Line2D]]:
if gs is None:
fig = plt.figure(1, 1)
gs = fig.add_gridspec(1, 1)[0]
ax = fig.add_subplot(gs[0, 0])
grid = self.meshgrid.reshape(2, -1)
line = ax.plot(*grid, ".", **kwargs)
if show:
plt.show()
return fig, gs, line
[docs]
class YOnlyRotation(Rotation):
"""A rotation constrained to rotate only about the y-axis.
Useful for samples in simple reflection geometries where only a rotation
around the vertical (y) axis are presumed to be present. The y-axis component
of the 6D representation is fixed to ``[0, 1, 0]``.
Args:
representation_initializer: A callable returning a shape ``(3,)`` array
representing the x-axis direction of the rotation.
"""
representation = Float[Array, "3"]
def __init__(self, representation_initializer: Callable[[], Float[Array, "3"]]) -> Self:
self.representation = representation_initializer()
@property
def representation_6d(self) -> Float[Array, "6"]:
y_axis_part = jnp.tile(jnp.array([0.0, 1.0, 0.0]), (self.representation.shape[:-1], 1))
return jnp.concatenate([self.representation, y_axis_part], axis=-1)
[docs]
def meshgrid(
shape: tuple[int, int],
pixel_size: Float[Array, "2"] = np.array([1.0, 1.0]),
) -> Float[Array, "2 n m"]:
"""Create a 2D coordinate meshgrid centered at the origin.
Generates coordinate arrays spanning ``[-N/2, N/2)`` along each axis,
scaled by the corresponding pixel size.
Args:
shape: Grid dimensions as ``(n_rows, n_cols)``.
pixel_size: Physical pixel size as ``(dy, dx)``.
Returns:
A ``(2, n, m)`` array where ``[0]`` is the x-coordinates and ``[1]``
is the y-coordinates.
"""
y = jnp.arange(-shape[0] / 2, shape[0] / 2) * pixel_size[0]
x = jnp.arange(-shape[1] / 2, shape[1] / 2) * pixel_size[1]
return jnp.stack(jnp.meshgrid(x, y), axis=0)
[docs]
def convert_coordinates_into_indices(
coordinates: Float[Array, "... 2"],
target_sampling: SamplingGrid,
) -> tuple[Float[Array, "... 2"], Bool[Array, "..."]]:
"""Map physical coordinates to array indices on a target grid.
Converts physical (x, y) coordinates into floating-point array indices
relative to the target :py:class:`~ptyrax.spatial.SamplingGrid`. Also
returns a boolean mask indicating which coordinates fall within the grid
boundaries.
Args:
coordinates: Physical coordinates with shape ``(..., 2)`` where the
leading dimension indexes ``(x, y)``.
target_sampling: The target grid defining pixel sizes and shape.
Returns:
A tuple of:
- ``indices``: Floating-point array indices, shape ``(2, ...)``.
- ``mask``: Boolean array indicating in-bounds coordinates.
"""
# pixel_size = jnp.flip(target_sampling.pixel_size) # x -> j, y -> i
# coordinates = jnp.flip(coordinates, axis=0) # x -> j, y -> i
pixel_number = jnp.array(target_sampling.shape)
n_dim = len(coordinates.shape) - 1
# Dimensionality of coordinates is the leading dimension
extra_indices = (jnp.newaxis,) * n_dim
indices = coordinates / target_sampling.pixel_size[:, *extra_indices] + pixel_number[:, *extra_indices] / 2
# flip the y-axis: (0,0) is at the top-left corner
# indices = jnp.transpose(indices, (0, -1, -2))
indices = jnp.stack((indices[0], indices[1]), axis=0)
mask = jnp.logical_and(indices > 0, indices < pixel_number[:, *extra_indices])
mask = jnp.logical_and(*mask)
return indices, mask
@jax.custom_jvp
def angle_safe(z: Complex[Array, "..."], default: float = 0.0, eps: float = 1e-12) -> Float[Array, "..."]:
"""Angle with a fixed value at z==0 and safe gradients.
Args:
z: complex array
default: scalar angle to return at z==0
eps: threshold to treat |z|^2 as zero (avoid tiny denominators)
"""
# primal: choose default when magnitude is (near) zero
mask = jnp.abs(z) ** 2 <= eps
return jnp.where(mask, default, jnp.angle(z))
def _sliced_rounded(
x: Float[Array, "... w0 h0"], target_center_idx: Float[Array, "2"], target_shape: tuple[int, int]
) -> Float[Array, "... w1 h1"]:
"Shift by slicing (no interpolation). Pixel size should be equal."
(*_, w0, h0) = x.shape
(w1, h1) = target_shape
start_x = jnp.round(target_center_idx[0] + w0 / 2 - w1 / 2).astype(jnp.int32)
start_y = jnp.round(target_center_idx[1] + h0 / 2 - h1 / 2).astype(jnp.int32)
sliced_x = jax.lax.dynamic_slice(
x,
start_indices=(0,) * (x.ndim - 2) + (start_x, start_y),
slice_sizes=x.shape[:-2] + (w1, h1),
)
return sliced_x
def _sliced_interpolated(
x: Float[Array, "... w0 h0"],
target_center_idx: Float[Array, "2"],
target_shape: tuple[int, int],
mode: Literal["real_imaginary", "amplitude_phase"] = "real_imaginary",
) -> Float[Array, "... w1 h1"]:
(*_, w0, h0) = x.shape
(w1, h1) = target_shape
start_x = target_center_idx[0] + w0 / 2 - w1 / 2
start_y = target_center_idx[1] + h0 / 2 - h1 / 2
left_x = jnp.floor(start_x).astype(jnp.int32)
top_y = jnp.floor(start_y).astype(jnp.int32)
right_x = left_x + 1
bottom_y = top_y + 1
sliced_top_left = jax.lax.dynamic_slice(
x,
start_indices=(0,) * (x.ndim - 2) + (left_x, top_y),
slice_sizes=x.shape[:-2] + (w1, h1),
)
sliced_top_right = jax.lax.dynamic_slice(
x,
start_indices=(0,) * (x.ndim - 2) + (right_x, top_y),
slice_sizes=x.shape[:-2] + (w1, h1),
)
sliced_bottom_left = jax.lax.dynamic_slice(
x,
start_indices=(0,) * (x.ndim - 2) + (left_x, bottom_y),
slice_sizes=x.shape[:-2] + (w1, h1),
)
sliced_bottom_right = jax.lax.dynamic_slice(
x,
start_indices=(0,) * (x.ndim - 2) + (right_x, bottom_y),
slice_sizes=x.shape[:-2] + (w1, h1),
)
alpha_x = start_x - left_x
alpha_y = start_y - top_y
def _interp(corners): # noqa: ANN001
(top_left, top_right, bottom_left, bottom_right) = corners
top = (1 - alpha_x) * top_left + alpha_x * top_right
bottom = (1 - alpha_x) * bottom_left + alpha_x * bottom_right
interpolated = (1 - alpha_y) * top + alpha_y * bottom
return interpolated
if mode == "real_imaginary":
interpolated = _interp((sliced_top_left, sliced_top_right, sliced_bottom_left, sliced_bottom_right))
elif mode == "amplitude_phase":
interpolated_amplitude = _interp(
(
jnp.abs(sliced_top_left),
jnp.abs(sliced_top_right),
jnp.abs(sliced_bottom_left),
jnp.abs(sliced_bottom_right),
)
)
interpolated_total = _interp(
(
sliced_top_left,
sliced_top_right,
sliced_bottom_left,
sliced_bottom_right,
)
)
interpolated = interpolated_amplitude * phase_only_exp(angle_safe(interpolated_total))
return interpolated
[docs]
def shift_with_interpolation_equal_pixel_size(
x: Float[Array, "... w0 h0"],
target_center_idx: Float[Array, "2"],
target_shape: tuple[int, int],
interpolation_mode: Literal["rounded", "real_imaginary", "amplitude_phase"] = "rounded",
) -> Float[Array, "... w1 h1"]:
"""Extract a shifted sub-region from an array when source and target pixel
sizes match.
Crops a region of size ``target_shape`` from ``x``, centered at
``target_center_idx`` (in pixel units relative to the center of ``x``).
The shift can be rounded to the nearest pixel or interpolated.
Args:
x: Source array of shape ``(..., w0, h0)``. If 3D, the leading
dimension is treated as independent channels.
target_center_idx: Sub-pixel center offset as ``(dx, dy)`` in index space.
target_shape: Output spatial dimensions ``(w1, h1)``.
interpolation_mode: One of ``'rounded'`` (nearest-pixel slice),
``'real_imaginary'`` (bilinear on real/imag parts), or
``'amplitude_phase'`` (bilinear on amplitude, phase-aware).
Returns:
The shifted sub-region with shape ``(..., w1, h1)``.
"""
def interpolate_single_channel(channel: Float[Array, "..."]) -> Float[Array, "..."]:
if interpolation_mode == "rounded":
interpolated_output = _sliced_rounded(channel, target_center_idx, target_shape)
elif interpolation_mode in ("real_imaginary", "amplitude_phase"):
interpolated_output = _sliced_interpolated(channel, target_center_idx, target_shape)
else:
raise ValueError(
"interpolation mode for equal pixels must be one of ['rounded', 'real_imaginary', 'amplitude_phase']. "
f"Got {interpolation_mode} instead"
)
return interpolated_output
if x.ndim == 3:
interpolator = jax.vmap(interpolate_single_channel)
else:
interpolator = interpolate_single_channel
return interpolator(x)
[docs]
@gin.configurable()
def shift_with_interpolation_unequal_pixel_size(
x: Float[Array, "... w0 h0"],
original_pixel_size: Float[Array, "2"] | tuple[float, float],
target_center_idx: Float[Array, "2"],
target_shape: tuple[int, int],
target_pixel_size: Float[Array, "2"] | tuple[float, float],
interpolation_mode: Literal["real_imaginary", "amplitude_phase"] = "real_imaginary",
) -> Float[Array, "... w1 h1"]:
"""Extract a shifted and resampled sub-region when pixel sizes differ.
Interpolates values from ``x`` onto a target grid that may have a different
pixel size, using ``jax.scipy.ndimage.map_coordinates`` for the resampling.
Args:
x: Source array of shape ``(..., w0, h0)``.
original_pixel_size: Pixel size of the source array as ``(px_x, px_y)``.
target_center_idx: Center offset in source-pixel units as ``(dx, dy)``.
target_shape: Output spatial dimensions ``(w1, h1)``.
target_pixel_size: Pixel size of the target grid as ``(px_x, px_y)``.
interpolation_mode: Either ``'real_imaginary'`` (bilinear on real/imag)
or ``'amplitude_phase'`` (bilinear on amplitude, phase-aware).
Returns:
The resampled sub-region with shape ``(..., w1, h1)``.
"""
(*_, w0, h0) = x.shape
(w1, h1) = target_shape
# relative_pixel_size = [t/o for t, o in zip(target_pixel_size, original_pixel_size)]
# grid_x = jnp.arange(w) * relative_pixel_size[0]
# grid_y = jnp.arange(h) * relative_pixel_size[1]
# Compute index offsets for target grid, centered around its own center
grid_x = (jnp.arange(w1) - (w1) / 2) * (target_pixel_size[0] / original_pixel_size[0])
grid_y = (jnp.arange(h1) - (h1) / 2) * (target_pixel_size[1] / original_pixel_size[1])
grid_x, grid_y = jnp.meshgrid(grid_x, grid_y, indexing="ij")
# Rotating about target_center
# TODO include rotation of the indexing grid
# rotated_grid = jnp.einsum('ij, mnj -> mni', target_rotation_matrix,
# jnp.stack([grid_x, grid_y, jnp.zeros_like(grid_x)], axis=-1))
# Translate to source index space: center of source + center offset
grid_x = grid_x + target_center_idx[0] + w0 / 2
grid_y = grid_y + target_center_idx[1] + h0 / 2
grid = jnp.stack([grid_x, grid_y], axis=0)
def interpolate_single_channel(channel: Float[Array, "..."]) -> Float[Array, "..."]:
if interpolation_mode == "real_imaginary":
interpolated_output = map_coordinates(channel, grid, order=1, mode="constant")
elif interpolation_mode == "amplitude_phase":
interpolated_amplitude = map_coordinates(jnp.abs(channel), grid, order=1, mode="constant")
interpolated_total = map_coordinates(channel, grid, order=1, mode="constant")
interpolated_output = interpolated_amplitude * phase_only_exp(angle_safe(interpolated_total))
else:
raise ValueError(
"interpolation mode for unequal pixels must be one of ['real_imaginary', 'amplitude_phase']."
f"Got {interpolation_mode}"
)
return interpolated_output
if x.ndim == 3:
interpolator = jax.vmap(interpolate_single_channel)
else:
interpolator = interpolate_single_channel
return interpolator(x)
[docs]
def interpolate_grid_to_grid(
samples: Complex[Array, "* m0 n0"],
source_grid: SamplingGrid = None,
source_coordinate_system: CoordinateSystem = CoordinateSystem(),
target_grid: SamplingGrid = None,
target_coordinate_system: CoordinateSystem = CoordinateSystem(),
interpolation_mode: Literal["real_imaginary", "amplitude_phase", "sliced_rounded", "sliced_interpolated"] = None,
equal_pixel_size: bool = False,
) -> Complex[Array, "* m1 n1"]:
"""Interpolate complex-valued samples from a source grid onto a target
grid.
Handles both equal and unequal pixel sizes between source and target grids,
accounting for the relative translation between their coordinate systems.
The translation is projected into the source's local frame to determine the
pixel-space offset.
Args:
samples: Complex array of shape ``(*, m0, n0)`` on the source grid.
source_grid: The :py:class:`~ptyrax.spatial.SamplingGrid` of the source data.
source_coordinate_system: The :py:class:`~ptyrax.spatial.CoordinateSystem`
of the source grid.
target_grid: The :py:class:`~ptyrax.spatial.SamplingGrid` to interpolate onto.
Defaults to ``source_grid`` if not provided.
target_coordinate_system: The :py:class:`~ptyrax.spatial.CoordinateSystem`
of the target grid.
interpolation_mode: Interpolation strategy. See
:py:func:`~ptyrax.spatial.shift_with_interpolation_equal_pixel_size` and
:py:func:`~ptyrax.spatial.shift_with_interpolation_unequal_pixel_size`.
equal_pixel_size: If ``True``, use the faster equal-pixel-size path.
Returns:
The interpolated samples on the target grid, shape ``(*, m1, n1)``.
Raises:
ValueError: If ``source_grid`` is not provided.
"""
if source_grid is None:
raise ValueError("Source_grid must be provided in call to interpolate_grid_to_grid")
if target_grid is None:
target_grid = source_grid
equal_pixel_size = True
kwargs = {"interpolation_mode": interpolation_mode} if interpolation_mode is not None else {}
projected_translation = source_coordinate_system.translation_internal
source_pixel_size = jnp.array((source_grid.x_pixel_size, source_grid.y_pixel_size))
projected_translation_idx = projected_translation[:2] / source_pixel_size
if equal_pixel_size:
interpolated_samples = shift_with_interpolation_equal_pixel_size(
samples, projected_translation_idx, target_grid.shape, **kwargs
)
return interpolated_samples
else:
target_pixel_size = jnp.array((target_grid.x_pixel_size, target_grid.y_pixel_size))
interpolated_samples = shift_with_interpolation_unequal_pixel_size(
samples,
source_pixel_size,
projected_translation_idx,
target_grid.shape,
target_pixel_size,
**kwargs,
)
return interpolated_samples
[docs]
def draw_axis_arrow(
axis_label: str,
xy_start: tuple[float, float],
xy_end: tuple[float, float],
ax: Axes,
**kwargs,
) -> Axes:
"""Draw a labeled arrow on a matplotlib axes to indicate a coordinate axis.
Args:
axis_label: Text label for the axis (e.g. ``'$x_s$'``).
xy_start: ``(x, y)`` coordinates of the arrow tip.
xy_end: ``(x, y)`` coordinates of the arrow tail.
ax: The matplotlib axes to draw on.
**kwargs: Additional keyword arguments passed to ``ax.annotate``.
Returns:
The axes with the arrow drawn.
"""
ax.annotate(
"",
xy=xy_start,
xytext=xy_end,
arrowprops=kwargs.pop("arrowprops", dict(arrowstyle="->", color=kwargs.get("color", "black"))),
**kwargs,
)
ax.annotate(
axis_label,
xy=xy_start,
xytext=(5, 5),
textcoords="offset points",
color=kwargs.get("color", "black"),
)
return ax
[docs]
def plot_geometry(
sample_coordinates: CoordinateSystem,
detector_coordinates: CoordinateSystem,
detector_coordinate_edges: ArrayLike,
arrow_length_ratio: float = 0.1,
fig: Figure = None,
gs: SubplotSpec = None,
) -> tuple[Figure, Axes]:
"""Plot the experimental geometry in the x-z plane.
Visualizes the sample and detector positions along with their local
coordinate axes (x in red, z in blue) projected onto the x-z scattering
plane.
Args:
sample_coordinates: The :py:class:`~ptyrax.spatial.CoordinateSystem`
describing sample positions and orientations.
detector_coordinates: The :py:class:`~ptyrax.spatial.CoordinateSystem`
describing detector position and orientation.
detector_coordinate_edges: Array of detector boundary coordinates used
for plotting detector extent.
arrow_length_ratio: Length of axis arrows as a fraction of the plot scale.
fig: Optional existing matplotlib figure.
gs: Optional matplotlib gridspec subplot.
Returns:
A tuple of the matplotlib ``(Figure, Axes)``.
"""
fig, ax = make_or_reuse_axes(fig, gs)
ax.plot(detector_coordinate_edges[0, :, 0], detector_coordinate_edges[0, :, 2], ".")
ax.plot(
sample_coordinates.translation[:, 0] * 100,
sample_coordinates.translation[:, 2] * 100,
".",
)
sample_x = sample_coordinates.rotation.as_matrix()[0].T @ jnp.array([1, 0, 0])
sample_z = sample_coordinates.rotation.as_matrix()[0].T @ jnp.array([0, 0, 1])
scale = max((ax.get_xlim()[1] - ax.get_xlim()[0]), (ax.get_ylim()[1] - ax.get_ylim()[0]))
arrow_length = scale * arrow_length_ratio
ax.set_aspect("equal", adjustable="box")
ax = draw_axis_arrow(
"$x_s$",
(arrow_length * sample_x[0], arrow_length * sample_x[2]),
(0, 0),
ax,
color="red",
)
ax = draw_axis_arrow(
"$z_s$",
(arrow_length * sample_z[0], arrow_length * sample_z[2]),
(0, 0),
ax,
color="blue",
)
detector_x = detector_coordinates.rotation.as_matrix()[0].T @ jnp.array([1, 0, 0])
detector_z = detector_coordinates.rotation.as_matrix()[0].T @ jnp.array([0, 0, 1])
x_d, y_d, z_d = detector_coordinates.translation[0]
ax = draw_axis_arrow(
"$x_d$",
(x_d + arrow_length * detector_x[0], z_d + arrow_length * detector_x[2]),
(x_d, z_d),
ax,
color="red",
)
ax = draw_axis_arrow(
"$z_d$",
(x_d + arrow_length * detector_z[0], z_d + arrow_length * detector_z[2]),
(x_d, z_d),
ax,
color="blue",
)
ax.axis("equal")
ax.set_xlabel("x")
ax.set_xlim(np.flip(ax.get_xlim()))
ax.set_ylabel("z")
return fig, ax
[docs]
def R_x(angle: float, **kwargs) -> Float[Array, "3 3"]: # noqa: N802
"""Construct a 3x3 rotation matrix for rotation about the x-axis.
Args:
angle: Rotation angle in degrees.
Returns:
A ``(3, 3)`` rotation matrix.
"""
return JaxRotation.from_rotvec([angle, 0, 0], degrees=True).as_matrix()
[docs]
def R_y(angle: float, **kwargs) -> Float[Array, "3 3"]: # noqa: N802
"""Construct a 3x3 rotation matrix for rotation about the y-axis.
Args:
angle: Rotation angle in degrees.
Returns:
A ``(3, 3)`` rotation matrix.
"""
return JaxRotation.from_rotvec([0, angle, 0], degrees=True).as_matrix()
[docs]
def R_z(angle: float, **kwargs) -> Float[Array, "3 3"]: # noqa: N802
"""Construct a 3x3 rotation matrix for rotation about the z-axis.
Args:
angle: Rotation angle in degrees.
Returns:
A ``(3, 3)`` rotation matrix.
"""
return JaxRotation.from_rotvec([0, 0, angle], degrees=True).as_matrix()
[docs]
@gin.configurable("interpolated_shift")
def shift_with_interpolation(
x: Float[Array, "..."],
center: Float[Array, "2"],
target_shape: tuple[int, int],
) -> Array:
"""Shift and crop an array using bilinear interpolation.
Extracts a region of ``target_shape`` from ``x`` centered at ``center``
(in pixel coordinates relative to the array center), using
``jax.scipy.ndimage.map_coordinates`` with nearest-neighbor boundary
handling. The leading dimension of ``x`` is vmapped over as channels.
Args:
x: Source array of shape ``(C, w, h)`` where ``C`` is the number of
channels.
center: Sub-pixel center offset as ``(cx, cy)`` in index space.
target_shape: Output spatial dimensions ``(w_out, h_out)``.
Returns:
The interpolated sub-region with shape ``(C, w_out, h_out)``.
"""
(_, w, h) = x.shape
corner = center - jnp.array(target_shape) / 2
coords = jnp.array((w / 2 + corner[0], h / 2 + corner[1]))
# Create a grid of coordinates for the target shape
grid_x, grid_y = jnp.meshgrid(jnp.arange(target_shape[0]), jnp.arange(target_shape[1]), indexing="ij")
grid_x = grid_x + coords[0]
grid_y = grid_y + coords[1]
grid = jnp.stack([grid_x, grid_y], axis=0)
def interpolate_single_channel(channel: Float[Array, "..."]) -> Float[Array, "..."]:
return map_coordinates(channel, grid, order=1, mode="nearest")
interpolated_values = jax.vmap(interpolate_single_channel)(x)
return interpolated_values
[docs]
@angle_safe.defjvp
def angle_safe_jvp(
primals: Complex[Array, "..."], tangents: Complex[Array, "..."]
) -> tuple[Float[Array, "..."], Float[Array, "..."]]:
(z, default, eps), (z_dot, _, _) = primals, tangents
# primal (same as forward)
r2 = jnp.abs(z) ** 2
mask = r2 <= eps
primal_out = jnp.where(mask, default, jnp.angle(z))
# SAFE denominator: never zero (or under eps) so division is finite.
safe_r2 = jnp.maximum(r2, eps)
# numerator for tangent: Im(conj(z) * dz)
# note: this is real-valued
numerator = jnp.imag(jnp.conj(z) * z_dot)
# compute tangent using safe denominator then zero it where |z|^2 <= eps
tentative_tangent = numerator / safe_r2
tangent_out = jnp.where(mask, 0.0, tentative_tangent)
return primal_out, tangent_out
[docs]
def rotation_matrix_from_angles(angles: Float[Array, "3 ..."]) -> Float[Array, "... 3 3"]:
r"""Build a rotation matrix from intrinsic Euler angles (x-y-z order).
Constructs the combined rotation matrix as $R = R_z \cdot R_y \cdot R_x$
from three rotation angles given in degrees.
Args:
angles: A shape ``(3, ...)`` array of ``(theta_x, theta_y, theta_z)``
Euler angles in degrees.
Returns:
A ``(..., 3, 3)`` rotation matrix.
"""
theta_x, theta_y, theta_z = jnp.deg2rad(angles)
r_x = jnp.array([[1, 0, 0], [0, jnp.cos(theta_x), -jnp.sin(theta_x)], [0, jnp.sin(theta_x), jnp.cos(theta_x)]])
r_y = jnp.array([[jnp.cos(theta_y), 0, jnp.sin(theta_y)], [0, 1, 0], [-jnp.sin(theta_y), 0, jnp.cos(theta_y)]])
r_z = jnp.array([[jnp.cos(theta_z), -jnp.sin(theta_z), 0], [jnp.sin(theta_z), jnp.cos(theta_z), 0], [0, 0, 1]])
return r_z @ r_y @ r_x
[docs]
def phase_only_exp(phase: Float[Array, "... m n"]) -> Complex[Array, "... m n"]:
r"""Compute a unit-magnitude complex exponential $e^{i\phi}$ using JAX.
Equivalent to ``jnp.exp(1j * phase)`` but avoids constructing an
intermediate complex array.
Args:
phase: Real-valued phase array in radians.
Returns:
Complex array with unit magnitude and the given phase.
Raises:
ValueError: If ``phase`` is complex-valued.
"""
if jnp.iscomplexobj(phase):
raise ValueError("Input to phase_only_exp should be real-valued phase array.")
return jnp.cos(phase) + 1j * jnp.sin(phase)
[docs]
def phase_only_exp_np(phase: Float[Array, "... m n"]) -> Complex[Array, "... m n"]:
r"""Compute a unit-magnitude complex exponential $e^{i\phi}$ using NumPy.
NumPy equivalent of :py:func:`~ptyrax.spatial.phase_only_exp`, useful
outside of JAX-traced contexts.
Args:
phase: Real-valued phase array in radians.
Returns:
Complex array with unit magnitude and the given phase.
Raises:
ValueError: If ``phase`` is complex-valued.
"""
if np.iscomplexobj(phase):
raise ValueError("Input to phase_only_exp should be real-valued phase array.")
return np.cos(phase) + 1j * np.sin(phase)
[docs]
def reflect(a: Shaped[Array, "... d"], n: Shaped[Array, " d"]) -> Shaped[Array, "... d"]:
r"""Reflect a vector ``a`` across a plane with unit normal ``n``.
Computes the reflection $a - 2(a \cdot n)n$.
Args:
a: Vector(s) to reflect, shape ``(..., d)``.
n: Unit normal of the reflection plane, shape ``(d,)``.
Returns:
The reflected vector(s), same shape as ``a``.
"""
return a - 2 * jnp.dot(a, n) * n
[docs]
def is_orthogonal(a: ArrayLike, b: ArrayLike) -> bool:
"""Check whether two vectors are orthogonal (dot product close to zero)."""
return jnp.isclose(a @ b, 0.0)
[docs]
def are_mutually_orthogonal(*args: tuple[ArrayLike, "..."]) -> bool:
"""Check whether all given vectors are mutually orthogonal.
Args:
*args: Two or more vectors to test.
Returns:
``True`` if every pair of vectors is orthogonal.
"""
for i, a in enumerate(args):
for b in args[i + 1 :]:
if not is_orthogonal(a, b):
return False
return True
[docs]
def are_in_plane(a: ArrayLike, b: ArrayLike, c: ArrayLike) -> bool:
"""Check whether three vectors are coplanar (scalar triple product close to
zero)."""
return jnp.isclose(jnp.dot(jnp.cross(a, b), c), 0.0)
[docs]
def is_parallel(a: ArrayLike, b: ArrayLike) -> bool:
"""Check whether two vectors are parallel."""
return jnp.isclose(a @ b, jnp.linalg.norm(a) * jnp.linalg.norm(b))
[docs]
def is_normalized(a: ArrayLike) -> bool:
"""Check whether a vector has unit norm."""
return jnp.isclose(jnp.linalg.norm(a), 1.0)
[docs]
def rotate_about_point(
input: Float[Array, "3"],
rotation: Float[Array, "3 3"],
center_of_rotation: Float[Array, "3"],
) -> Float[Array, "3"]:
"""Rotate a 3D point about a given center of rotation.
Translates the point so the center of rotation is at the origin, applies
the rotation, then translates back.
Args:
input: The 3D point to rotate.
rotation: A ``(3, 3)`` rotation matrix.
center_of_rotation: The 3D point to rotate around.
Returns:
The rotated 3D point.
"""
return rotation @ (input - center_of_rotation) + center_of_rotation
[docs]
def three_dimensional_representation_to_matrix(representation: Float[Array, "... 3"]) -> Float[Array, "... 3 3"]:
"""Converts a 3D rotation representation (with fixed z-axis) to a 3x3
rotation matrix.
Based on:
"On the continuity of rotation representations in neural networks"
Yi Zhou, Connelly Barnes, Jingwan Lu, Jimei Yang, Hao Li.
Conference on Neural Information Processing Systems (NeurIPS) 2019.
Args:
representation: A shape (..., 6) array representing the rotation as a 6D vector.
Returns:
A shape (..., 3, 3) array representing the rotation as a 3x3 matrix.
"""
representation = jnp.atleast_1d(representation)
# original_shape = representation.shape
x = representation
x = x / jnp.linalg.norm(x, axis=-1, keepdims=True)
z = jnp.array([0, 0, 1.0])
y = jnp.cross(z, x)
rotation_matrix = jnp.stack((x, y, z), axis=-2)
return rotation_matrix