ptyrax.training#

Functions

adam_thresholded(learning_rate[, b1, b2, ...])

Adam optimizer factory with thresholding to ignore very small updates.

bases_l1(field[, weight])

Computes the L1 norm of the base arrays used to compute the field samples.

fft_l1(field_or_array, weight)

Computes the L1 norm of the FFT of the field data.

fft_l2(field, weight)

Computes the L2 norm of the 2-D FFT of the field data.

initialize_optimizer_and_state(model, *[, ...])

Initialize partitioned optimizers and their state for a model based on OptimizerSpecification objects.

l1(field[, weight])

Computes the L1 norm regularization term.

l2(field[, weight])

Computes the L2 norm regularization term.

loop_schedule(schedule, schedule_duration)

Wraps a schedule so that it restarts after every schedule_duration steps.

loss(model, data_indices, target_image[, ...])

Computes the loss for the model.

make_optimizer_specification(base_optimizer, ...)

Construct an OptimizerSpecification grouping optimizer, learning-rate schedule, and matching patterns.

mean_error(y_true, y_pred)

Computes the mean error between the true and predicted values.

mean_square_error(y_true, y_pred)

Computes the mean square error between the true and predicted values.

mixed_mean_square_error(y_true, y_pred[, eps])

Computes the relative mean square error between the true and predicted values.

parameter_l1(pytree[, weight])

Computes the L1 norm of the parameters.

regularize(model)

Computes the regularization term for the model.

relative_mean_square_error(y_true, y_pred[, eps])

Computes the relative mean square error between the true and predicted values.

relative_mean_square_error_per_pixel(y_true, ...)

Computes the relative mean square error between the true and predicted values on a per-pixel basis.

schedule_product(*schedules)

Combines multiple schedules by multiplying their outputs.

support_l1(field[, weight])

Computes the L1 norm of the support matrix.

support_l2(field[, weight])

Computes the L2 norm of the field within a support region.

support_overlap(field[, weight])

Computes the support overlap regularization term.

threshold_relative_mean_square_error(y_true, ...)

Computes the relative mean square error between the true and predicted values.

tv(field[, weight, tv_mode])

Computes the total variation regularization term.

Classes

OptimizerSpecification(name, match_patterns, ...)

Groups an optimizer, learning-rate schedule, and parameter-matching patterns.

ThresholdedScaleByAdamState(count, mu, nu)

State for the Adam algorithm.

class ptyrax.training.OptimizerSpecification(name, match_patterns, optimizer, learn_rate_schedule=<function constant_schedule.<locals>.<lambda>>)[source]#

Bases: object

Groups an optimizer, learning-rate schedule, and parameter-matching patterns.

Instances are created by make_optimizer_specification() and consumed by initialize_optimizer_and_state() to partition model parameters into optimizer groups.

Variables:
  • name – Human-readable label (usually the gin scope name).

  • match_patterns – Regex patterns that select which parameter paths use this optimizer.

  • optimizer – The composed optax.GradientTransformation.

  • learn_rate_schedule – Learning-rate schedule applied to the optimizer.

Parameters:
  • name (str)

  • match_patterns (list[str])

  • optimizer (GradientTransformation)

  • learn_rate_schedule (Callable[[Array | ndarray | bool | number | float | int], Array | ndarray | bool | number | float | int])

learn_rate_schedule()#
match_patterns: list[str]#
name: str#
optimizer: GradientTransformation#
class ptyrax.training.ThresholdedScaleByAdamState(count, mu, nu)[source]#

Bases: NamedTuple

State for the Adam algorithm.

Parameters:
  • count (PyTree)

  • mu (PyTree)

  • nu (PyTree)

count: PyTree#

Alias for field number 0

index(value, start=0, stop=9223372036854775807, /)#

Return first index of value.

Raises ValueError if the value is not present.

mu: PyTree#

Alias for field number 1

nu: PyTree#

Alias for field number 2

ptyrax.training.adam_thresholded(learning_rate, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, threshold=1e-08, mu_dtype=None, *, nesterov=False)[source]#

Adam optimizer factory with thresholding to ignore very small updates.

Parameters:
  • learning_rate (float | Array | Callable[[Array | ndarray | bool | number | float | int], Array | ndarray | bool | number | float | int]) – Scalar or schedule for the learning rate.

  • b1 (float) – Adam hyperparameters.

  • b2 (float) – Adam hyperparameters.

  • eps (float) – Adam hyperparameters.

  • eps_root (float) – Adam hyperparameters.

  • threshold (float) – Minimum update magnitude to consider.

  • mu_dtype (Any | None) – Optional dtype for first-moment accumulators.

  • nesterov (bool) – Whether to use Nesterov-style correction.

Returns:

An optax.GradientTransformation implementing the thresholded Adam update.

Return type:

GradientTransformation

ptyrax.training.bases_l1(field, weight=0.0)[source]#

Computes the L1 norm of the base arrays used to compute the field samples. For example, if data is modeled by an OuterProductArrayParametrization, the arrays present in the outer product are used to compute the l1 norm. Mostly useful if wavelet bases are used, as natural images are expected to have sparse l1 in these bases.

