agedi.api.training¶
Training orchestration.
Attributes¶
Functions¶
|
Create a Lightning trainer configured for AGeDi. |
|
Train a diffusion model and return the trainer used. |
|
Build (or restore), train, and return an AGeDi model from ASE Atoms data. |
|
Train an AGeDi model from a YAML configuration file or dictionary. |
Module Contents¶
- agedi.api.training._TRAIN_FROM_ATOMS_KEYS¶
- agedi.api.training._TRAINER_KEYS¶
- agedi.api.training.create_trainer(*, epochs: int = -1, max_time: int | Dict | datetime.timedelta | None = 24, accelerator: str = 'auto', devices: int = 1, logger: str = 'tensorboard', log_dir: str = 'logs', project: str = 'agedi', name: str = 'agedi', log_interval: int = 10, gradient_clip_val: float = 10.0, progress_bar: bool = False, print_epoch_interval: int = 10, log_grad_norm: bool = True, repeat: int | None = None, repeat_epoch: int | None = None, hparams: Dict | None = None, extra_callbacks: List[lightning.pytorch.callbacks.Callback] | None = None) lightning.Trainer¶
Create a Lightning trainer configured for AGeDi.
- Parameters:
epochs – Maximum number of training epochs (
-1= unlimited).max_time –
Wall-clock time limit for training. Accepts:
int– number of hours (e.g.24≡ 24 hours).dict– Lightning-style mapping, e.g.{"days": 0, "hours": 12, "minutes": 30, "seconds": 0}.datetime.timedelta– a Python timedelta object.None– no time limit.
accelerator – Hardware accelerator to use (e.g.
"auto","gpu","cpu"). Default:"auto".devices – Number of devices to train on. Default:
1.logger – Logging backend:
"tensorboard"(default) or"wandb".log_dir – Root directory for logs and checkpoints. Default:
"logs".project – WandB project name (only used when
logger="wandb").name – Experiment display name used by TensorBoard and WandB as the run sub-directory / run name. Default:
"agedi".log_interval – How often (in steps) to log metrics. Default:
10.gradient_clip_val – Maximum gradient norm for gradient clipping. Default:
10.0.progress_bar – Whether to show a Lightning progress bar. Default:
False.print_epoch_interval – Print a one-line training summary to stdout every this many epochs. Set to
0to disable. Default:10.log_grad_norm – Whether to log the total gradient norm during training. Disable for large models where the per-step overhead is undesirable. Default:
True.repeat – Number of repetition levels for cell-repeat data augmentation. Must be set together with repeat_epoch. When
None(default), no repetition augmentation is applied.repeat_epoch – How many epochs between repetition-level increases. Required when repeat is set.
hparams – Hyperparameters dict logged to
hparams.yamlviaHParamsMetricLogger. WhenNone(default), no extra hyperparameter logging is performed.extra_callbacks – Extra Lightning callbacks to append to the default callback list. When
None(default) only the built-in callbacks are used.
- Returns:
A configured
Trainerready to calltrainer.fit(diffusion, dataset).- Return type:
lightning.Trainer
- agedi.api.training.train(diffusion: Agedi, dataset: agedi.data.Dataset, trainer: lightning.Trainer | None = None, ckpt_path: str | pathlib.Path | None = None, **trainer_kwargs) lightning.Trainer¶
Train a diffusion model and return the trainer used.
- Parameters:
diffusion – The diffusion model to train.
dataset – The dataset to train on.
trainer – A pre-configured Lightning
Trainer. WhenNonea new trainer is created from trainer_kwargs.ckpt_path – Path to a Lightning checkpoint (
.ckpt) to resume training from. When provided the full training state (model weights, optimiser, LR-scheduler, and epoch counter) is restored before fitting. Equivalent to passingckpt_pathtotrainer.fit().**trainer_kwargs – Additional keyword arguments forwarded to
create_trainer()when trainer isNone.
- agedi.api.training.train_from_atoms(data: Sequence[ase.Atoms], *, model: str = 'PaiNN', cutoff: float = 6.0, feature_size: int = 64, n_blocks: int = 4, n_rbf: int = 30, noisers: Sequence[str] = ('CellPositions',), sde: str | SDE = 've', conditioning: str = 'none', conditioning_type: str = 'scalar', mask: str = 'none', confinement: Tuple[float, float] | None = None, force_field: bool = False, batch_size: int = 64, train_split: float | int = 0.9, val_split: float | int = 0.1, repeat: int | None = None, canonical_cell: bool = False, lr: float = 0.0001, lr_factor: float = 0.95, lr_patience: int = 100, weight_decay: float = 0.0, eps: float = 1e-05, guidance_weight: float = -1.0, data_path: str | None = None, regressor_data: Sequence[ase.Atoms] | None = None, checkpoint: str | pathlib.Path | None = None, trainer: lightning.Trainer | None = None, n_classes: int | None = None, **trainer_kwargs) Tuple[Agedi, agedi.data.Dataset, lightning.Trainer]¶
Build (or restore), train, and return an AGeDi model from ASE Atoms data.
When a
"Types"noiser is included and no checkpoint is given, the unique element types present in data are automatically detected and a compact type map is built so that the vocabulary size equals the number of distinct element types (plus the absorbing state at index 0). Then_classesparameter can be used to restrict the vocabulary to the n_classes most frequently occurring element types (sorted by atomic number).- Parameters:
data – ASE
Atomsobjects to train on.model – GNN backbone architecture name. Looked up in the model registry; use
register_model()to add custom backends. Default:"PaiNN"(SchNetPack PaiNN).cutoff – Neighbour-list cutoff radius in Å. Default:
6.0.feature_size – Embedding / feature dimension. Default:
64.n_blocks – Number of interaction blocks in the GNN backbone. Default:
4.n_rbf – Number of radial basis functions. Default:
30.noisers – Sequence of noiser identifiers. Recognised string identifiers:
"Positions","CellPositions","ConfinedCellPositions","Types"(snake_case aliases also accepted). Default:("CellPositions",).sde – SDE for position noisers. Short aliases:
"ve"(default),"vp". Pass an instantiatedSDEfor full control.conditioning – Per-structure property to condition on (read from
atoms.info[conditioning]oratoms.get_<conditioning>()), or"none"for time-only conditioning (default).conditioning_type – Type of the conditioning module:
"scalar"(default) or"integer".mask – Atom-masking strategy:
"MaskFixed"(freeze atoms tagged with ASEFixAtoms) or"none"(default).confinement – Z-direction confinement bounds
(z_min, z_max)in Å. Required when using the"ConfinedCellPositions"noiser.force_field – When
True, attach a regressor head (sharing the backbone) that predicts per-atom forces and total energy. Enables force-field guided sampling viaForcefieldGuidanceConfig. The training data must contain DFT (or other) forces and energy. Default:False.batch_size – Mini-batch size used during training. Default:
64.train_split – Fraction or absolute count of structures for the training split. Default:
0.9.val_split – Fraction or absolute count of structures for the validation split. Default:
0.1.repeat – When given, augment the dataset by repeating each structure up to
repeattimes along the first two cell vectors. Requiresrepeat_epoch(passed via**trainer_kwargs) to specify how often the repetition level increases.canonical_cell – Store unit cells in canonical lower-triangular form. Default:
False.lr – Learning rate. Default:
1e-4.lr_factor – LR-scheduler reduction factor. Default:
0.95.lr_patience – LR-scheduler patience (epochs). Default:
100.weight_decay – Optimiser weight decay. Default:
0.0.eps – Minimum diffusion time value. Default:
1e-5.guidance_weight – Classifier-free guidance weight. Default:
-1.0(disabled).data_path – String path to the training data file; stored in
hparams.yamlfor reference only. WhenNone, no path metadata is saved.regressor_data – Optional additional ASE Atoms objects used exclusively for training the force-field regressor head. Structures here are never passed through the diffusion loss. Each structure must have an ASE calculator with energy and forces attached.
checkpoint – Path to a previously saved run directory (containing
hparams.yaml) or directly to a.ckptcheckpoint file. When provided the model architecture and weights are loaded from the checkpoint instead of being built from the architecture parameters (model, cutoff, feature_size, etc.). The full training state (optimiser, LR-scheduler, epoch counter) is also restored so that training continues seamlessly. Supply data to train on new data, or use the original data path to resume on the same dataset.trainer – A pre-configured Lightning
Trainer. WhenNone(default) a new trainer is built from**trainer_kwargs.n_classes – Number of element-type classes to use for the
Typesnoiser (not counting the absorbing state at index 0). WhenNone(default), all distinct element types present in data are used. Must not exceed the number of distinct types in the training data. Ignored when checkpoint is provided (the vocabulary is loaded from the checkpoint).**trainer_kwargs – Additional keyword arguments forwarded to
create_trainer()when trainer isNone. Common keys:epochs,max_time,logger,log_dir,gradient_clip_val,repeat_epoch.
- Returns:
The trained diffusion model, the dataset, and the Lightning trainer.
- Return type:
- agedi.api.training.train_from_config(config: str | pathlib.Path | Dict) Tuple[Agedi, agedi.data.Dataset, lightning.Trainer]¶
Train an AGeDi model from a YAML configuration file or dictionary.
This is the Hydra-style entry point. The configuration can be provided as:
a path to a YAML file (
strorPath),a plain Python
dict,a Hydra / OmegaConf
DictConfig.
The function loads the training data from
config["data_path"](an ASE-readable file) and delegates totrain_from_atoms()with the remaining configuration values.The minimal required key is
data_path. All other keys are optional and fall back to the same defaults astrain_from_atoms().A ready-to-edit template is shipped with the package at
agedi/conf/train.yaml.- Parameters:
config – Configuration source – a YAML file path, a
dict, or an OmegaConfDictConfig.- Returns:
The trained diffusion model, the dataset used, and the Lightning trainer.
- Return type:
Examples
Minimal Python usage:
from agedi import train_from_config diffusion, dataset, trainer = train_from_config("conf/train.yaml")
Programmatic override:
from agedi import train_from_config cfg = {"data_path": "train.traj", "epochs": 50, "feature_size": 128} diffusion, _, _ = train_from_config(cfg)