agedi.diffusion =============== .. py:module:: agedi.diffusion Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/agedi/diffusion/agedi/index /autoapi/agedi/diffusion/diffusion/index /autoapi/agedi/diffusion/distributions/index /autoapi/agedi/diffusion/guidance/index /autoapi/agedi/diffusion/noisers/index /autoapi/agedi/diffusion/sdes/index Classes ------- .. autoapisummary:: agedi.diffusion.Agedi agedi.diffusion.ForcefieldGuidanceConfig agedi.diffusion.Diffusion Package Contents ---------------- .. py:class:: Agedi(score_model: agedi.models.ScoreModel, noisers: List[agedi.diffusion.noisers.Noiser], regressor_model: Optional[torch.nn.Module] = None, regressor_heads: Optional[List] = None, regressor_loss_weight: float = 1.0, optim_config: Optional[Dict] = None, scheduler_config: Optional[Dict] = None, eps: float = 1e-05) Bases: :py:obj:`lightning.LightningModule`, :py:obj:`agedi.diffusion.diffusion.Diffusion` Full diffusion model: training + sampling. Combines the :class:`~agedi.diffusion.Diffusion` sampling pipeline with :class:`~lightning.LightningModule` training hooks. :param score_model: The score model. :type score_model: ScoreModel :param noisers: A list of noisers. :type noisers: List[Noiser] :param regressor_model: An optional regressor model used for force-field guidance during sampling. When present, its loss is added to the diffusion loss during training. :type regressor_model: torch.nn.Module, optional :param regressor_heads: When provided, a :class:`~agedi.models.regressor.RegressorModel` is built internally using these heads while **sharing** the translator and representation from ``score_model``. Use this parameter (instead of ``regressor_model``) when the backbone should be shared. :type regressor_heads: List, optional :param regressor_loss_weight: Weight applied to the regressor loss. Defaults to ``1.0``. :type regressor_loss_weight: float, optional :param optim_config: Keyword arguments forwarded to :class:`torch.optim.AdamW`. :type optim_config: dict, optional :param scheduler_config: Keyword arguments forwarded to :class:`torch.optim.lr_scheduler.ReduceLROnPlateau`. :type scheduler_config: dict, optional :param eps: Minimum diffusion time value. :type eps: float, optional .. py:attribute:: regressor_loss_weight :value: 1.0 .. py:attribute:: optim_config :value: None .. py:attribute:: scheduler_config :value: None .. py:attribute:: _regressor_training :value: False .. py:method:: on_fit_start() -> None Write ``hparams.yaml`` to the trainer log directory at training start. .. py:method:: get_hparams() -> Dict Return hyperparameters sufficient to reconstruct this diffusion model. :returns: Hyperparameter dictionary with ``_target_``, ``score_model``, ``noisers``, ``optim_config``, ``scheduler_config``, ``eps``, and optionally ``regressor_heads`` or ``regressor_model``. :rtype: dict .. py:method:: setup(stage: str = None) -> None Set up the model (put score model in training mode). .. py:method:: forward(batch: agedi.data.AtomsGraph) -> agedi.data.AtomsGraph Forward pass through the score model. :param batch: A batch of AtomsGraph data. :type batch: AtomsGraph :returns: The output of the score model forward pass. :rtype: AtomsGraph .. py:method:: loss(batch: agedi.data.AtomsGraph, batch_idx: torch.Tensor) -> Dict Compute the combined diffusion + regressor loss. Always computes the diffusion (denoising) loss on a noised copy of the batch. When a regressor model is present and the batch contains force labels, the regressor loss is added with weight ``regressor_loss_weight``. :param batch: A batch of AtomsGraph data. :type batch: AtomsGraph :param batch_idx: The index of the batch. :type batch_idx: torch.Tensor :returns: A dictionary of losses. :rtype: dict .. py:method:: diffusion_loss(batch: agedi.data.AtomsGraph, batch_idx: torch.Tensor) -> Dict Compute the diffusion (denoising score-matching) loss. :param batch: A batch of AtomsGraph data. :type batch: AtomsGraph :param batch_idx: The index of the batch. :type batch_idx: torch.Tensor :returns: A dictionary of losses. :rtype: dict .. py:method:: regressor_loss(batch: agedi.data.AtomsGraph, batch_idx: torch.Tensor) -> Dict Compute the regressor loss on the un-noised batch. :param batch: A batch of AtomsGraph data. :type batch: AtomsGraph :param batch_idx: The index of the batch. :type batch_idx: torch.Tensor :returns: A dictionary of losses. :rtype: dict :raises ValueError: If no regressor model is attached. .. py:method:: training_step(batch, batch_idx: torch.Tensor) -> torch.Tensor Perform a training step. Computes the combined diffusion + regressor loss (see :meth:`loss`). When the :class:`~agedi.data.Dataset` was set up with a dedicated regressor dataset (via :meth:`~agedi.data.Dataset.add_regressor_data`), ``batch`` is a dict with two keys: * ``"main"`` – a regular training batch used for both the diffusion and regressor loss. * ``"regressor"`` – a regressor-only batch whose structures are *only* forwarded through the regressor loss (not the diffusion loss). When no regressor dataset is present ``batch`` is a plain :class:`~agedi.data.AtomsGraph` batch and the behaviour is identical to the pre-existing implementation. :param batch: A batch of AtomsGraph data, or a dict with ``"main"`` and ``"regressor"`` keys when a dedicated regressor dataset is used. :type batch: AtomsGraph or dict :param batch_idx: The index of the batch. :type batch_idx: torch.Tensor :returns: The combined loss. :rtype: torch.Tensor .. py:method:: validation_step(batch: agedi.data.AtomsGraph, batch_idx: torch.Tensor) -> torch.Tensor Perform a validation step. :param batch: A batch of AtomsGraph data. :type batch: AtomsGraph :param batch_idx: The index of the batch. :type batch_idx: torch.Tensor :returns: The combined loss. :rtype: torch.Tensor .. py:method:: configure_optimizers() -> Dict Configure optimizers and learning-rate schedulers. When a regressor model is present a single optimizer is built over the deduplicated union of ``score_model`` and ``regressor_model`` parameters (shared parameters appear only once). :returns: A dictionary with ``"optimizer"``, ``"lr_scheduler"``, and ``"monitor"`` keys. :rtype: dict .. py:method:: _scheduler_monitor() -> str Return the metric used by ReduceLROnPlateau. .. py:property:: regressor_training :type: bool Whether the regressor model is in training mode. .. py:class:: ForcefieldGuidanceConfig Configuration for force-field guided sampling. :param guidance: Scale of the force-field guidance applied at each reverse step. Set to ``0.0`` (the default) to disable guidance entirely. :type guidance: float :param zeta: Exponent for the time-dependent weight factor ``(1 - t)**zeta``. Higher values concentrate guidance near the end of the trajectory. :type zeta: float :param force_threshold: Convergence criterion for the optional post-diffusion relaxation: the maximum per-atom force magnitude (eV/Å) below which relaxation stops. :type force_threshold: float :param max_extra_steps: Maximum number of additional relaxation steps performed after the main diffusion trajectory when ``guidance > 0``. :type max_extra_steps: int .. py:attribute:: guidance :type: float :value: 0.0 .. py:attribute:: zeta :type: float :value: 3.0 .. py:attribute:: force_threshold :type: float :value: 0.05 .. py:attribute:: max_extra_steps :type: int :value: 0 .. py:class:: Diffusion(score_model: ScoreModel, noisers: List[agedi.diffusion.noisers.Noiser], regressor_model: Optional[torch.nn.Module] = None, eps: float = 1e-05) Pure-Python sampling core for diffusion models. Holds the score model, noisers, and an optional regressor and provides the full forward / reverse / sampling pipeline. This class does **not** inherit from :class:`torch.nn.Module` or :class:`lightning.LightningModule` and therefore has no training hooks. When used through :class:`~agedi.diffusion.Agedi` (which inherits from both this class and :class:`lightning.LightningModule`), the Lightning infrastructure manages device placement and module registration. When used standalone, device information is derived from the score model's parameters via the :attr:`device` property. :param score_model: The score model. :type score_model: ScoreModel :param noisers: A list of noisers. :type noisers: List[Noiser] :param regressor_model: An optional regressor model used for force-field guidance during sampling. :type regressor_model: torch.nn.Module, optional :param eps: Minimum value for the diffusion time step (used in :meth:`sample_time`). :type eps: float, optional .. py:attribute:: score_model .. py:attribute:: noisers .. py:attribute:: regressor_model :value: None .. py:attribute:: eps :value: 1e-05 .. py:attribute:: lbfgs_step_sizer :type: Optional[agedi.diffusion.guidance.BatchedLBFGSStepSizer] :value: None .. py:attribute:: zeta :type: float :value: 3.0 .. py:attribute:: noiser_keys .. py:attribute:: score_keys .. py:attribute:: _compiled_reverse_step :value: None .. py:property:: device :type: torch.device Infer the computation device from the score model's parameters. When used through :class:`~agedi.diffusion.Agedi` (which also inherits :class:`lightning.LightningModule`), Lightning's own ``device`` property takes precedence. .. py:method:: sample_time(batch: agedi.data.AtomsGraph) -> None Sample a random diffusion time for each graph in *batch*. Draws times uniformly from ``[eps, 1]`` and assigns them to ``batch.time`` at atom resolution. :param batch: A batch of AtomsGraph data; modified in-place. :type batch: AtomsGraph .. py:method:: forward_step(batch: agedi.data.AtomsGraph) -> agedi.data.AtomsGraph Forward diffusion step (corruption). Applies each noiser in order to corrupt the batch. :param batch: A batch of AtomsGraph data. :type batch: AtomsGraph :returns: The corrupted batch. :rtype: AtomsGraph .. py:method:: reverse_step(batch: agedi.data.AtomsGraph, delta_t: float, force_field_guidance: float, last: bool = False, timings: Optional[SamplingTimings] = None) -> agedi.data.AtomsGraph Reverse diffusion step (denoising). Evaluates the score model and applies one reverse-SDE step through all noisers. Optionally applies force-field guidance afterwards. :param batch: A batch of AtomsGraph data. :type batch: AtomsGraph :param delta_t: The time step. :type delta_t: float :param force_field_guidance: Scale of the force-field guidance (``0.0`` disables it). :type force_field_guidance: float :param last: Whether this is the final denoising step. :type last: bool, optional :param timings: If provided, timing measurements are accumulated here. :type timings: SamplingTimings, optional :returns: The denoised batch. :rtype: AtomsGraph .. py:method:: corrector_step(batch: agedi.data.AtomsGraph, corrector_dt: float) -> agedi.data.AtomsGraph Langevin corrector step at constant time. Evaluates the score model and applies one Langevin corrector step through all noisers (in reverse order). :param batch: A batch of AtomsGraph data. :type batch: AtomsGraph :param corrector_dt: Step size for the Langevin corrector. :type corrector_dt: float :returns: The corrected batch. :rtype: AtomsGraph .. py:method:: force_field_guidance_step(batch: agedi.data.AtomsGraph, scale: float, max_step_size: float = 0.1) -> agedi.data.AtomsGraph Apply one force-field guidance step. :param batch: A batch of AtomsGraph data. :type batch: AtomsGraph :param scale: Base scale of the force field guidance. :type scale: float :param max_step_size: Maximum allowed step size magnitude. :type max_step_size: float, optional :returns: Updated batch. :rtype: AtomsGraph .. py:method:: post_diffusion_relaxation_step(batch: agedi.data.AtomsGraph, scale: float = 0.1) -> agedi.data.AtomsGraph Perform a pure force-based relaxation step. :param batch: A batch of AtomsGraph data. :type batch: AtomsGraph :param scale: Step size scaling factor. :type scale: float, optional :returns: Updated batch. :rtype: AtomsGraph .. py:method:: _initialize_graph(cutoff: float, **kwargs) -> agedi.data.AtomsGraph Initialise a single graph from noiser priors. :param cutoff: Cutoff radius for the neighbour list. :type cutoff: float :param \*\*kwargs: Additional keyword arguments passed to the graph (e.g. ``cell``, ``template``, ``pbc``). :returns: The initialised graph. :rtype: AtomsGraph .. py:method:: _sync_for_timing(device: Optional[torch.device]) -> None :staticmethod: .. py:method:: _time_sampling_call(device: Optional[torch.device], timings: SamplingTimings, key: str, fn, *args, **kwargs) .. py:method:: _format_timing_line(label: str, value: float, count: Optional[int] = None) -> str :staticmethod: .. py:method:: _print_sampling_timings(timings: SamplingTimings) -> None .. py:property:: compiled_reverse_step Lazily compile :meth:`reverse_step` with ``torch.compile``. The compiled kernel is cached as ``self._compiled_reverse_step`` so that compilation happens at most once per model instance. Using a per-instance cache (rather than a class-level ``@torch.compile`` decorator) means that two :class:`Diffusion` objects with different architectures will each compile their own kernel and never interfere. .. note:: ``timings`` must **not** be passed to the compiled function — ``time.perf_counter`` is not traceable by Dynamo. Time the compiled call from outside in :meth:`_sample_batch` using the ``is_compiled`` flag. .. py:method:: _sample_batch(batch: torch_geometric.data.Batch, steps: int, eps: float, force_field_guidance: float, save_trajectory: bool, progress_bar: bool, force_threshold: float, max_extra_steps: int, corrector_steps: int = 0, corrector_step_size: float = 0.001, timings: Optional[SamplingTimings] = None, reverse_step_fn=None, is_compiled: bool = False) -> List[agedi.data.AtomsGraph] Run the reverse-diffusion loop for a pre-built batch. :param batch: A batch of :class:`~agedi.data.AtomsGraph` data at ``t=1``. :type batch: Batch :param steps: Number of reverse-diffusion steps. :type steps: int :param eps: Minimum time value (end of trajectory). :type eps: float :param force_field_guidance: Scale of the force-field guidance (``0.0`` disables it). :type force_field_guidance: float :param save_trajectory: Whether to collect and return all intermediate states. :type save_trajectory: bool :param progress_bar: Whether to display a tqdm progress bar. :type progress_bar: bool :param force_threshold: Maximum per-atom force for terminating post-diffusion relaxation. :type force_threshold: float :param max_extra_steps: Maximum extra relaxation steps after the main trajectory. :type max_extra_steps: int :param corrector_steps: Number of Langevin corrector passes after each predictor step. ``0`` (default) disables the corrector (standard DDPM/EM sampling). :type corrector_steps: int, optional :param corrector_step_size: Step size used for each Langevin corrector step. Defaults to ``1e-3``. :type corrector_step_size: float, optional :param timings: If provided, timing measurements are accumulated here. :type timings: SamplingTimings, optional :param reverse_step_fn: The reverse step function to use. Defaults to ``self.reverse_step``. Pass a ``torch.compile``-wrapped version to enable compiled sampling. :type reverse_step_fn: callable, optional :param is_compiled: Whether ``reverse_step_fn`` is a compiled function. :type is_compiled: bool, optional :returns: Final structures, or (when *save_trajectory* is ``True``) a list of trajectories (one per graph). :rtype: List[AtomsGraph] .. py:method:: _sample(N: int, steps: int, cutoff: float, eps: float, force_field_guidance: float, force_threshold: float, max_extra_steps: int, progress_bar: bool, save_trajectory: bool, corrector_steps: int = 0, corrector_step_size: float = 0.001, print_timings: bool = False, compile: bool = False, **kwargs) -> List[agedi.data.AtomsGraph] Build *N* graphs from priors and run the sampling loop. :param N: Number of structures to generate. :type N: int :param steps: Number of reverse-diffusion steps. :type steps: int :param cutoff: Cutoff radius for the neighbour list. :type cutoff: float :param eps: Minimum time value (end of trajectory). :type eps: float :param force_field_guidance: Scale of the force-field guidance. :type force_field_guidance: float :param force_threshold: Maximum per-atom force for post-diffusion relaxation. :type force_threshold: float :param max_extra_steps: Maximum extra relaxation steps. :type max_extra_steps: int :param progress_bar: Show tqdm progress bar. :type progress_bar: bool :param save_trajectory: Collect all intermediate states. :type save_trajectory: bool :param corrector_steps: Langevin corrector passes per predictor step. :type corrector_steps: int, optional :param corrector_step_size: Step size for each corrector pass. :type corrector_step_size: float, optional :param print_timings: Print a timing breakdown after sampling completes. :type print_timings: bool, optional :param compile: Use ``torch.compile`` on the reverse diffusion step. :type compile: bool, optional :param \*\*kwargs: Keyword arguments forwarded to :meth:`_initialize_graph`. :returns: Sampled structures (or trajectories when *save_trajectory* is ``True``). :rtype: List[AtomsGraph] .. py:method:: sample(N: int, template=None, batch_size: Optional[int] = 64, steps: Optional[int] = 500, cutoff: Optional[float] = 6.0, eps: Optional[float] = 0.001, n_atoms: Optional[int] = None, atomic_numbers: Optional[List[int]] = None, formula: Optional[str] = None, positions: Optional[numpy.ndarray] = None, cell: Optional[numpy.ndarray] = None, pbc: Optional[numpy.ndarray] = None, confinement: Optional[Tuple[float, float]] = None, compile: bool = False, ff_guidance: Optional[agedi.diffusion.guidance.ForcefieldGuidanceConfig] = None, property: Optional[Dict] = None, progress_bar: Optional[bool] = False, save_trajectory: Optional[bool] = False, print_timings: Optional[bool] = False, corrector_steps: int = 0, corrector_step_size: float = 0.001) -> List[agedi.data.AtomsGraph] Sample structures from the diffusion model. The minimum required arguments depend on the configured noisers and whether a template is provided: * ``n_atoms`` -- always required unless derivable from ``atomic_numbers`` or ``formula``. * ``atomic_numbers`` -- required unless a types-noiser is configured (key ``"x"``), or derivable from ``formula``. * ``positions`` -- required when no positions-noiser is configured (type-only diffusion). * ``cell`` -- required when no ``template`` is given. * ``pbc`` -- optional; defaults to ``[True, True, True]``. :param N: Number of structures to generate. :type N: int :param template: Template structure. ``cell`` and ``pbc`` are taken from the template when not explicitly provided. :type template: AtomsGraph or ase.Atoms, optional :param batch_size: Internal batch size for splitting large *N*. :type batch_size: int, optional :param steps: Number of reverse-diffusion steps. :type steps: int, optional :param cutoff: Cutoff radius for the neighbour list. :type cutoff: float, optional :param eps: Minimum time value at the end of the trajectory. :type eps: float, optional :param n_atoms: Number of atoms per structure. :type n_atoms: int, optional :param atomic_numbers: Atomic numbers of the atoms to generate. :type atomic_numbers: List[int], optional :param formula: Chemical formula (e.g. ``"H2O"``). :type formula: str, optional :param positions: Fixed atom positions (shape ``(n_atoms, 3)``). :type positions: np.ndarray, optional :param cell: Unit-cell matrix (3x3). :type cell: np.ndarray, optional :param pbc: Periodic boundary conditions. :type pbc: np.ndarray, optional :param confinement: Z-directional confinement ``(z_min, z_max)``. :type confinement: Tuple[float, float], optional :param compile: When ``True``, use ``torch.compile`` on the reverse diffusion step for improved throughput on CUDA hardware. :type compile: bool, optional :param ff_guidance: Force-field guidance configuration. :type ff_guidance: ForcefieldGuidanceConfig, optional :param property: Conditioning properties (key -> scalar tensor). :type property: dict, optional :param progress_bar: Show a tqdm progress bar. :type progress_bar: bool, optional :param save_trajectory: Return full trajectories instead of final structures. :type save_trajectory: bool, optional :param print_timings: Print a timing breakdown after sampling completes. :type print_timings: bool, optional :param corrector_steps: Number of Langevin corrector passes after each predictor step. ``0`` (default) gives standard (predictor-only) sampling. :type corrector_steps: int, optional :param corrector_step_size: Step size for each corrector pass. Defaults to ``1e-3``. :type corrector_step_size: float, optional :returns: Sampled structures, or trajectories when *save_trajectory* is ``True``. :rtype: List[AtomsGraph]