Parameters:
  • field (CoherentField) – The field to compute the L1 norm for.

  • weight (float) – The weight of the regularization term.

Returns:

The computed L1 norm of the bases.

Return type:

float

ptyrax.training.fft_l1(field_or_array, weight)[source]#

Computes the L1 norm of the FFT of the field data.

Parameters:
  • field (CoherentField) – The field to compute the FFT L1 norm for.

  • weight (float) – The weight of the regularization term.

  • field_or_array (CoherentField | Array | ndarray | bool | number | bool | int | float | complex | LiteralArray)

Returns:

The computed FFT L1 norm regularization term.

Return type:

float

ptyrax.training.fft_l2(field, weight)[source]#

Computes the L2 norm of the 2-D FFT of the field data.

Applies a 2-D FFT to field.data and returns the weighted L2 norm, encouraging low energy content in the Fourier domain.

Parameters:
  • field (CoherentField) – The coherent field whose Fourier-space L2 norm is computed.

  • weight (float) – Scalar multiplier for the regularization term.

Returns:

The weighted L2 norm of the FFT of the field data.

Return type:

Float[Array, “”]

Example

`python reg = fft_l2(probe_field, weight=1e-4) `

ptyrax.training.initialize_optimizer_and_state(model, *, optimizers=None)[source]#

Initialize partitioned optimizers and their state for a model based on OptimizerSpecification objects.

Parameters:
  • model (PtychographyModel) – The model whose parameters will be partitioned and optimized.

  • optimizers (list[OptimizerSpecification] | None) – List of OptimizerSpecification objects describing optimizer assignments.

Returns:

A tuple of (optimizer_state, partitioned_optimizer, dynamic_variable_spec).

Return type:

tuple[PyTree[ptyrax.models.ptychography.PtychographyModel], GradientTransformation]

ptyrax.training.l1(field, weight=0.0)[source]#

Computes the L1 norm regularization term.

Parameters:
  • field (CoherentField) – The field to compute the L1 norm for.

  • weight (float) – The weight of the regularization term.

Returns:

The computed L1 norm regularization term.

Return type:

float

ptyrax.training.l2(field, weight=0.0)[source]#

Computes the L2 norm regularization term.

Parameters:
  • field (CoherentField) – The field to compute the L2 norm for.

  • weight (float) – The weight of the regularization term.

Returns:

The computed L2 norm regularization term.

Return type:

float

ptyrax.training.loop_schedule(schedule, schedule_duration)[source]#

Wraps a schedule so that it restarts after every schedule_duration steps.

Parameters:
  • schedule (optax.Schedule) – The base schedule to loop.

  • schedule_duration (int | float) – Number of steps before the schedule resets.

Returns:

A new schedule that cycles the base schedule.

Return type:

optax.Schedule

Example

`python cosine = optax.cosine_decay_schedule(1e-3, 100) looped = loop_schedule(cosine, schedule_duration=100) `

ptyrax.training.loss(model, data_indices, target_image, loss_fn=<function mean_square_error>, batch_reduction_fn=<function sum>)[source]#

Computes the loss for the model.

Parameters:
  • model (PtychographyModel) – The model to compute the loss for.

  • scanning_position_index (array) – The scanning position index.

  • target_diffraction_pattern (array) – The target diffraction pattern.

  • loss_fn (Callable) – The loss function to use.

  • data_indices (Integer[Array, ''])

  • target_image (Float[Array, ''])

  • batch_reduction_fn (Callable[[Array], float])

Returns:

The computed loss.

Return type:

float

ptyrax.training.make_optimizer_specification(base_optimizer, match_patterns, learn_rate=None, learn_rate_schedule=None, optimizer_pre_transforms=(), optimizer_wrappers=(), optimizer_post_transforms=(), schedule_wrappers=(), **kwargs)[source]#

Construct an OptimizerSpecification grouping optimizer, learning-rate schedule, and matching patterns.

This factory accepts either an optax.GradientTransformation or a callable that builds one from a learning rate. It supports wrapping transforms and schedule wrappers, and records the gin scope name as the specification name.

Parameters:
  • base_optimizer (Callable[[float | Array | Callable[[Array | ndarray | bool | number | float | int], Array | ndarray | bool | number | float | int]], GradientTransformation] | GradientTransformation) – Optimizer factory or ready GradientTransformation.

  • match_patterns (list[str]) – List of regex patterns selecting parameters for this optimizer.

  • learn_rate (float) – Constant learning rate (mutually exclusive with learn_rate_schedule).

  • learn_rate_schedule (Callable[[Array | ndarray | bool | number | float | int], Array | ndarray | bool | number | float | int]) – optax schedule to use for this optimizer.

  • optimizer_pre_transforms (tuple[GradientTransformation, ...])

  • optimizer_wrappers (tuple[Callable[[GradientTransformation], GradientTransformation], ...])

  • optimizer_post_transforms (tuple[GradientTransformation, ...])

  • schedule_wrappers (tuple[Callable[[Callable[[Array | ndarray | bool | number | float | int], Array | ndarray | bool | number | float | int]], Callable[[Array | ndarray | bool | number | float | int], Array | ndarray | bool | number | float | int]], ...])

