agedi.diffusion.sdes.ve

Classes

VE

Implements variance-exploding (VE) SDE.

Module Contents

class agedi.diffusion.sdes.ve.VE(sigma_min: float = 0.01, sigma_max: float = 1.0, **kwargs)

Bases: agedi.diffusion.sdes.SDE

Implements variance-exploding (VE) SDE.

Parameters:
  • sigma_min (float) – The minimum value of the sigma parameter.

  • sigma_max (float) – The maximum value of the sigma parameter.

Return type:

VP

sigma_min = 0.01
sigma_max = 1.0
noise_schedule
get_hparams() Dict

Return hyperparameters for this VE SDE.

drift(x: torch.Tensor, t: torch.Tensor) torch.Tensor

Implement VP drift term.

Defines the drift term of the SDE: f(x, t).

Parameters:
  • x (torch.Tensor) – The positions of the atoms.

  • t (torch.Tensor) – The time at which to calculate the drift term.

Returns:

drift – The drift term of the SDE.

Return type:

torch.Tensor

diffusion(t: torch.Tensor) torch.Tensor

Implement VP diffusion term.

Defines the diffusion term of the SDE: g(t).

Parameters:

t (torch.Tensor) – The time at which to calculate the diffusion term.

Returns:

diffusion – The diffusion term of the SDE.

Return type:

torch.Tensor

mean(t: torch.Tensor) torch.Tensor

Implement VP mean term.

Calculates the mean of transition kernel at time t: mu(t).

Parameters:

t (torch.Tensor) – The time at which to calculate the mean.

Returns:

mean – The mean of the diffusion process.

Return type:

torch.Tensor

var(t: torch.Tensor) torch.Tensor

Implement VP variance term.

Calculates the variance of transition kernel at time t: sigma^2(t).

Parameters:

t (torch.Tensor) – The time at which to calculate the variance.

Returns:

var – The variance of the diffusion process.

Return type:

torch.Tensor

sigma(t: torch.Tensor) torch.Tensor

VE sigma function

Calculates the value of sigma at time t.

Parameters:

t (torch.Tensor) – The time at which to calculate sigma.

Returns:

sigma – The value of sigma at time t.

Return type:

torch.Tensor