agedi.diffusion.sdes.vp

Classes

VP

Implements variance-preserving (VP) SDE.

Module Contents

class agedi.diffusion.sdes.vp.VP(beta_min: float = 0.01, beta_max: float = 3, **kwargs)

Bases: agedi.diffusion.sdes.SDE

Implements variance-preserving (VP) SDE.

Parameters:
  • beta_min (float) – The minimum value of the beta parameter.

  • beta_max (float) – The maximum value of the beta parameter.

Return type:

VP

beta_min = 0.01
beta_max = 3
noise_schedule
get_hparams() Dict

Return hyperparameters for this VP SDE.

drift(x: torch.Tensor, t: torch.Tensor) torch.Tensor

Implement VP drift term.

Defines the drift term of the SDE: f(x, t).

Parameters:
  • x (torch.Tensor) – The positions of the atoms.

  • t (torch.Tensor) – The time at which to calculate the drift term.

Returns:

drift – The drift term of the SDE.

Return type:

torch.Tensor

diffusion(t: torch.Tensor) torch.Tensor

Implement VP diffusion term.

Defines the diffusion term of the SDE: g(t).

Parameters:

t (torch.Tensor) – The time at which to calculate the diffusion term.

Returns:

diffusion – The diffusion term of the SDE.

Return type:

torch.Tensor

mean(t: torch.Tensor) torch.Tensor

Implement VP mean term.

Calculates the mean of transition kernel at time t: mu(t).

Parameters:

t (torch.Tensor) – The time at which to calculate the mean.

Returns:

mean – The mean of the diffusion process.

Return type:

torch.Tensor

var(t: torch.Tensor) torch.Tensor

Implement VP variance term.

Calculates the variance of transition kernel at time t: sigma^2(t).

Parameters:

t (torch.Tensor) – The time at which to calculate the variance.

Returns:

var – The variance of the diffusion process.

Return type:

torch.Tensor

beta(t: torch.Tensor) torch.Tensor

VP Beta function

Calculates the value of beta at time t.

Parameters:

t (torch.Tensor) – The time at which to calculate beta.

Returns:

beta – The value of beta at time t.

Return type:

torch.Tensor

alpha(t: torch.Tensor) torch.Tensor

VP Alpha function

Calculates the value of alpha at time t with .. math:: lpha(t) = int_{0}^{t} beta(s) ds.

Parameters:

t (torch.Tensor) – The time at which to calculate alpha.

Returns:

alpha – The value of alpha at time t.

Return type:

torch.Tensor