agedi.diffusion.sdes.noise_schedules

Classes

NoiseSchedule

Abstract base class for diffusion noise schedules.

Linear

Linear noise schedule: f(t) = min + (max - min) * t.

Exponential

Exponential noise schedule: f(t) = min * (max/min)^t.

Cosine

Cosine noise schedule: f(t) = min + (max - min) * (1 - cos(πt)) / 2.

Module Contents

class agedi.diffusion.sdes.noise_schedules.NoiseSchedule(min: float, max: float)

Bases: abc.ABC

Abstract base class for diffusion noise schedules.

A noise schedule defines a function f(t) that controls the noise level during the forward diffusion process, where t [0, 1].

min
max
get_hparams() Dict

Return hyperparameters sufficient to reconstruct this noise schedule.

Returns a dictionary with a _target_ key plus min and max. Subclasses can call super().get_hparams() and add their own params.

Returns:

Hyperparameter dictionary.

Return type:

dict

abstractmethod f(t: torch.Tensor) torch.Tensor

Returns the noise schedule value at time t.

abstractmethod fprime(t: torch.Tensor) torch.Tensor

Returns the derivative of the noise schedule at time t.

abstractmethod fint(t: torch.Tensor) torch.Tensor

Return the integral of the noise schedule at time t

df2dt(t: torch.Tensor) torch.Tensor

Return the time derivative of f(t)² at time t.

Computed as 2 * f(t) * f'(t).

class agedi.diffusion.sdes.noise_schedules.Linear(min: float, max: float)

Bases: NoiseSchedule

Linear noise schedule: f(t) = min + (max - min) * t.

f(t: torch.Tensor) torch.Tensor

Evaluate the noise schedule at time t.

fprime(t: torch.Tensor) torch.Tensor

Return the derivative of the noise schedule at time t.

fint(t: torch.Tensor) torch.Tensor

Return the integral of the noise schedule from 0 to t.

class agedi.diffusion.sdes.noise_schedules.Exponential(min: float, max: float)

Bases: NoiseSchedule

Exponential noise schedule: f(t) = min * (max/min)^t.

f(t: torch.Tensor) torch.Tensor

Evaluate the noise schedule at time t.

fprime(t: torch.Tensor) torch.Tensor

Return the derivative of the noise schedule at time t.

fint(t: torch.Tensor) torch.Tensor

Return the integral of the noise schedule from 0 to t.

class agedi.diffusion.sdes.noise_schedules.Cosine(min: float, max: float)

Bases: NoiseSchedule

Cosine noise schedule: f(t) = min + (max - min) * (1 - cos(πt)) / 2.

f(t: torch.Tensor) torch.Tensor

Evaluate the noise schedule at time t.

fprime(t: torch.Tensor) torch.Tensor

Return the derivative of the noise schedule at time t.

fint(t: torch.Tensor) torch.Tensor

Return the integral of the noise schedule from 0 to t.