agedi.diffusion.sdes.ve ======================= .. py:module:: agedi.diffusion.sdes.ve Classes ------- .. autoapisummary:: agedi.diffusion.sdes.ve.VE Module Contents --------------- .. py:class:: VE(sigma_min: float = 0.01, sigma_max: float = 1.0, **kwargs) Bases: :py:obj:`agedi.diffusion.sdes.SDE` Implements variance-exploding (VE) SDE. :param sigma_min: The minimum value of the sigma parameter. :type sigma_min: float :param sigma_max: The maximum value of the sigma parameter. :type sigma_max: float :rtype: VP .. py:attribute:: sigma_min :value: 0.01 .. py:attribute:: sigma_max :value: 1.0 .. py:attribute:: noise_schedule .. py:method:: get_hparams() -> Dict Return hyperparameters for this VE SDE. .. py:method:: drift(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor Implement VP drift term. Defines the drift term of the SDE: f(x, t). :param x: The positions of the atoms. :type x: torch.Tensor :param t: The time at which to calculate the drift term. :type t: torch.Tensor :returns: **drift** -- The drift term of the SDE. :rtype: torch.Tensor .. py:method:: diffusion(t: torch.Tensor) -> torch.Tensor Implement VP diffusion term. Defines the diffusion term of the SDE: g(t). :param t: The time at which to calculate the diffusion term. :type t: torch.Tensor :returns: **diffusion** -- The diffusion term of the SDE. :rtype: torch.Tensor .. py:method:: mean(t: torch.Tensor) -> torch.Tensor Implement VP mean term. Calculates the mean of transition kernel at time t: mu(t). :param t: The time at which to calculate the mean. :type t: torch.Tensor :returns: **mean** -- The mean of the diffusion process. :rtype: torch.Tensor .. py:method:: var(t: torch.Tensor) -> torch.Tensor Implement VP variance term. Calculates the variance of transition kernel at time t: sigma^2(t). :param t: The time at which to calculate the variance. :type t: torch.Tensor :returns: **var** -- The variance of the diffusion process. :rtype: torch.Tensor .. py:method:: sigma(t: torch.Tensor) -> torch.Tensor VE sigma function Calculates the value of sigma at time t. :param t: The time at which to calculate sigma. :type t: torch.Tensor :returns: **sigma** -- The value of sigma at time t. :rtype: torch.Tensor