agedi.diffusion.sdes.vp¶
Classes¶
Implements variance-preserving (VP) SDE. |
Module Contents¶
- class agedi.diffusion.sdes.vp.VP(beta_min: float = 0.01, beta_max: float = 3, **kwargs)¶
Bases:
agedi.diffusion.sdes.SDEImplements variance-preserving (VP) SDE.
- Parameters:
beta_min (float) – The minimum value of the beta parameter.
beta_max (float) – The maximum value of the beta parameter.
- Return type:
- beta_min = 0.01¶
- beta_max = 3¶
- noise_schedule¶
- get_hparams() Dict¶
Return hyperparameters for this VP SDE.
- drift(x: torch.Tensor, t: torch.Tensor) torch.Tensor¶
Implement VP drift term.
Defines the drift term of the SDE: f(x, t).
- Parameters:
x (torch.Tensor) – The positions of the atoms.
t (torch.Tensor) – The time at which to calculate the drift term.
- Returns:
drift – The drift term of the SDE.
- Return type:
torch.Tensor
- diffusion(t: torch.Tensor) torch.Tensor¶
Implement VP diffusion term.
Defines the diffusion term of the SDE: g(t).
- Parameters:
t (torch.Tensor) – The time at which to calculate the diffusion term.
- Returns:
diffusion – The diffusion term of the SDE.
- Return type:
torch.Tensor
- mean(t: torch.Tensor) torch.Tensor¶
Implement VP mean term.
Calculates the mean of transition kernel at time t: mu(t).
- Parameters:
t (torch.Tensor) – The time at which to calculate the mean.
- Returns:
mean – The mean of the diffusion process.
- Return type:
torch.Tensor
- var(t: torch.Tensor) torch.Tensor¶
Implement VP variance term.
Calculates the variance of transition kernel at time t: sigma^2(t).
- Parameters:
t (torch.Tensor) – The time at which to calculate the variance.
- Returns:
var – The variance of the diffusion process.
- Return type:
torch.Tensor
- beta(t: torch.Tensor) torch.Tensor¶
VP Beta function
Calculates the value of beta at time t.
- Parameters:
t (torch.Tensor) – The time at which to calculate beta.
- Returns:
beta – The value of beta at time t.
- Return type:
torch.Tensor
- alpha(t: torch.Tensor) torch.Tensor¶
VP Alpha function
Calculates the value of alpha at time t with .. math:: lpha(t) = int_{0}^{t} beta(s) ds.
- Parameters:
t (torch.Tensor) – The time at which to calculate alpha.
- Returns:
alpha – The value of alpha at time t.
- Return type:
torch.Tensor