agedi.diffusion.sdes.vp ======================= .. py:module:: agedi.diffusion.sdes.vp Classes ------- .. autoapisummary:: agedi.diffusion.sdes.vp.VP Module Contents --------------- .. py:class:: VP(beta_min: float = 0.01, beta_max: float = 3, **kwargs) Bases: :py:obj:`agedi.diffusion.sdes.SDE` Implements variance-preserving (VP) SDE. :param beta_min: The minimum value of the beta parameter. :type beta_min: float :param beta_max: The maximum value of the beta parameter. :type beta_max: float :rtype: VP .. py:attribute:: beta_min :value: 0.01 .. py:attribute:: beta_max :value: 3 .. py:attribute:: noise_schedule .. py:method:: get_hparams() -> Dict Return hyperparameters for this VP SDE. .. py:method:: drift(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor Implement VP drift term. Defines the drift term of the SDE: f(x, t). :param x: The positions of the atoms. :type x: torch.Tensor :param t: The time at which to calculate the drift term. :type t: torch.Tensor :returns: **drift** -- The drift term of the SDE. :rtype: torch.Tensor .. py:method:: diffusion(t: torch.Tensor) -> torch.Tensor Implement VP diffusion term. Defines the diffusion term of the SDE: g(t). :param t: The time at which to calculate the diffusion term. :type t: torch.Tensor :returns: **diffusion** -- The diffusion term of the SDE. :rtype: torch.Tensor .. py:method:: mean(t: torch.Tensor) -> torch.Tensor Implement VP mean term. Calculates the mean of transition kernel at time t: mu(t). :param t: The time at which to calculate the mean. :type t: torch.Tensor :returns: **mean** -- The mean of the diffusion process. :rtype: torch.Tensor .. py:method:: var(t: torch.Tensor) -> torch.Tensor Implement VP variance term. Calculates the variance of transition kernel at time t: sigma^2(t). :param t: The time at which to calculate the variance. :type t: torch.Tensor :returns: **var** -- The variance of the diffusion process. :rtype: torch.Tensor .. py:method:: beta(t: torch.Tensor) -> torch.Tensor VP Beta function Calculates the value of beta at time t. :param t: The time at which to calculate beta. :type t: torch.Tensor :returns: **beta** -- The value of beta at time t. :rtype: torch.Tensor .. py:method:: alpha(t: torch.Tensor) -> torch.Tensor VP Alpha function Calculates the value of alpha at time t with .. math:: lpha(t) = int_{0}^{t} beta(s) ds. :param t: The time at which to calculate alpha. :type t: torch.Tensor :returns: **alpha** -- The value of alpha at time t. :rtype: torch.Tensor