Source code for ptyrax.experiment

import copy
import datetime
import itertools
import logging
import random
import re
import shutil
import subprocess
from pathlib import Path
from typing import Any, Dict, Generator, List

import numpy as np
from ruamel.yaml import YAML


# ----------------------
# Main experiment mode
# ----------------------
[docs] def run_experiment(experiment_name: str, dry_run: bool = False) -> None: """Run or queue a DVC experiment sweep from a config with ``!range`` tags. Parses the DVC stage’s config files, expands ``!range`` tags into a Cartesian product of parameter combinations, and queues one DVC experiment per combination. Args: experiment_name: Name of the DVC stage to run. dry_run: If True, log what would be queued without actually running. Raises: ValueError: If the experiment name is invalid or multiple dynamic config files are found. """ if not _validate_experiment_name(experiment_name): raise ValueError( f"Experiment name '{experiment_name}' is invalid. " "Only alphanumeric characters, underscores and hyphens are allowed (max length 32)." ) sweep_id = _generate_sweep_id() logging.info(f"Generated sweep ID: {sweep_id}") config_paths = _get_config_files_from_dvc_stage(experiment_name) dynamic_cfg_path = None dynamic_ranges = None for cfg_path in config_paths: with open(cfg_path) as f: cfg = yaml.load(f) if r := _expand_ranges(cfg): if dynamic_cfg_path is not None: raise ValueError("Only one config file with ranges is allowed.") dynamic_cfg_path = cfg_path dynamic_ranges = r if dynamic_cfg_path is None or dynamic_ranges is None: logging.error("No config file with !range found.") return with open(dynamic_cfg_path) as f: original_cfg = yaml.load(f) _save_sweep_metadata(sweep_id, experiment_name, dynamic_ranges) shutil.move(dynamic_cfg_path, f"{dynamic_cfg_path}.backup") try: axis_order = list(dynamic_ranges.keys()) axis_values = _build_axis_values(dynamic_ranges) all_combos = _generate_combinations_with_confirmation(axis_order, axis_values) _queue_experiment_variants( all_combos=all_combos, axis_order=axis_order, original_cfg=original_cfg, dynamic_cfg_path=dynamic_cfg_path, experiment_name=experiment_name, sweep_id=sweep_id, dry_run=dry_run, ) finally: backup = f"{dynamic_cfg_path}.backup" if Path(backup).exists(): shutil.move(backup, dynamic_cfg_path) logging.info(f"Queued {len(all_combos)} experiments (dry_run={dry_run}).")
def _build_axis_values(dynamic_ranges: Dict[str, Dict[str, List[Any]]]) -> Dict[str, List[Dict[str, Any]]]: """Validate and build axis value structures from dynamic_ranges.""" axis_values: Dict[str, List[Dict[str, Any]]] = {} for axis, entries in dynamic_ranges.items(): lens = {len(v) for v in entries.values()} if len(lens) != 1: raise ValueError( f"All !range values in axis '{axis}' must have equal length. " f"Got lengths: {[f'{k}: {len(v)}' for k, v in entries.items()]}" ) length = lens.pop() axis_values[axis] = [{keypath: arr[i] for keypath, arr in entries.items()} for i in range(length)] return axis_values def _generate_combinations_with_confirmation( axis_order: List[str], axis_values: Dict[str, List[Dict[str, Any]]], ) -> List[tuple[Dict[str, Any], ...]]: """Compute Cartesian product across axes and confirm with the user if large.""" all_combos = list(itertools.product(*[axis_values[a] for a in axis_order])) logging.info(f"Generated {len(all_combos)} experiment variants.") if len(all_combos) > 20: logging.info("Large number of experiments generated; are you sure you want to proceed? (Y/n)?") ans = input().strip().lower() if ans not in ("y", "yes", ""): raise KeyboardInterrupt("Experiment queuing aborted by user.") return all_combos def _queue_experiment_variants( all_combos: List[tuple[Dict[str, Any], ...]], axis_order: List[str], original_cfg: Dict[str, Any], dynamic_cfg_path: str, experiment_name: str, sweep_id: str, dry_run: bool, ) -> None: """Apply parameter combinations, write configs, and queue DVC experiments.""" for combo in all_combos: new_cfg = copy.deepcopy(original_cfg) variable_metadata_parts = [] for d, axis in zip(combo, axis_order): variable_metadata_parts.append(f"{axis}_{list(d.values())[0]}") for keypath, value in d.items(): path = keypath.split(".") _set_nested_value(new_cfg, path, value) variable_metadata = "_".join(variable_metadata_parts) tag = new_cfg.get("__main__", {}).get("main", {}).get("tag", "") _set_nested_value(new_cfg, ["__main__", "main", "tag"], tag + variable_metadata) if "__main__" not in new_cfg: new_cfg["__main__"] = {} if "main" not in new_cfg["__main__"]: new_cfg["__main__"]["main"] = {} new_cfg["__main__"]["main"]["sweep_id"] = sweep_id with open(dynamic_cfg_path, "w") as f: yaml.dump(new_cfg, f) cmd = ["dvc", "exp", "run", experiment_name, "--queue"] if dry_run: logging.info(f"[Dry Run] Would queue experiment {experiment_name} with config: {new_cfg}") logging.info(f"[Dry Run] Command: {cmd}") else: _ = subprocess.run(cmd, check=True, text=True, shell=False) # noqa: S603 logging.info(f"Queued {len(all_combos)} experiments (dry_run={dry_run}).") def _validate_experiment_name(experiment_name: str) -> None: exp_name_format = r"^[a-zA-Z0-9_\-]{1,32}$" return re.match(exp_name_format, experiment_name) is not None def _generate_sweep_id() -> str: """Generate a unique sweep ID with timestamp and random word.""" timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") # Simple random words for readability (~100k combinations from ~316 adjectives × ~316 nouns) # fmt: off adjectives = [ "quick", "lazy", "happy", "sleepy", "brave", "clever", "wild", "calm", "bright", "dark", "bold", "cool", "crisp", "deft", "eager", "fair", "firm", "glad", "grand", "keen", "kind", "lean", "mild", "neat", "open", "pale", "pure", "rare", "safe", "soft", "tall", "tame", "vast", "warm", "wise", "zany", "agile", "alert", "ample", "avid", "basic", "brisk", "civic", "clean", "clear", "close", "cozy", "daily", "dense", "dizzy", "dorky", "dusty", "early", "empty", "equal", "exact", "extra", "faint", "fancy", "fatal", "fiery", "final", "first", "fixed", "fleet", "foggy", "fresh", "frugal", "funny", "fuzzy", "giant", "giddy", "given", "great", "green", "gross", "gusty", "handy", "harsh", "hasty", "hefty", "humid", "ideal", "inner", "ionic", "ivory", "jazzy", "jolly", "jumpy", "juicy", "large", "legal", "level", "light", "lofty", "loose", "loyal", "lucky", "lunar", "lusty", "magic", "major", "maple", "meek", "merry", "minor", "misty", "modal", "moist", "moral", "muddy", "murky", "muted", "naive", "naval", "noble", "noisy", "novel", "oaken", "olive", "outer", "papal", "pasty", "perky", "petty", "plain", "plush", "polar", "prime", "proud", "quiet", "rainy", "rapid", "ready", "regal", "rigid", "rocky", "roomy", "rough", "round", "royal", "rural", "rusty", "salty", "sandy", "sharp", "sheer", "shiny", "short", "silky", "slick", "smart", "smoky", "snowy", "solar", "solid", "sonic", "sorry", "spare", "spicy", "stark", "steep", "stiff", "still", "stony", "stout", "sunny", "super", "sweet", "swift", "tangy", "tense", "thick", "tight", "timid", "tiny", "tipsy", "total", "tough", "toxic", "tricky", "true", "ultra", "uncut", "undue", "upper", "urban", "usual", "utter", "valid", "vital", "vivid", "vocal", "wacky", "wavy", "weary", "weird", "white", "whole", "witty", "woody", "wordy", "woven", "wrong", "young", "zesty", "zippy", "adept", "alien", "amber", "antic", "balmy", "basal", "bland", "blank", "bleak", "bliss", "blunt", "bonny", "bossy", "bound", "bowed", "bulky", "burly", "catchy", "cheap", "chief", "chill", "civil", "comfy", "coral", "crazy", "curly", "dainty", "dandy", "dewy", "dingy", "dizzy", "downy", "dried", "dumpy", "dusky", "eerie", "elfin", "elite", "every", "fetid", "fishy", "flaky", "flint", "fluid", "foamy", "forte", "frail", "frank", "freed", "froze", "gaunt", "giddy", "glassy", "gooey", "grainy", "grimy", "gruff", "gutsy", "hairy", "hazel", "heady", "heard", "heavy", "huffy", "icing", "inane", "inert", "irate", "itchy", "ivory", "jaded", "jumbo", "kinky" ] nouns = [ "fox", "dog", "cat", "bird", "fish", "bear", "wolf", "lion", "tiger", "eagle", "hawk", "deer", "hare", "dove", "crow", "swan", "goat", "seal", "frog", "moth", "crab", "clam", "newt", "mole", "vole", "wren", "lark", "lynx", "pike", "bass", "carp", "wasp", "toad", "slug", "snail", "crane", "finch", "mouse", "otter", "panda", "raven", "robin", "shark", "snake", "squid", "stork", "trout", "whale", "zebra", "bison", "camel", "cobra", "coral", "egret", "falcon", "gecko", "heron", "hyena", "iguana", "koala", "llama", "moose", "okapi", "orca", "osprey", "owl", "perch", "quail", "rhino", "sable", "stoat", "swift", "tapir", "viper", "yak", "cedar", "birch", "maple", "oak", "pine", "aspen", "elm", "ivy", "fern", "moss", "reed", "sage", "thyme", "basil", "mint", "cliff", "creek", "delta", "dune", "fjord", "glade", "gorge", "grove", "knoll", "marsh", "oasis", "plain", "ridge", "shoal", "slope", "thorn", "trail", "vale", "bloom", "brush", "cloud", "comet", "flare", "flame", "frost", "gleam", "grain", "gust", "jewel", "light", "mist", "ocean", "orbit", "pearl", "plume", "prism", "pulse", "quartz", "river", "shell", "shore", "spark", "spire", "spray", "star", "steam", "stone", "storm", "surge", "tide", "torch", "tower", "twist", "vault", "wave", "wind", "anvil", "arrow", "badge", "blade", "bolt", "charm", "crest", "crown", "drum", "flint", "forge", "glyph", "grain", "guard", "hatch", "horn", "ingot", "lance", "lever", "medal", "notch", "panel", "pivot", "plate", "prong", "quill", "rivet", "rune", "shard", "spear", "spoke", "staff", "stake", "strut", "tally", "token", "wedge", "wick", "arch", "basin", "beam", "cairn", "chord", "clasp", "cog", "dome", "edge", "facet", "flange", "grate", "hasp", "joint", "keel", "latch", "ledge", "link", "loom", "mesh", "node", "petal", "plank", "rail", "ramp", "ring", "rotor", "scale", "seam", "shaft", "slab", "slot", "span", "spool", "stem", "strut", "tread", "truss", "valve", "vane", "weld", "yoke", "agate", "amber", "beryl", "chalk", "clay", "cobalt", "crystal", "ember", "flint", "garnet", "glass", "gold", "jade", "jet", "lapis", "lead", "mica", "nickel", "onyx", "opal", "ore", "ruby", "rust", "sand", "silk", "slate", "steel", "tin", "topaz", "zinc", "atlas", "axiom", "canon", "cipher", "creed", "dogma", "draft", "edict", "epoch", "ethos", "flux", "focus", "glyph", "helix", "index", "locus", "maxim", "motif", "nexus", "pivot", "proxy", "realm", "scope", "sigma", "theta", "torus", "triad", "unity", "vapor", "verge", "vigor", "volta", "zenith", "zonal" ] # fmt: on random_word = f"{random.choice(adjectives)}-{random.choice(nouns)}" # noqa: S311 return f"{timestamp}_{random_word}" def _save_sweep_metadata(sweep_id: str, stage_name: str, dynamic_ranges: Dict[str, Dict[str, List[Any]]]) -> None: """Save sweep metadata to dvc_sweeps.yaml registry.""" registry_path = Path("dvc_sweeps.yaml") # Load existing registry or create new if registry_path.exists(): with open(registry_path) as f: registry = yaml.load(f) or {} else: registry = {} # Extract sweep axis definitions (tag name + first value) sweep_axes = {} for axis, entries in dynamic_ranges.items(): # Get first value from first entry first_entry = next(iter(entries.values())) sweep_axes[axis] = first_entry[0] if first_entry else None # Add this sweep to registry registry[sweep_id] = { "stage": stage_name, "timestamp": datetime.datetime.now().isoformat(), "sweep_axes": sweep_axes, "parameter_ranges": {axis: list(entries.keys()) for axis, entries in dynamic_ranges.items()}, } # Save registry with open(registry_path, "w") as f: yaml.dump(registry, f) # ---------------------- # Custom YAML tag !range # ----------------------
[docs] class RangeTag: """Custom YAML tag representing a parameter sweep range. Used in experiment config files with the ``!range`` YAML tag to specify values that should be swept over in a parameter study. Attributes: axis: Name of the sweep axis (used for grouping), or None for auto. values: List of values to sweep over. """ def __init__(self, axis: str | None, values: List[Any]) -> None: self.axis = axis self.values = values
yaml = YAML() yaml.allow_duplicate_keys = True # Register !range constructor
[docs] def range_constructor(loader: Any, node: Any) -> RangeTag: # noqa: ANN401 values = loader.construct_sequence(node) return ( RangeTag(axis=values[0], values=values[1:]) if isinstance(values, list) and len(values) > 1 and isinstance(values[0], str) else RangeTag(axis=None, values=values) )
yaml.constructor.add_constructor("!range", range_constructor) def _range_representer(dumper: Any, data: RangeTag): # noqa: ANN401 if data.axis is None: return dumper.represent_sequence("!range", data.values) else: return dumper.represent_sequence("!range", [data.axis] + data.values) yaml.representer.add_representer(RangeTag, _range_representer) # ---------------------- # Helper functions # ---------------------- def _set_nested_value(cfg: Dict[str, Any], path: List[str], value: Any) -> None: # noqa: ANN401 """Set a value in a nested dictionary by following a key path. Args: cfg: Nested dictionary to modify in-place. path: List of keys forming the path to the target. value: Value to assign at the terminal key. """ for p in path[:-1]: cfg = cfg[p] cfg[path[-1]] = value # ---------------------- # Range expansion with axis grouping # ---------------------- def _expand_ranges(obj: Any, prefix: List[str] | None = None) -> Dict[str, Dict[str, List[Any]]]: # noqa: ANN401 """ Returns dict: axis -> { keypath -> list_of_values } Rules: • Multiple ranges with the same axis are zipped (must have equal length). • Different axes form independent groups whose tensor product is taken. """ if prefix is None: prefix = [] grouped: Dict[str, Dict[str, List[Any]]] = {} def add(axis: str, key: str, vals: List[Any]): grouped.setdefault(axis, {})[key] = vals # Dict recursion if isinstance(obj, dict): for k, v in obj.items(): sub = _expand_ranges(v, prefix + [k]) for axis, d in sub.items(): for kk, vv in d.items(): add(axis, kk, vv) return grouped # List: detect RangeTag within if isinstance(obj, list): tags = [x for x in obj if isinstance(x, RangeTag)] if tags: base = [x for x in obj if not isinstance(x, RangeTag)] for tag in tags: axis = tag.axis or f"__axis_{'.'.join(prefix)}" # All zipped: expanded = [base + [v] for v in tag.values] add(axis, ".".join(prefix), expanded) return grouped else: # Plain list: recurse elements for idx, v in enumerate(obj): sub = _expand_ranges(v, prefix + [str(idx)]) for axis, d in sub.items(): for kk, vv in d.items(): add(axis, kk, vv) return grouped # Single RangeTag if isinstance(obj, RangeTag): axis = obj.axis or f"__axis_{'.'.join(prefix)}" add(axis, ".".join(prefix), obj.values) return grouped # ---------------------- # Load DVC stage config # ---------------------- def _get_config_files_from_dvc_stage(stage_name: str) -> List[str]: """Extract YAML config file paths referenced by a DVC stage. Reads ``dvc.yaml`` and collects paths from the stage’s ``params`` and ``deps`` sections. Args: stage_name: Name of the DVC stage. Returns: List of YAML file paths used by the stage. """ with open("dvc.yaml") as f: full_config = yaml.load(f) stage = full_config.get("stages", {}).get(stage_name, {}) params = stage.get("params", []) deps = stage.get("deps", []) yaml_files: List[str] = [] # Params entries are dicts: { file.yaml: null or subkeys } for p in params: if isinstance(p, dict): path = list(p.keys())[0] if path.endswith((".yaml", ".yml")): yaml_files.append(path) yaml_files.extend(d for d in deps if isinstance(d, str) and d.endswith((".yaml", ".yml"))) return yaml_files # ---------------------- # HDF5 pooled experiment utilities # ----------------------
[docs] def iter_experiments(hdf5_path: str | Path) -> Generator[tuple[str, Any], None, None]: """Iterate over experiments in a pooled HDF5 file. Args: hdf5_path: Path to pooled experiment HDF5 file Yields: Tuple of (experiment_name, experiment_group) Example: >>> for exp_name, exp_group in iter_experiments("experiment_sweep_xyz.hdf5"): ... print(f"Processing {exp_name}") ... losses = exp_group["training/scalars/0_loss/0_loss_total/value"][()] """ import h5py with h5py.File(hdf5_path, "r") as f: # Get experiment names from index exp_names = f["index/exp_names"][()] # Convert bytes to strings if needed if exp_names.dtype.kind == "S": exp_names = [name.decode() for name in exp_names] for exp_name in exp_names: if exp_name in f: yield exp_name, f[exp_name]
[docs] def get_experiment_index(hdf5_path: str | Path) -> dict: """Load the experiment index as a dictionary. Args: hdf5_path: Path to pooled experiment HDF5 file Returns: Dict with keys: 'exp_names', 'indices', 'param_{name}' for each parameter Example: >>> index = get_experiment_index("experiment_sweep_xyz.hdf5") >>> print(index["exp_names"]) >>> print(index["param_angle"]) """ import h5py index_data = {} with h5py.File(hdf5_path, "r") as f: index_group = f["index"] for key in index_group.keys(): data = index_group[key][()] # Convert bytes to strings if needed if data.dtype.kind == "S": data = np.array([item.decode() if isinstance(item, bytes) else item for item in data]) index_data[key] = data return index_data
[docs] def load_experiment_by_name(hdf5_path: str | Path, exp_name: str) -> dict: """Load a specific experiment by name. Args: hdf5_path: Path to pooled experiment HDF5 file. exp_name: Experiment name to load. Returns: Nested dict with experiment data (scalars, images, etc.). Example: >>> exp = load_experiment_by_name("experiment_sweep_xyz.hdf5", "leaky-lulu") >>> losses = exp["training"]["scalars"]["0_loss"]["0_loss_total"]["value"] """ import h5py def _load_group(group: Any): # noqa: ANN401 """Recursively load HDF5 group to dict.""" result = {} for key in group.keys(): if isinstance(group[key], h5py.Group): result[key] = _load_group(group[key]) else: result[key] = group[key][()] # Also load attributes result["_attrs"] = dict(group.attrs) return result with h5py.File(hdf5_path, "r") as f: if exp_name not in f: raise KeyError(f"Experiment '{exp_name}' not found in {hdf5_path}") return _load_group(f[exp_name])
[docs] def get_sweep_axes(hdf5_path: str | Path) -> dict: """Get sweep axis information from pooled HDF5. Args: hdf5_path: Path to pooled experiment HDF5 file Returns: Dict with sweep metadata (sweep_id, n_experiments, parameter names) Example: >>> axes = get_sweep_axes("experiment_sweep_xyz.hdf5") >>> print(axes["sweep_id"]) >>> print(axes["param_names"]) """ import h5py with h5py.File(hdf5_path, "r") as f: # Get attributes sweep_info = { "sweep_id": f.attrs.get("sweep_id", "unknown"), "n_experiments": f.attrs.get("n_experiments", 0), } # Get parameter names from index param_names = [key.replace("param_", "") for key in f["index"].keys() if key.startswith("param_")] sweep_info["param_names"] = param_names return sweep_info