agedi.api.training ================== .. py:module:: agedi.api.training .. autoapi-nested-parse:: Training orchestration. Attributes ---------- .. autoapisummary:: agedi.api.training._TRAIN_FROM_ATOMS_KEYS agedi.api.training._TRAINER_KEYS Functions --------- .. autoapisummary:: agedi.api.training.create_trainer agedi.api.training.train agedi.api.training.train_from_atoms agedi.api.training.train_from_config Module Contents --------------- .. py:data:: _TRAIN_FROM_ATOMS_KEYS .. py:data:: _TRAINER_KEYS .. py:function:: create_trainer(*, epochs: int = -1, max_time: Optional[Union[int, Dict, datetime.timedelta]] = 24, accelerator: str = 'auto', devices: int = 1, logger: str = 'tensorboard', log_dir: str = 'logs', project: str = 'agedi', name: str = 'agedi', log_interval: int = 10, gradient_clip_val: float = 10.0, progress_bar: bool = False, print_epoch_interval: int = 10, log_grad_norm: bool = True, repeat: Optional[int] = None, repeat_epoch: Optional[int] = None, hparams: Optional[Dict] = None, extra_callbacks: Optional[List[lightning.pytorch.callbacks.Callback]] = None) -> lightning.Trainer Create a Lightning trainer configured for AGeDi. :param epochs: Maximum number of training epochs (``-1`` = unlimited). :param max_time: Wall-clock time limit for training. Accepts: * ``int`` – number of *hours* (e.g. ``24`` ≡ 24 hours). * ``dict`` – Lightning-style mapping, e.g. ``{"days": 0, "hours": 12, "minutes": 30, "seconds": 0}``. * :class:`datetime.timedelta` – a Python timedelta object. * ``None`` – no time limit. :param accelerator: Hardware accelerator to use (e.g. ``"auto"``, ``"gpu"``, ``"cpu"``). Default: ``"auto"``. :param devices: Number of devices to train on. Default: ``1``. :param logger: Logging backend: ``"tensorboard"`` (default) or ``"wandb"``. :param log_dir: Root directory for logs and checkpoints. Default: ``"logs"``. :param project: WandB project name (only used when ``logger="wandb"``). :param name: Experiment display name used by TensorBoard and WandB as the run sub-directory / run name. Default: ``"agedi"``. :param log_interval: How often (in steps) to log metrics. Default: ``10``. :param gradient_clip_val: Maximum gradient norm for gradient clipping. Default: ``10.0``. :param progress_bar: Whether to show a Lightning progress bar. Default: ``False``. :param print_epoch_interval: Print a one-line training summary to stdout every this many epochs. Set to ``0`` to disable. Default: ``10``. :param log_grad_norm: Whether to log the total gradient norm during training. Disable for large models where the per-step overhead is undesirable. Default: ``True``. :param repeat: Number of repetition levels for cell-repeat data augmentation. Must be set together with *repeat_epoch*. When ``None`` (default), no repetition augmentation is applied. :param repeat_epoch: How many epochs between repetition-level increases. Required when *repeat* is set. :param hparams: Hyperparameters dict logged to ``hparams.yaml`` via :class:`~agedi.data.callbacks.HParamsMetricLogger`. When ``None`` (default), no extra hyperparameter logging is performed. :param extra_callbacks: Extra Lightning callbacks to append to the default callback list. When ``None`` (default) only the built-in callbacks are used. :returns: A configured :class:`~lightning.Trainer` ready to call ``trainer.fit(diffusion, dataset)``. :rtype: lightning.Trainer .. py:function:: train(diffusion: Agedi, dataset: agedi.data.Dataset, trainer: Optional[lightning.Trainer] = None, ckpt_path: Optional[Union[str, pathlib.Path]] = None, **trainer_kwargs) -> lightning.Trainer Train a diffusion model and return the trainer used. :param diffusion: The diffusion model to train. :param dataset: The dataset to train on. :param trainer: A pre-configured Lightning :class:`~lightning.Trainer`. When ``None`` a new trainer is created from *trainer_kwargs*. :param ckpt_path: Path to a Lightning checkpoint (``.ckpt``) to resume training from. When provided the full training state (model weights, optimiser, LR-scheduler, and epoch counter) is restored before fitting. Equivalent to passing ``ckpt_path`` to ``trainer.fit()``. :param \*\*trainer_kwargs: Additional keyword arguments forwarded to :func:`create_trainer` when *trainer* is ``None``. .. py:function:: train_from_atoms(data: Sequence[ase.Atoms], *, model: str = 'PaiNN', cutoff: float = 6.0, feature_size: int = 64, n_blocks: int = 4, n_rbf: int = 30, noisers: Sequence[str] = ('CellPositions', ), sde: Union[str, SDE] = 've', conditioning: str = 'none', conditioning_type: str = 'scalar', mask: str = 'none', confinement: Optional[Tuple[float, float]] = None, force_field: bool = False, batch_size: int = 64, train_split: Union[float, int] = 0.9, val_split: Union[float, int] = 0.1, repeat: Optional[int] = None, canonical_cell: bool = False, lr: float = 0.0001, lr_factor: float = 0.95, lr_patience: int = 100, weight_decay: float = 0.0, eps: float = 1e-05, guidance_weight: float = -1.0, data_path: Optional[str] = None, regressor_data: Optional[Sequence[ase.Atoms]] = None, checkpoint: Optional[Union[str, pathlib.Path]] = None, trainer: Optional[lightning.Trainer] = None, n_classes: Optional[int] = None, **trainer_kwargs) -> Tuple[Agedi, agedi.data.Dataset, lightning.Trainer] Build (or restore), train, and return an AGeDi model from ASE Atoms data. When a ``"Types"`` noiser is included and no *checkpoint* is given, the unique element types present in *data* are automatically detected and a compact type map is built so that the vocabulary size equals the number of distinct element types (plus the absorbing state at index 0). The ``n_classes`` parameter can be used to restrict the vocabulary to the *n_classes* most frequently occurring element types (sorted by atomic number). :param data: ASE :class:`~ase.Atoms` objects to train on. :param model: GNN backbone architecture name. Looked up in the model registry; use :func:`register_model` to add custom backends. Default: ``"PaiNN"`` (SchNetPack PaiNN). :param cutoff: Neighbour-list cutoff radius in Å. Default: ``6.0``. :param feature_size: Embedding / feature dimension. Default: ``64``. :param n_blocks: Number of interaction blocks in the GNN backbone. Default: ``4``. :param n_rbf: Number of radial basis functions. Default: ``30``. :param noisers: Sequence of noiser identifiers. Recognised string identifiers: ``"Positions"``, ``"CellPositions"``, ``"ConfinedCellPositions"``, ``"Types"`` (snake_case aliases also accepted). Default: ``("CellPositions",)``. :param sde: SDE for position noisers. Short aliases: ``"ve"`` (default), ``"vp"``. Pass an instantiated :class:`~agedi.diffusion.sdes.SDE` for full control. :param conditioning: Per-structure property to condition on (read from ``atoms.info[conditioning]`` or ``atoms.get_()``), or ``"none"`` for time-only conditioning (default). :param conditioning_type: Type of the conditioning module: ``"scalar"`` (default) or ``"integer"``. :param mask: Atom-masking strategy: ``"MaskFixed"`` (freeze atoms tagged with ASE :class:`~ase.constraints.FixAtoms`) or ``"none"`` (default). :param confinement: Z-direction confinement bounds ``(z_min, z_max)`` in Å. Required when using the ``"ConfinedCellPositions"`` noiser. :param force_field: When ``True``, attach a regressor head (sharing the backbone) that predicts per-atom forces and total energy. Enables force-field guided sampling via :class:`~agedi.diffusion.ForcefieldGuidanceConfig`. The training data must contain DFT (or other) forces and energy. Default: ``False``. :param batch_size: Mini-batch size used during training. Default: ``64``. :param train_split: Fraction or absolute count of structures for the training split. Default: ``0.9``. :param val_split: Fraction or absolute count of structures for the validation split. Default: ``0.1``. :param repeat: When given, augment the dataset by repeating each structure up to ``repeat`` times along the first two cell vectors. Requires ``repeat_epoch`` (passed via ``**trainer_kwargs``) to specify how often the repetition level increases. :param canonical_cell: Store unit cells in canonical lower-triangular form. Default: ``False``. :param lr: Learning rate. Default: ``1e-4``. :param lr_factor: LR-scheduler reduction factor. Default: ``0.95``. :param lr_patience: LR-scheduler patience (epochs). Default: ``100``. :param weight_decay: Optimiser weight decay. Default: ``0.0``. :param eps: Minimum diffusion time value. Default: ``1e-5``. :param guidance_weight: Classifier-free guidance weight. Default: ``-1.0`` (disabled). :param data_path: String path to the training data file; stored in ``hparams.yaml`` for reference only. When ``None``, no path metadata is saved. :param regressor_data: Optional additional ASE Atoms objects used *exclusively* for training the force-field regressor head. Structures here are never passed through the diffusion loss. Each structure must have an ASE calculator with energy and forces attached. :param checkpoint: Path to a previously saved run directory (containing ``hparams.yaml``) or directly to a ``.ckpt`` checkpoint file. When provided the model architecture and weights are loaded from the checkpoint instead of being built from the architecture parameters (*model*, *cutoff*, *feature_size*, etc.). The full training state (optimiser, LR-scheduler, epoch counter) is also restored so that training continues seamlessly. Supply *data* to train on new data, or use the original data path to resume on the same dataset. :param trainer: A pre-configured Lightning :class:`~lightning.Trainer`. When ``None`` (default) a new trainer is built from ``**trainer_kwargs``. :param n_classes: Number of element-type classes to use for the :class:`~agedi.diffusion.noisers.Types` noiser (not counting the absorbing state at index 0). When ``None`` (default), all distinct element types present in *data* are used. Must not exceed the number of distinct types in the training data. Ignored when *checkpoint* is provided (the vocabulary is loaded from the checkpoint). :param \*\*trainer_kwargs: Additional keyword arguments forwarded to :func:`create_trainer` when *trainer* is ``None``. Common keys: ``epochs``, ``max_time``, ``logger``, ``log_dir``, ``gradient_clip_val``, ``repeat_epoch``. :returns: The trained diffusion model, the dataset, and the Lightning trainer. :rtype: Tuple[Agedi, Dataset, Trainer] .. py:function:: train_from_config(config: Union[str, pathlib.Path, Dict]) -> Tuple[Agedi, agedi.data.Dataset, lightning.Trainer] Train an AGeDi model from a YAML configuration file or dictionary. This is the *Hydra-style* entry point. The configuration can be provided as: * a path to a YAML file (``str`` or :class:`~pathlib.Path`), * a plain Python ``dict``, * a Hydra / OmegaConf ``DictConfig``. The function loads the training data from ``config["data_path"]`` (an ASE-readable file) and delegates to :func:`train_from_atoms` with the remaining configuration values. The minimal required key is ``data_path``. All other keys are optional and fall back to the same defaults as :func:`train_from_atoms`. A ready-to-edit template is shipped with the package at ``agedi/conf/train.yaml``. :param config: Configuration source – a YAML file path, a ``dict``, or an OmegaConf ``DictConfig``. :returns: The trained diffusion model, the dataset used, and the Lightning trainer. :rtype: Tuple[Agedi, Dataset, Trainer] .. rubric:: Examples Minimal Python usage:: from agedi import train_from_config diffusion, dataset, trainer = train_from_config("conf/train.yaml") Programmatic override:: from agedi import train_from_config cfg = {"data_path": "train.traj", "epochs": 50, "feature_size": 128} diffusion, _, _ = train_from_config(cfg)