agedi.diffusion.noisers.sde =========================== .. py:module:: agedi.diffusion.noisers.sde Classes ------- .. autoapisummary:: agedi.diffusion.noisers.sde.SDENoiser Module Contents --------------- .. py:class:: SDENoiser(sde_class: agedi.diffusion.sdes.SDE, sde_kwargs: Optional[Dict], distribution: agedi.diffusion.distributions.Distribution, prior: agedi.diffusion.distributions.Distribution, sde: Optional[agedi.diffusion.sdes.SDE] = None, **kwargs) Bases: :py:obj:`agedi.diffusion.noisers.Noiser`, :py:obj:`abc.ABC` Implements a SDE base class that can be inherited by other classes. :param sde_class: The class of the SDE to be used for the noising. :type sde_class: SDE :param sde_kwargs: The keyword arguments to be passed to the SDE class. :type sde_kwargs: Dict :param distribution: The distribution to be used for the noise. :type distribution: Distribution :param prior: The prior distribution to be used for the noise. :type prior: Distribution :param sde: An already-instantiated SDE object. When provided, *sde_class* and *sde_kwargs* are ignored. :type sde: SDE, optional :param key: The key to be used for the noising. :type key: str :param \*\*kwargs: Additional keyword arguments to be passed to the Noiser class. :returns: The noiser for the atoms positions in Cartesian coordinates. :rtype: Noiser .. py:attribute:: _key :value: None .. py:method:: get_hparams() -> Dict Return hyperparameters for this SDE noiser. .. py:method:: postprocess_score(score: torch.Tensor) -> torch.Tensor :abstractmethod: Post-process the predicted score before computing the loss. :param score: Raw predicted score tensor. :type score: torch.Tensor :returns: Post-processed score tensor. :rtype: torch.Tensor .. py:method:: postprocess_noise(noise: torch.Tensor) -> torch.Tensor :abstractmethod: Post-process the noise tensor before computing the loss. :param noise: Raw noise tensor. :type noise: torch.Tensor :returns: Post-processed noise tensor. :rtype: torch.Tensor .. py:method:: _noise(batch: agedi.data.AtomsGraph) -> agedi.data.AtomsGraph Adds noise to the atomistic structure. Added noise is stored in the self.key+"_noise". :param batch: The atomistic structure (or batch hereof) to be noised. :type batch: AtomsGraph :returns: The noised atomistic structure (or bach hereof). :rtype: AtomsGraph .. py:method:: _denoise(batch: agedi.data.AtomsGraph, delta_t: float, last: bool) -> agedi.data.AtomsGraph Denoises the positions of the atomistic structure. The denoising follows the Euler-Maruyama scheme. ::math:: R_i+1 = R_i + \Delta t (f(R_i, t) + g(t)**2 * s(R_i, t)) + \sqrt{\Delta t} g(t) * w The used score is expected to be stored in the self.key+"_score". :param batch: The atomistic structure (or batch hereof) to be denoised. :type batch: AtomsGraph :param delta_t: The time step for the denoising. :type delta_t: float :param last: If the denoising is the last step of the denoising. :type last: bool :returns: The denoised atomistic structure (or bach hereof). :rtype: AtomsGraph .. py:method:: _loss(batch: agedi.data.AtomsGraph) -> torch.Tensor Compute the noiser loss. Computes the loss of the diffusion model SDE noiser Expects the total added noise to be stored in the self.key+"_noise", and the predicted score to be stored in the self.key+"_score". The loss is computed as ::math:: L = \sum_i ||\sigma_t w_i + \sigma_t^2 s(R_i)||^2 :param batch: The atomistic structure (or batch hereof) to be noised and denoised. :type batch: AtomsGraph :returns: The loss of the noised and denoised atomistic structure. :rtype: float