agedi.diffusion.sdes.noise_schedules ==================================== .. py:module:: agedi.diffusion.sdes.noise_schedules Classes ------- .. autoapisummary:: agedi.diffusion.sdes.noise_schedules.NoiseSchedule agedi.diffusion.sdes.noise_schedules.Linear agedi.diffusion.sdes.noise_schedules.Exponential agedi.diffusion.sdes.noise_schedules.Cosine Module Contents --------------- .. py:class:: NoiseSchedule(min: float, max: float) Bases: :py:obj:`abc.ABC` Abstract base class for diffusion noise schedules. A noise schedule defines a function ``f(t)`` that controls the noise level during the forward diffusion process, where ``t ∈ [0, 1]``. .. py:attribute:: min .. py:attribute:: max .. py:method:: get_hparams() -> Dict Return hyperparameters sufficient to reconstruct this noise schedule. Returns a dictionary with a ``_target_`` key plus ``min`` and ``max``. Subclasses can call ``super().get_hparams()`` and add their own params. :returns: Hyperparameter dictionary. :rtype: dict .. py:method:: f(t: torch.Tensor) -> torch.Tensor :abstractmethod: Returns the noise schedule value at time t. .. py:method:: fprime(t: torch.Tensor) -> torch.Tensor :abstractmethod: Returns the derivative of the noise schedule at time t. .. py:method:: fint(t: torch.Tensor) -> torch.Tensor :abstractmethod: Return the integral of the noise schedule at time t .. py:method:: df2dt(t: torch.Tensor) -> torch.Tensor Return the time derivative of f(t)² at time *t*. Computed as ``2 * f(t) * f'(t)``. .. py:class:: Linear(min: float, max: float) Bases: :py:obj:`NoiseSchedule` Linear noise schedule: ``f(t) = min + (max - min) * t``. .. py:method:: f(t: torch.Tensor) -> torch.Tensor Evaluate the noise schedule at time *t*. .. py:method:: fprime(t: torch.Tensor) -> torch.Tensor Return the derivative of the noise schedule at time *t*. .. py:method:: fint(t: torch.Tensor) -> torch.Tensor Return the integral of the noise schedule from 0 to *t*. .. py:class:: Exponential(min: float, max: float) Bases: :py:obj:`NoiseSchedule` Exponential noise schedule: ``f(t) = min * (max/min)^t``. .. py:method:: f(t: torch.Tensor) -> torch.Tensor Evaluate the noise schedule at time *t*. .. py:method:: fprime(t: torch.Tensor) -> torch.Tensor Return the derivative of the noise schedule at time *t*. .. py:method:: fint(t: torch.Tensor) -> torch.Tensor Return the integral of the noise schedule from 0 to *t*. .. py:class:: Cosine(min: float, max: float) Bases: :py:obj:`NoiseSchedule` Cosine noise schedule: ``f(t) = min + (max - min) * (1 - cos(πt)) / 2``. .. py:method:: f(t: torch.Tensor) -> torch.Tensor Evaluate the noise schedule at time *t*. .. py:method:: fprime(t: torch.Tensor) -> torch.Tensor Return the derivative of the noise schedule at time *t*. .. py:method:: fint(t: torch.Tensor) -> torch.Tensor Return the integral of the noise schedule from 0 to *t*.