agedi.data.callbacks¶
Classes¶
Logs the total gradient norm of the score-model parameters before each optimizer step. |
|
Prints epoch-level training progress to stdout at a configurable interval. |
|
Manages hyperparameter logging and populates the TensorBoard hp_metric panel. |
|
Lightning callback that advances the dataset through training phases. |
Functions¶
|
Recursively flatten a nested dict into a flat dict with dotted keys. |
Module Contents¶
- class agedi.data.callbacks.GradNormLogger(log_every_n_steps: int = 50)¶
Bases:
lightning.pytorch.callbacks.CallbackLogs the total gradient norm of the score-model parameters before each optimizer step.
- Parameters:
log_every_n_steps – Log the gradient norm every this many optimizer steps (default:
50). Set to1to log every step.
- log_every_n_steps = 50¶
- on_before_optimizer_step(trainer, pl_module, optimizer)¶
- class agedi.data.callbacks.EpochProgressPrinter(print_epoch_interval: int = 10)¶
Bases:
lightning.pytorch.callbacks.CallbackPrints epoch-level training progress to stdout at a configurable interval.
- Parameters:
print_epoch_interval – Print a summary line every this many epochs (default: 10).
- print_epoch_interval = 10¶
- _fit_start_time: float = 0.0¶
- on_fit_start(trainer, pl_module)¶
- on_validation_epoch_end(trainer, pl_module)¶
- on_fit_end(trainer, pl_module)¶
- agedi.data.callbacks._flatten_hparams(d: dict, prefix: str = '', sep: str = '/') dict¶
Recursively flatten a nested dict into a flat dict with dotted keys.
Only scalar values (int, float, str, bool) are kept; lists and nested dicts are flattened recursively. This is required for TensorBoard’s
log_hyperparamswhich only accepts scalar values.- Parameters:
d (dict) – The nested hyperparameter dict to flatten.
prefix (str) – Key prefix to prepend (used in recursion).
sep (str) – Separator between key segments (default
"/").
- Returns:
Flat dict with scalar values only.
- Return type:
dict
- class agedi.data.callbacks.HParamsMetricLogger(hparams: Dict | None = None)¶
Bases:
lightning.pytorch.callbacks.CallbackManages hyperparameter logging and populates the TensorBoard hp_metric panel.
When a full
hparamsdict is provided (including training metadata such asdistribution,prior,sde,conditioning,batch_size, etc.) it is written tohparams.yamlat training start, complementing the baseline written byon_fit_start(). When no dict is provided the callback falls back to callingpl_module.get_hparams(), which returns only the model-architecture config.For non-TensorBoard loggers (e.g. WandB) the resolved hparams dict is forwarded to
log_hyperparamsat training start.- Parameters:
hparams – Full hyperparameter dictionary to log (architecture + training metadata). When
Nonethe callback resolves hparams frompl_module.get_hparams().
- _hparams = None¶
- _resolve_hparams(pl_module) Dict¶
- _TB_EXCLUDE_KEYS¶
- _flatten_for_tb(resolved: dict) dict¶
Flatten hparams for TensorBoard, excluding keys that are not useful there.
- on_train_start(trainer, pl_module)¶
- on_validation_epoch_end(trainer, pl_module)¶
- on_fit_end(trainer, pl_module)¶
- class agedi.data.callbacks.TrainingPhase(n_phases: int, epochs_per_phase: List[int], **kwargs)¶
Bases:
lightning.pytorch.callbacks.CallbackLightning callback that advances the dataset through training phases.
Each phase can use a different set of data augmentation transforms (e.g. supercell repeats). The callback monitors epoch count and calls
set_phase()on the datamodule when it is time to move to the next phase.- n_phases¶
- epochs_per_phase¶
- epoch_counter = 0¶
- current_phase = 0¶
- _prepare_epoch(trainer: lightning.Trainer, model: lightning.LightningModule) None¶
Advance to the next training phase if enough epochs have elapsed.
Called at the end of each validation epoch. When the epoch counter reaches the threshold for the current phase, the datamodule is instructed to switch to the next phase via
set_phase().- Parameters:
trainer (Trainer) – The active Lightning trainer.
model (LightningModule) – The model being trained (unused, required by Lightning callback API).
- on_validation_end(trainer: lightning.Trainer, model: lightning.LightningModule) None¶
Hook called by Lightning at the end of each validation epoch.
Delegates to
_prepare_epoch()to check whether the current training phase should advance.- Parameters:
trainer (Trainer) – The active Lightning trainer.
model (LightningModule) – The model being trained.