agedi.api.training

Training orchestration.

Attributes

Functions

create_trainer(→ lightning.Trainer)

Create a Lightning trainer configured for AGeDi.

train(→ lightning.Trainer)

Train a diffusion model and return the trainer used.

train_from_atoms(, sde, SDE] =, conditioning, ...)

Build (or restore), train, and return an AGeDi model from ASE Atoms data.

train_from_config(→ Tuple[Agedi, agedi.data.Dataset, ...)

Train an AGeDi model from a YAML configuration file or dictionary.

Module Contents

agedi.api.training._TRAIN_FROM_ATOMS_KEYS
agedi.api.training._TRAINER_KEYS
agedi.api.training.create_trainer(*, epochs: int = -1, max_time: int | Dict | datetime.timedelta | None = 24, accelerator: str = 'auto', devices: int = 1, logger: str = 'tensorboard', log_dir: str = 'logs', project: str = 'agedi', name: str = 'agedi', log_interval: int = 10, gradient_clip_val: float = 10.0, progress_bar: bool = False, print_epoch_interval: int = 10, log_grad_norm: bool = True, repeat: int | None = None, repeat_epoch: int | None = None, hparams: Dict | None = None, extra_callbacks: List[lightning.pytorch.callbacks.Callback] | None = None) lightning.Trainer

Create a Lightning trainer configured for AGeDi.

Parameters:
  • epochs – Maximum number of training epochs (-1 = unlimited).

  • max_time

    Wall-clock time limit for training. Accepts:

    • int – number of hours (e.g. 24 ≡ 24 hours).

    • dict – Lightning-style mapping, e.g. {"days": 0, "hours": 12, "minutes": 30, "seconds": 0}.

    • datetime.timedelta – a Python timedelta object.

    • None – no time limit.

  • accelerator – Hardware accelerator to use (e.g. "auto", "gpu", "cpu"). Default: "auto".

  • devices – Number of devices to train on. Default: 1.

  • logger – Logging backend: "tensorboard" (default) or "wandb".

  • log_dir – Root directory for logs and checkpoints. Default: "logs".

  • project – WandB project name (only used when logger="wandb").

  • name – Experiment display name used by TensorBoard and WandB as the run sub-directory / run name. Default: "agedi".

  • log_interval – How often (in steps) to log metrics. Default: 10.

  • gradient_clip_val – Maximum gradient norm for gradient clipping. Default: 10.0.

  • progress_bar – Whether to show a Lightning progress bar. Default: False.

  • print_epoch_interval – Print a one-line training summary to stdout every this many epochs. Set to 0 to disable. Default: 10.

  • log_grad_norm – Whether to log the total gradient norm during training. Disable for large models where the per-step overhead is undesirable. Default: True.

  • repeat – Number of repetition levels for cell-repeat data augmentation. Must be set together with repeat_epoch. When None (default), no repetition augmentation is applied.

  • repeat_epoch – How many epochs between repetition-level increases. Required when repeat is set.

  • hparams – Hyperparameters dict logged to hparams.yaml via HParamsMetricLogger. When None (default), no extra hyperparameter logging is performed.

  • extra_callbacks – Extra Lightning callbacks to append to the default callback list. When None (default) only the built-in callbacks are used.

Returns:

A configured Trainer ready to call trainer.fit(diffusion, dataset).

Return type:

lightning.Trainer

agedi.api.training.train(diffusion: Agedi, dataset: agedi.data.Dataset, trainer: lightning.Trainer | None = None, ckpt_path: str | pathlib.Path | None = None, **trainer_kwargs) lightning.Trainer

Train a diffusion model and return the trainer used.

Parameters:
  • diffusion – The diffusion model to train.

  • dataset – The dataset to train on.

  • trainer – A pre-configured Lightning Trainer. When None a new trainer is created from trainer_kwargs.

  • ckpt_path – Path to a Lightning checkpoint (.ckpt) to resume training from. When provided the full training state (model weights, optimiser, LR-scheduler, and epoch counter) is restored before fitting. Equivalent to passing ckpt_path to trainer.fit().

  • **trainer_kwargs – Additional keyword arguments forwarded to create_trainer() when trainer is None.

agedi.api.training.train_from_atoms(data: Sequence[ase.Atoms], *, model: str = 'PaiNN', cutoff: float = 6.0, feature_size: int = 64, n_blocks: int = 4, n_rbf: int = 30, noisers: Sequence[str] = ('CellPositions',), sde: str | SDE = 've', conditioning: str = 'none', conditioning_type: str = 'scalar', mask: str = 'none', confinement: Tuple[float, float] | None = None, force_field: bool = False, batch_size: int = 64, train_split: float | int = 0.9, val_split: float | int = 0.1, repeat: int | None = None, canonical_cell: bool = False, lr: float = 0.0001, lr_factor: float = 0.95, lr_patience: int = 100, weight_decay: float = 0.0, eps: float = 1e-05, guidance_weight: float = -1.0, data_path: str | None = None, regressor_data: Sequence[ase.Atoms] | None = None, checkpoint: str | pathlib.Path | None = None, trainer: lightning.Trainer | None = None, n_classes: int | None = None, **trainer_kwargs) Tuple[Agedi, agedi.data.Dataset, lightning.Trainer]

Build (or restore), train, and return an AGeDi model from ASE Atoms data.

When a "Types" noiser is included and no checkpoint is given, the unique element types present in data are automatically detected and a compact type map is built so that the vocabulary size equals the number of distinct element types (plus the absorbing state at index 0). The n_classes parameter can be used to restrict the vocabulary to the n_classes most frequently occurring element types (sorted by atomic number).

Parameters:
  • data – ASE Atoms objects to train on.

  • model – GNN backbone architecture name. Looked up in the model registry; use register_model() to add custom backends. Default: "PaiNN" (SchNetPack PaiNN).

  • cutoff – Neighbour-list cutoff radius in Å. Default: 6.0.

  • feature_size – Embedding / feature dimension. Default: 64.

  • n_blocks – Number of interaction blocks in the GNN backbone. Default: 4.

  • n_rbf – Number of radial basis functions. Default: 30.

  • noisers – Sequence of noiser identifiers. Recognised string identifiers: "Positions", "CellPositions", "ConfinedCellPositions", "Types" (snake_case aliases also accepted). Default: ("CellPositions",).

  • sde – SDE for position noisers. Short aliases: "ve" (default), "vp". Pass an instantiated SDE for full control.

  • conditioning – Per-structure property to condition on (read from atoms.info[conditioning] or atoms.get_<conditioning>()), or "none" for time-only conditioning (default).

  • conditioning_type – Type of the conditioning module: "scalar" (default) or "integer".

  • mask – Atom-masking strategy: "MaskFixed" (freeze atoms tagged with ASE FixAtoms) or "none" (default).

  • confinement – Z-direction confinement bounds (z_min, z_max) in Å. Required when using the "ConfinedCellPositions" noiser.

  • force_field – When True, attach a regressor head (sharing the backbone) that predicts per-atom forces and total energy. Enables force-field guided sampling via ForcefieldGuidanceConfig. The training data must contain DFT (or other) forces and energy. Default: False.

  • batch_size – Mini-batch size used during training. Default: 64.

  • train_split – Fraction or absolute count of structures for the training split. Default: 0.9.

  • val_split – Fraction or absolute count of structures for the validation split. Default: 0.1.

  • repeat – When given, augment the dataset by repeating each structure up to repeat times along the first two cell vectors. Requires repeat_epoch (passed via **trainer_kwargs) to specify how often the repetition level increases.

  • canonical_cell – Store unit cells in canonical lower-triangular form. Default: False.

  • lr – Learning rate. Default: 1e-4.

  • lr_factor – LR-scheduler reduction factor. Default: 0.95.

  • lr_patience – LR-scheduler patience (epochs). Default: 100.

  • weight_decay – Optimiser weight decay. Default: 0.0.

  • eps – Minimum diffusion time value. Default: 1e-5.

  • guidance_weight – Classifier-free guidance weight. Default: -1.0 (disabled).

  • data_path – String path to the training data file; stored in hparams.yaml for reference only. When None, no path metadata is saved.

  • regressor_data – Optional additional ASE Atoms objects used exclusively for training the force-field regressor head. Structures here are never passed through the diffusion loss. Each structure must have an ASE calculator with energy and forces attached.

  • checkpoint – Path to a previously saved run directory (containing hparams.yaml) or directly to a .ckpt checkpoint file. When provided the model architecture and weights are loaded from the checkpoint instead of being built from the architecture parameters (model, cutoff, feature_size, etc.). The full training state (optimiser, LR-scheduler, epoch counter) is also restored so that training continues seamlessly. Supply data to train on new data, or use the original data path to resume on the same dataset.

  • trainer – A pre-configured Lightning Trainer. When None (default) a new trainer is built from **trainer_kwargs.

  • n_classes – Number of element-type classes to use for the Types noiser (not counting the absorbing state at index 0). When None (default), all distinct element types present in data are used. Must not exceed the number of distinct types in the training data. Ignored when checkpoint is provided (the vocabulary is loaded from the checkpoint).

  • **trainer_kwargs – Additional keyword arguments forwarded to create_trainer() when trainer is None. Common keys: epochs, max_time, logger, log_dir, gradient_clip_val, repeat_epoch.

Returns:

The trained diffusion model, the dataset, and the Lightning trainer.

Return type:

Tuple[Agedi, Dataset, Trainer]

agedi.api.training.train_from_config(config: str | pathlib.Path | Dict) Tuple[Agedi, agedi.data.Dataset, lightning.Trainer]

Train an AGeDi model from a YAML configuration file or dictionary.

This is the Hydra-style entry point. The configuration can be provided as:

  • a path to a YAML file (str or Path),

  • a plain Python dict,

  • a Hydra / OmegaConf DictConfig.

The function loads the training data from config["data_path"] (an ASE-readable file) and delegates to train_from_atoms() with the remaining configuration values.

The minimal required key is data_path. All other keys are optional and fall back to the same defaults as train_from_atoms().

A ready-to-edit template is shipped with the package at agedi/conf/train.yaml.

Parameters:

config – Configuration source – a YAML file path, a dict, or an OmegaConf DictConfig.

Returns:

The trained diffusion model, the dataset used, and the Lightning trainer.

Return type:

Tuple[Agedi, Dataset, Trainer]

Examples

Minimal Python usage:

from agedi import train_from_config
diffusion, dataset, trainer = train_from_config("conf/train.yaml")

Programmatic override:

from agedi import train_from_config
cfg = {"data_path": "train.traj", "epochs": 50, "feature_size": 128}
diffusion, _, _ = train_from_config(cfg)