agedi.diffusion.sdes.ve¶
Classes¶
Implements variance-exploding (VE) SDE. |
Module Contents¶
- class agedi.diffusion.sdes.ve.VE(sigma_min: float = 0.01, sigma_max: float = 1.0, **kwargs)¶
Bases:
agedi.diffusion.sdes.SDEImplements variance-exploding (VE) SDE.
- Parameters:
sigma_min (float) – The minimum value of the sigma parameter.
sigma_max (float) – The maximum value of the sigma parameter.
- Return type:
- sigma_min = 0.01¶
- sigma_max = 1.0¶
- noise_schedule¶
- get_hparams() Dict¶
Return hyperparameters for this VE SDE.
- drift(x: torch.Tensor, t: torch.Tensor) torch.Tensor¶
Implement VP drift term.
Defines the drift term of the SDE: f(x, t).
- Parameters:
x (torch.Tensor) – The positions of the atoms.
t (torch.Tensor) – The time at which to calculate the drift term.
- Returns:
drift – The drift term of the SDE.
- Return type:
torch.Tensor
- diffusion(t: torch.Tensor) torch.Tensor¶
Implement VP diffusion term.
Defines the diffusion term of the SDE: g(t).
- Parameters:
t (torch.Tensor) – The time at which to calculate the diffusion term.
- Returns:
diffusion – The diffusion term of the SDE.
- Return type:
torch.Tensor
- mean(t: torch.Tensor) torch.Tensor¶
Implement VP mean term.
Calculates the mean of transition kernel at time t: mu(t).
- Parameters:
t (torch.Tensor) – The time at which to calculate the mean.
- Returns:
mean – The mean of the diffusion process.
- Return type:
torch.Tensor
- var(t: torch.Tensor) torch.Tensor¶
Implement VP variance term.
Calculates the variance of transition kernel at time t: sigma^2(t).
- Parameters:
t (torch.Tensor) – The time at which to calculate the variance.
- Returns:
var – The variance of the diffusion process.
- Return type:
torch.Tensor
- sigma(t: torch.Tensor) torch.Tensor¶
VE sigma function
Calculates the value of sigma at time t.
- Parameters:
t (torch.Tensor) – The time at which to calculate sigma.
- Returns:
sigma – The value of sigma at time t.
- Return type:
torch.Tensor