agedi.diffusion.guidance ======================== .. py:module:: agedi.diffusion.guidance .. autoapi-nested-parse:: Force-field guidance utilities for diffusion sampling. This module provides: - :class:`ForcefieldGuidanceConfig` – configuration dataclass. - :class:`LBFGSStepSizer` – per-graph L-BFGS step-size adapter. - :class:`BatchedLBFGSStepSizer` – batched wrapper around :class:`LBFGSStepSizer`. - :func:`force_field_guidance_step` – one guidance step (module-level). - :func:`post_diffusion_relaxation_step` – post-diffusion relaxation (module-level). Classes ------- .. autoapisummary:: agedi.diffusion.guidance.ForcefieldGuidanceConfig agedi.diffusion.guidance.LBFGSStepSizer agedi.diffusion.guidance.BatchedLBFGSStepSizer Functions --------- .. autoapisummary:: agedi.diffusion.guidance.force_field_guidance_step agedi.diffusion.guidance.post_diffusion_relaxation_step Module Contents --------------- .. py:class:: ForcefieldGuidanceConfig Configuration for force-field guided sampling. :param guidance: Scale of the force-field guidance applied at each reverse step. Set to ``0.0`` (the default) to disable guidance entirely. :type guidance: float :param zeta: Exponent for the time-dependent weight factor ``(1 - t)**zeta``. Higher values concentrate guidance near the end of the trajectory. :type zeta: float :param force_threshold: Convergence criterion for the optional post-diffusion relaxation: the maximum per-atom force magnitude (eV/Å) below which relaxation stops. :type force_threshold: float :param max_extra_steps: Maximum number of additional relaxation steps performed after the main diffusion trajectory when ``guidance > 0``. :type max_extra_steps: int .. py:attribute:: guidance :type: float :value: 0.0 .. py:attribute:: zeta :type: float :value: 3.0 .. py:attribute:: force_threshold :type: float :value: 0.05 .. py:attribute:: max_extra_steps :type: int :value: 0 .. py:class:: LBFGSStepSizer(memory_size: int = 10, initial_step: float = 0.1, device: str = 'cuda') L-BFGS approach for determining optimal step sizes in force field guidance. .. py:attribute:: memory_size :value: 10 .. py:attribute:: initial_step :value: 0.1 .. py:attribute:: device :value: 'cuda' .. py:attribute:: s_list .. py:attribute:: y_list .. py:attribute:: rho_list .. py:attribute:: prev_pos :value: None .. py:attribute:: prev_forces :value: None .. py:attribute:: H0_scaling :value: 1.0 .. py:method:: compute_step(pos: torch.Tensor, forces: torch.Tensor) -> torch.Tensor Compute the optimal step using L-BFGS approximation. :param pos: Current atomic positions (B×N×3 tensor). :type pos: torch.Tensor :param forces: Current forces (B×N×3 tensor). :type forces: torch.Tensor :returns: Optimal step vector (B×N×3 tensor). :rtype: torch.Tensor .. py:method:: reset() -> None Reset the L-BFGS memory. .. py:class:: BatchedLBFGSStepSizer(batch_size: int, memory_size: int = 10, initial_step: float = 0.1) Batched wrapper around :class:`LBFGSStepSizer` for use with batched graphs. Maintains one :class:`LBFGSStepSizer` per graph in a batch and dispatches the step computation to the appropriate instance based on batch indices. .. py:attribute:: step_sizers .. py:method:: compute_step(pos: torch.Tensor, forces: torch.Tensor, batch_idx: torch.Tensor) -> torch.Tensor Compute steps for batched data. :param pos: Current atomic positions. :type pos: torch.Tensor :param forces: Current forces acting on the atoms. :type forces: torch.Tensor :param batch_idx: Index tensor mapping each atom to its graph in the batch. :type batch_idx: torch.Tensor :returns: Combined step tensor with the same shape as *pos*. :rtype: torch.Tensor .. py:method:: reset() -> None Reset the L-BFGS memory for all step-sizers in the batch. .. py:function:: force_field_guidance_step(batch: agedi.data.AtomsGraph, regressor_model: torch.nn.Module, lbfgs_step_sizer: BatchedLBFGSStepSizer, scale: float, zeta: float = 3.0, max_step_size: float = 0.1) -> agedi.data.AtomsGraph Apply one force-field guidance step with batched L-BFGS step-size adaptation. :param batch: A batch of AtomsGraph data. :type batch: AtomsGraph :param regressor_model: The regressor model used to compute forces. :type regressor_model: torch.nn.Module :param lbfgs_step_sizer: The L-BFGS step sizer (one per graph in the batch). :type lbfgs_step_sizer: BatchedLBFGSStepSizer :param scale: Base scale of the force field guidance. :type scale: float :param zeta: Exponent for the time-dependent weight ``(1 - t)**zeta``. :type zeta: float, optional :param max_step_size: Maximum allowed step size magnitude. Default is 0.1. :type max_step_size: float, optional :returns: Updated batch after applying the guidance step. :rtype: AtomsGraph .. py:function:: post_diffusion_relaxation_step(batch: agedi.data.AtomsGraph, regressor_model: torch.nn.Module, lbfgs_step_sizer: Optional[BatchedLBFGSStepSizer], scale: float = 0.1) -> agedi.data.AtomsGraph Perform a pure force-based relaxation step after diffusion is complete. :param batch: A batch of AtomsGraph data. :type batch: AtomsGraph :param regressor_model: The regressor model used to compute forces. :type regressor_model: torch.nn.Module :param lbfgs_step_sizer: The L-BFGS step sizer. Initialised from ``batch`` if ``None``. :type lbfgs_step_sizer: BatchedLBFGSStepSizer or None :param scale: Step size scaling factor for relaxation. :type scale: float, optional :returns: Updated batch after relaxation step. :rtype: AtomsGraph