Source code for ptyrax.logger

from __future__ import annotations

import io
import pathlib
from typing import TYPE_CHECKING, Any, Union

import gin
import h5py
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Inexact

# Decode images into arrays (stack along axis 0)
from PIL import Image
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from tensorboardX import SummaryWriter

from ptyrax.parametrizations import resolve_parametrizations
from ptyrax.utils import plot

if TYPE_CHECKING:
    from ptyrax.dataset import ImageDataset
    from ptyrax.models.ptychography import ImagePredictionModel


[docs] @gin.configurable def log_image(writer: SummaryWriter, tag: str, tensor: Inexact[Array, "..."], step: int, **kwargs) -> None: """Log an image to tensorboard. Args: writer: SummaryWriter Tensorboard writer tag: str Tag for the image tensor: np.ndarray Image to log step: int Current step **kwargs: Additional arguments for the plot function Returns: None """ fig, axs, im = plot(tensor, show=False, **kwargs) writer.add_figure(tag, fig, step)
[docs] @gin.configurable def log_on_train_start(writer: SummaryWriter, dataset: ImageDataset, debug: bool = False) -> None: """Hook to log initial information at the start of training to TensorBoard. Args: writer: TensorBoard `SummaryWriter` instance. dataset: Dataset to inspect and optionally log example images from. debug: If True, log additional debug images for every dataset item. Returns: None """ writer.add_text("gin config", gin.config_str(max_line_length=70)) for line in gin.config.config_str().split("\n"): if line.startswith("#"): continue try: key, _ = line.split(" = ") value = gin.query_parameter(key) writer.add_text(key.replace(".", "/"), str(value)) except ValueError: continue if debug: for i in range(dataset.n): pat = dataset[i] log_image(writer, "4_model/true_dif_pat", pat, i, gamma=0.5) else: log_image(writer, "4_model/true_dif_pat/0", dataset[0], 0, gamma=0.5)
[docs] @gin.configurable def log_on_epoch_end( writer: SummaryWriter, epoch: int, model: ImagePredictionModel, epoch_total_loss: float, all_epoch_losses: dict, debug: bool = False, diff_pat_gamma: float = 0.5, log_every: int = 1, **kwargs, ) -> None: """Hook executed at the end of each epoch to log predictions, scalars and other metrics. Args: writer: TensorBoard `SummaryWriter`. epoch: Current epoch index. model: Model used for prediction/logging. epoch_total_loss: Total loss for the epoch. all_epoch_losses: Iterable of NamedLoss-like objects. debug: If True, log per-sample predictions. diff_pat_gamma: Gamma applied to diffraction patterns when visualizing. log_every: Only run logging every `log_every` epochs. Returns: None """ if epoch % log_every != 0: return if hasattr(model, "__log_epoch__"): model.__log_epoch__(writer, epoch, **kwargs) if debug: for i in range(model.n_indices): pred_im = resolve_parametrizations(model, jnp.array(i))() log_image(writer, f"4_model/pred_dif_pat/{i}", pred_im, epoch, gamma=diff_pat_gamma) else: pred_im = resolve_parametrizations(model, jnp.array(0))() log_image(writer, "4_model/pred_dif_pat/0", pred_im, epoch, gamma=diff_pat_gamma) if epoch_total_loss is not None: writer.add_scalar("0_loss/0_loss_total", epoch_total_loss, epoch) if all_epoch_losses is not None: for loss in all_epoch_losses: value = np.array(loss.value) N = value.shape[0] for i, v in enumerate(value): writer.add_scalar(f"0_loss/1_loss_terms/{loss.tag}", v, epoch * N + i) writer.flush() # Ensure that everything is written to disk
[docs] def log_on_batch_end( writer: SummaryWriter, current_losses: float, model: ImagePredictionModel, step: int, ) -> None: """Log scalar loss and optional model-level batch metrics to TensorBoard. Args: writer: TensorBoard `SummaryWriter` instance. current_losses: Loss value for the current batch. model: Model (checked for a ``__log_batch__`` hook). step: Global step counter. """ writer.add_scalar("0_loss", current_losses, step) if hasattr(model, "__log_batch__"): model.__log_batch__(writer, step)
[docs] def tensorboard_to_hdf5(tb_dir: Union[str, pathlib.Path], out_filepath: Union[str, pathlib.Path]) -> None: """Convert TensorBoard event files to a single HDF5 archive. Reads scalars, histograms, text tensors, and images from the TensorBoard log directory and writes them into a structured HDF5 file for archival and post-hoc analysis. Args: tb_dir: Path to the directory containing TensorBoard event files. out_filepath: Destination HDF5 file path. """ tb_dir = pathlib.Path(tb_dir) out_file = pathlib.Path(out_filepath) ea = EventAccumulator(str(tb_dir)) ea.Reload() def deduplicate_by_wall_time(events: Any) -> list: # noqa: ANN401 unique = {} for e in events: unique[e.wall_time] = e return sorted(unique.values(), key=lambda e: e.wall_time) with h5py.File(out_file, "w") as f: # ------------------ Scalars ------------------ scalars_grp = f.create_group("scalars") for tag in ea.Tags().get("scalars", []): events = deduplicate_by_wall_time(ea.Scalars(tag)) if not events: continue values = np.array([e.value for e in events]) steps = np.array([e.step for e in events]) wall_time = np.array([e.wall_time for e in events]) tag_grp = scalars_grp.create_group(tag) tag_grp.create_dataset("value", data=values) tag_grp.create_dataset("step", data=steps) tag_grp.create_dataset("wall_time", data=wall_time) # ------------------ Histograms ------------------ hist_grp = f.create_group("histograms") for tag in ea.Tags().get("histograms", []): events = ea.Histograms(tag) # Do NOT deduplicate; keep all events if not events: continue # Prepare each attribute as a stacked array # TODO fix: only a single histogram is outputted for some reason tag_grp = hist_grp.create_group(tag) tag_grp.create_dataset("min", data=np.array([e.histogram_value.min for e in events])) tag_grp.create_dataset("max", data=np.array([e.histogram_value.max for e in events])) tag_grp.create_dataset("sum", data=np.array([e.histogram_value.sum for e in events])) tag_grp.create_dataset("num", data=np.array([e.histogram_value.num for e in events])) tag_grp.create_dataset("sum_squares", data=np.array([e.histogram_value.sum_squares for e in events])) tag_grp.create_dataset("bucket", data=np.stack([np.array(e.histogram_value.bucket) for e in events])) tag_grp.create_dataset( "bucket_limit", data=np.stack([np.array(e.histogram_value.bucket_limit) for e in events]) ) tag_grp.create_dataset("step", data=np.array([e.step for e in events])) tag_grp.create_dataset("wall_time", data=np.array([e.wall_time for e in events])) # ------------------ Text ------------------ text_grp = f.create_group("text") for tag in ea.Tags().get("tensors", []): if not tag.endswith("text_summary"): continue events = deduplicate_by_wall_time(ea.Tensors(tag)) if not events: continue serialized = [e.tensor_proto.SerializeToString() for e in events] tag_grp = text_grp.create_group(tag) tag_grp.create_dataset("data", data=np.array(serialized, dtype="S")) # ------------------ Images ------------------ img_grp = f.create_group("images") for tag in ea.Tags().get("images", []): events = deduplicate_by_wall_time(ea.Images(tag)) if not events: continue imgs = [] for e in events: img = Image.open(io.BytesIO(e.encoded_image_string)) imgs.append(np.array(img)) imgs = np.stack(imgs) # stack along first axis tag_grp = img_grp.create_group(tag) tag_grp.create_dataset("data", data=imgs)