agedi.data.callbacks

Classes

GradNormLogger

Logs the total gradient norm of the score-model parameters before each optimizer step.

EpochProgressPrinter

Prints epoch-level training progress to stdout at a configurable interval.

HParamsMetricLogger

Manages hyperparameter logging and populates the TensorBoard hp_metric panel.

TrainingPhase

Lightning callback that advances the dataset through training phases.

Functions

_flatten_hparams(→ dict)

Recursively flatten a nested dict into a flat dict with dotted keys.

Module Contents

class agedi.data.callbacks.GradNormLogger(log_every_n_steps: int = 50)

Bases: lightning.pytorch.callbacks.Callback

Logs the total gradient norm of the score-model parameters before each optimizer step.

Parameters:

log_every_n_steps – Log the gradient norm every this many optimizer steps (default: 50). Set to 1 to log every step.

log_every_n_steps = 50
on_before_optimizer_step(trainer, pl_module, optimizer)
class agedi.data.callbacks.EpochProgressPrinter(print_epoch_interval: int = 10)

Bases: lightning.pytorch.callbacks.Callback

Prints epoch-level training progress to stdout at a configurable interval.

Parameters:

print_epoch_interval – Print a summary line every this many epochs (default: 10).

print_epoch_interval = 10
_fit_start_time: float = 0.0
on_fit_start(trainer, pl_module)
on_validation_epoch_end(trainer, pl_module)
on_fit_end(trainer, pl_module)
agedi.data.callbacks._flatten_hparams(d: dict, prefix: str = '', sep: str = '/') dict

Recursively flatten a nested dict into a flat dict with dotted keys.

Only scalar values (int, float, str, bool) are kept; lists and nested dicts are flattened recursively. This is required for TensorBoard’s log_hyperparams which only accepts scalar values.

Parameters:
  • d (dict) – The nested hyperparameter dict to flatten.

  • prefix (str) – Key prefix to prepend (used in recursion).

  • sep (str) – Separator between key segments (default "/").

Returns:

Flat dict with scalar values only.

Return type:

dict

class agedi.data.callbacks.HParamsMetricLogger(hparams: Dict | None = None)

Bases: lightning.pytorch.callbacks.Callback

Manages hyperparameter logging and populates the TensorBoard hp_metric panel.

When a full hparams dict is provided (including training metadata such as distribution, prior, sde, conditioning, batch_size, etc.) it is written to hparams.yaml at training start, complementing the baseline written by on_fit_start(). When no dict is provided the callback falls back to calling pl_module.get_hparams(), which returns only the model-architecture config.

For non-TensorBoard loggers (e.g. WandB) the resolved hparams dict is forwarded to log_hyperparams at training start.

Parameters:

hparams – Full hyperparameter dictionary to log (architecture + training metadata). When None the callback resolves hparams from pl_module.get_hparams().

_hparams = None
_resolve_hparams(pl_module) Dict
_TB_EXCLUDE_KEYS
_flatten_for_tb(resolved: dict) dict

Flatten hparams for TensorBoard, excluding keys that are not useful there.

on_train_start(trainer, pl_module)
on_validation_epoch_end(trainer, pl_module)
on_fit_end(trainer, pl_module)
class agedi.data.callbacks.TrainingPhase(n_phases: int, epochs_per_phase: List[int], **kwargs)

Bases: lightning.pytorch.callbacks.Callback

Lightning callback that advances the dataset through training phases.

Each phase can use a different set of data augmentation transforms (e.g. supercell repeats). The callback monitors epoch count and calls set_phase() on the datamodule when it is time to move to the next phase.

n_phases
epochs_per_phase
epoch_counter = 0
current_phase = 0
_prepare_epoch(trainer: lightning.Trainer, model: lightning.LightningModule) None

Advance to the next training phase if enough epochs have elapsed.

Called at the end of each validation epoch. When the epoch counter reaches the threshold for the current phase, the datamodule is instructed to switch to the next phase via set_phase().

Parameters:
  • trainer (Trainer) – The active Lightning trainer.

  • model (LightningModule) – The model being trained (unused, required by Lightning callback API).

on_validation_end(trainer: lightning.Trainer, model: lightning.LightningModule) None

Hook called by Lightning at the end of each validation epoch.

Delegates to _prepare_epoch() to check whether the current training phase should advance.

Parameters:
  • trainer (Trainer) – The active Lightning trainer.

  • model (LightningModule) – The model being trained.