Returns:

OptimizerSpecification describing optimizer and schedule for partitioning.

Return type:

OptimizerSpecification

ptyrax.training.mean_error(y_true, y_pred)[source]#

Computes the mean error between the true and predicted values.

Parameters:
  • y_true (array) – The true values.

  • y_pred (array) – The predicted values.

Returns:

The computed mean error.

Return type:

float

ptyrax.training.mean_square_error(y_true, y_pred)[source]#

Computes the mean square error between the true and predicted values.

Parameters:
  • y_true (array) – The true values.

  • y_pred (array) – The predicted values.

Returns:

The computed mean square error.

Return type:

float

ptyrax.training.mixed_mean_square_error(y_true, y_pred, eps=0.01)[source]#

Computes the relative mean square error between the true and predicted values.

Parameters:
  • y_true (array) – The true values.

  • y_pred (array) – The predicted values.

  • eps (float) – A small value to avoid division by zero.

Returns:

The computed relative mean square error.

Return type:

float

ptyrax.training.parameter_l1(pytree, weight=0.0)[source]#

Computes the L1 norm of the parameters.

Parameters:
  • pytree (PyTree) – The pytree to compute the L1 norm for.

  • weight (float) – The weight of the regularization term.

Returns:

The computed L1 norm of the parameters.

Return type:

float

ptyrax.training.regularize(model)[source]#

Computes the regularization term for the model.

Parameters:

model (ImagePredictionModel) – The model to regularize.

Returns:

The computed regularization term.

Return type:

float

ptyrax.training.relative_mean_square_error(y_true, y_pred, eps=0.01)[source]#

Computes the relative mean square error between the true and predicted values.

Parameters:
  • y_true (array) – The true values.

  • y_pred (array) – The predicted values.

  • eps (float) – A small value to avoid division by zero.

Returns:

The computed relative mean square error.

Return type:

float

ptyrax.training.relative_mean_square_error_per_pixel(y_true, y_pred, eps=1e-08)[source]#

Computes the relative mean square error between the true and predicted values on a per-pixel basis.

Parameters:
  • y_true (array) – The true values.

  • y_pred (array) – The predicted values.

  • eps (float) – A small value to avoid division by zero.

Returns:

The computed relative mean square error.

Return type:

float

ptyrax.training.schedule_product(*schedules)[source]#

Combines multiple schedules by multiplying their outputs.

Parameters:

*schedules (Iterable[Callable[[Array | ndarray | bool | number | float | int], Array | ndarray | bool | number | float | int]]) – The schedules to combine.

Returns:

A schedule that is the product of the input schedules.

Return type:

Callable[[Array | ndarray | bool | number | float | int], Array | ndarray | bool | number | float | int]

ptyrax.training.support_l1(field, weight=0.0, **kwargs)[source]#

Computes the L1 norm of the support matrix.

Parameters:
  • field (CoherentField) – The field to compute the L1 norm for.

  • weight (float) – The weight of the regularization term.

Returns:

The computed L1 norm of the support matrix.

Return type:

float

ptyrax.training.support_l2(field, weight=0.0, **kwargs)[source]#

Computes the L2 norm of the field within a support region.

Multiplies the field data by a support mask and returns the weighted L2 norm. The support mask is generated from the field’s sampling grid using support().

Parameters:
  • field (CoherentField) – The coherent field whose data is masked and normed.

  • weight (float) – Scalar multiplier for the regularization term. Defaults to 0.0.

  • **kwargs – Additional keyword arguments forwarded to support().

Returns:

The weighted L2 norm of the support-masked field.

Return type:

float

Example

`python reg = support_l2(probe_field, weight=1e-3, radius=0.5) `

ptyrax.training.support_overlap(field, weight=0.0, **kwargs)[source]#

Computes the support overlap regularization term.

Parameters:
  • field (CoherentField) – The field to compute the support overlap for.

  • weight (float) – The weight of the regularization term.

Returns:

The computed support overlap regularization term.

Return type:

float

ptyrax.training.threshold_relative_mean_square_error(y_true, y_pred, eps=0.01, threshold=0.0004)[source]#

Computes the relative mean square error between the true and predicted values.

Parameters:
  • y_true (array) – The true values.

  • y_pred (array) – The predicted values.

  • eps (float) – A small value to avoid division by zero.

  • threshold (float)

Returns:

The computed relative mean square error.

Return type:

float

ptyrax.training.tv(field, weight=0.0, tv_mode='real_imag')[source]#

Computes the total variation regularization term.

Parameters:
  • field (CoherentField) – The field to compute the total variation for.

  • weight (float) – The weight of the regularization term.

  • tv_mode (str) – The mode of total variation computation. Can be “real_imag” or “mag_phase”.

Returns:

The computed total variation regularization term.

Return type:

float

Raises:

ValueError – If the tv_mode is not one of “real_imag” or “mag_phase”.