agedi.diffusion.sdes.base ========================= .. py:module:: agedi.diffusion.sdes.base Classes ------- .. autoapisummary:: agedi.diffusion.sdes.base.SDE Module Contents --------------- .. py:class:: SDE(noise_schedule: agedi.diffusion.sdes.noise_schedules.NoiseSchedule = Linear) Bases: :py:obj:`abc.ABC` SDE base class .. py:attribute:: noise_schedule_cls .. py:method:: get_hparams() -> Dict Return hyperparameters sufficient to reconstruct this SDE. Returns a dictionary with a ``_target_`` key (the fully-qualified class name). Subclasses should call ``super().get_hparams()`` and merge in their own constructor parameters. :returns: Hyperparameter dictionary. :rtype: dict .. py:method:: drift(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor :abstractmethod: Drift term of the SDE. Must be implemented by subclass. Defines the drift term of the SDE: .. math:: f(x, t) = ... :param x: The positions of the atoms. :type x: torch.Tensor :param t: The time at which to calculate the drift term. :type t: torch.Tensor :returns: **drift** -- The drift term of the SDE. :rtype: torch.Tensor .. py:method:: diffusion(t: torch.Tensor) -> torch.Tensor :abstractmethod: Diffusion term of the SDE. Must be implemented by subclass. Defines the diffusion term of the SDE: .. math:: g(t) = ... :param t: The time at which to calculate the diffusion term. :type t: torch.Tensor :returns: **diffusion** -- The diffusion term of the SDE. :rtype: torch.Tensor .. py:method:: mean(t: torch.Tensor) -> torch.Tensor :abstractmethod: Mean of the SDE. Must be implemented by subclass. Calculates the mean of transition kernel at time t: .. math:: \mu_t = ... :param t: The time at which to calculate the mean. :type t: torch.Tensor :returns: **mean** -- The mean of the diffusion process. :rtype: torch.Tensor .. py:method:: var(t: torch.Tensor) -> torch.Tensor :abstractmethod: Variance of the SDE. Must be implemented by subclass. Calculates the variance of transition kernel at time t: .. math:: \sigma_t^2 = ... :param t: The time at which to calculate the variance. :type t: torch.Tensor :returns: **var** -- The variance of the diffusion process. :rtype: torch.Tensor .. py:method:: transition_kernel(x: torch.Tensor, t: torch.Tensor, w: Callable) -> torch.Tensor Transition kernel of the SDE. Calculates the transition kernel of the diffusion process: .. math:: p(\mathbf{x}_t | \mathbf{x}_0) = \mu_t \mathbf{x} + \sigma_t \mathbf{w}, with :math:`\mathbf{w} \sim N(0,1)`. :param x: The positions of the atoms. :type x: torch.Tensor :param w: The noise term. :type w: torch.Tensor :param t: The time at which to calculate the transition kernel. :type t: torch.Tensor :returns: **transition_kernel** -- The transition kernel of the diffusion process. :rtype: torch.Tensor .. py:method:: noise(x0: torch.Tensor, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor Noise term of the SDE. Calculates the noise term of the SDE: .. math:: \mathbf{w} = rac{\mathbf{x}_t - \mu_t \mathbf{x}_0}{\sigma_t} Parameters ---------- x0: torch.Tensor x at time 0. xt: torch.Tensor x at time t. t: torch.Tensor The time at which to calculate the noise term. Returns ------- noise: torch.Tensor The noise term of the diffusion process.