agedi.data.callbacks ==================== .. py:module:: agedi.data.callbacks Classes ------- .. autoapisummary:: agedi.data.callbacks.GradNormLogger agedi.data.callbacks.EpochProgressPrinter agedi.data.callbacks.HParamsMetricLogger agedi.data.callbacks.TrainingPhase Functions --------- .. autoapisummary:: agedi.data.callbacks._flatten_hparams Module Contents --------------- .. py:class:: GradNormLogger(log_every_n_steps: int = 50) Bases: :py:obj:`lightning.pytorch.callbacks.Callback` Logs the total gradient norm of the score-model parameters before each optimizer step. :param log_every_n_steps: Log the gradient norm every this many optimizer steps (default: ``50``). Set to ``1`` to log every step. .. py:attribute:: log_every_n_steps :value: 50 .. py:method:: on_before_optimizer_step(trainer, pl_module, optimizer) .. py:class:: EpochProgressPrinter(print_epoch_interval: int = 10) Bases: :py:obj:`lightning.pytorch.callbacks.Callback` Prints epoch-level training progress to stdout at a configurable interval. :param print_epoch_interval: Print a summary line every this many epochs (default: 10). .. py:attribute:: print_epoch_interval :value: 10 .. py:attribute:: _fit_start_time :type: float :value: 0.0 .. py:method:: on_fit_start(trainer, pl_module) .. py:method:: on_validation_epoch_end(trainer, pl_module) .. py:method:: on_fit_end(trainer, pl_module) .. py:function:: _flatten_hparams(d: dict, prefix: str = '', sep: str = '/') -> dict Recursively flatten a nested dict into a flat dict with dotted keys. Only scalar values (int, float, str, bool) are kept; lists and nested dicts are flattened recursively. This is required for TensorBoard's ``log_hyperparams`` which only accepts scalar values. :param d: The nested hyperparameter dict to flatten. :type d: dict :param prefix: Key prefix to prepend (used in recursion). :type prefix: str :param sep: Separator between key segments (default ``"/"``). :type sep: str :returns: Flat dict with scalar values only. :rtype: dict .. py:class:: HParamsMetricLogger(hparams: Optional[Dict] = None) Bases: :py:obj:`lightning.pytorch.callbacks.Callback` Manages hyperparameter logging and populates the TensorBoard hp_metric panel. When a full ``hparams`` dict is provided (including training metadata such as ``distribution``, ``prior``, ``sde``, ``conditioning``, ``batch_size``, etc.) it is written to ``hparams.yaml`` at training start, complementing the baseline written by :meth:`~agedi.diffusion.diffusion.Diffusion.on_fit_start`. When no dict is provided the callback falls back to calling ``pl_module.get_hparams()``, which returns only the model-architecture config. For non-TensorBoard loggers (e.g. WandB) the resolved hparams dict is forwarded to ``log_hyperparams`` at training start. :param hparams: Full hyperparameter dictionary to log (architecture + training metadata). When ``None`` the callback resolves hparams from ``pl_module.get_hparams()``. .. py:attribute:: _hparams :value: None .. py:method:: _resolve_hparams(pl_module) -> Dict .. py:attribute:: _TB_EXCLUDE_KEYS .. py:method:: _flatten_for_tb(resolved: dict) -> dict Flatten hparams for TensorBoard, excluding keys that are not useful there. .. py:method:: on_train_start(trainer, pl_module) .. py:method:: on_validation_epoch_end(trainer, pl_module) .. py:method:: on_fit_end(trainer, pl_module) .. py:class:: TrainingPhase(n_phases: int, epochs_per_phase: List[int], **kwargs) Bases: :py:obj:`lightning.pytorch.callbacks.Callback` Lightning callback that advances the dataset through training phases. Each phase can use a different set of data augmentation transforms (e.g. supercell repeats). The callback monitors epoch count and calls :meth:`~agedi.data.Dataset.set_phase` on the datamodule when it is time to move to the next phase. .. py:attribute:: n_phases .. py:attribute:: epochs_per_phase .. py:attribute:: epoch_counter :value: 0 .. py:attribute:: current_phase :value: 0 .. py:method:: _prepare_epoch(trainer: lightning.Trainer, model: lightning.LightningModule) -> None Advance to the next training phase if enough epochs have elapsed. Called at the end of each validation epoch. When the epoch counter reaches the threshold for the current phase, the datamodule is instructed to switch to the next phase via :meth:`~agedi.data.Dataset.set_phase`. :param trainer: The active Lightning trainer. :type trainer: Trainer :param model: The model being trained (unused, required by Lightning callback API). :type model: LightningModule .. py:method:: on_validation_end(trainer: lightning.Trainer, model: lightning.LightningModule) -> None Hook called by Lightning at the end of each validation epoch. Delegates to :meth:`_prepare_epoch` to check whether the current training phase should advance. :param trainer: The active Lightning trainer. :type trainer: Trainer :param model: The model being trained. :type model: LightningModule