agedi.diffusion.sdes¶
Submodules¶
Classes¶
Package Contents¶
- class agedi.diffusion.sdes.SDE(noise_schedule: agedi.diffusion.sdes.noise_schedules.NoiseSchedule = Linear)¶
Bases:
abc.ABCSDE base class
- noise_schedule_cls¶
- get_hparams() Dict¶
Return hyperparameters sufficient to reconstruct this SDE.
Returns a dictionary with a
_target_key (the fully-qualified class name). Subclasses should callsuper().get_hparams()and merge in their own constructor parameters.- Returns:
Hyperparameter dictionary.
- Return type:
dict
- abstractmethod drift(x: torch.Tensor, t: torch.Tensor) torch.Tensor¶
Drift term of the SDE.
Must be implemented by subclass.
Defines the drift term of the SDE: .. math:
f(x, t) = ...
- Parameters:
x (torch.Tensor) – The positions of the atoms.
t (torch.Tensor) – The time at which to calculate the drift term.
- Returns:
drift – The drift term of the SDE.
- Return type:
torch.Tensor
- abstractmethod diffusion(t: torch.Tensor) torch.Tensor¶
Diffusion term of the SDE.
Must be implemented by subclass.
Defines the diffusion term of the SDE: .. math:
g(t) = ...
- Parameters:
t (torch.Tensor) – The time at which to calculate the diffusion term.
- Returns:
diffusion – The diffusion term of the SDE.
- Return type:
torch.Tensor
- abstractmethod mean(t: torch.Tensor) torch.Tensor¶
Mean of the SDE.
Must be implemented by subclass.
Calculates the mean of transition kernel at time t: .. math:
\mu_t = ...
- Parameters:
t (torch.Tensor) – The time at which to calculate the mean.
- Returns:
mean – The mean of the diffusion process.
- Return type:
torch.Tensor
- abstractmethod var(t: torch.Tensor) torch.Tensor¶
Variance of the SDE.
Must be implemented by subclass.
Calculates the variance of transition kernel at time t: .. math:
\sigma_t^2 = ...
- Parameters:
t (torch.Tensor) – The time at which to calculate the variance.
- Returns:
var – The variance of the diffusion process.
- Return type:
torch.Tensor
- transition_kernel(x: torch.Tensor, t: torch.Tensor, w: Callable) torch.Tensor¶
Transition kernel of the SDE.
Calculates the transition kernel of the diffusion process: .. math:
- p(mathbf{x}_t | mathbf{x}_0) = mu_t mathbf{x} + sigma_t mathbf{w},
with \(\mathbf{w} \sim N(0,1)\).
- Parameters:
x (torch.Tensor) – The positions of the atoms.
w (torch.Tensor) – The noise term.
t (torch.Tensor) – The time at which to calculate the transition kernel.
- Returns:
transition_kernel – The transition kernel of the diffusion process.
- Return type:
torch.Tensor
- noise(x0: torch.Tensor, xt: torch.Tensor, t: torch.Tensor) torch.Tensor¶
Noise term of the SDE.
Calculates the noise term of the SDE: .. math:: mathbf{w} =
rac{mathbf{x}_t - mu_t mathbf{x}_0}{sigma_t}
- x0: torch.Tensor
x at time 0.
- xt: torch.Tensor
x at time t.
- t: torch.Tensor
The time at which to calculate the noise term.
- noise: torch.Tensor
The noise term of the diffusion process.
- class agedi.diffusion.sdes.VE(sigma_min: float = 0.01, sigma_max: float = 1.0, **kwargs)¶
Bases:
agedi.diffusion.sdes.SDEImplements variance-exploding (VE) SDE.
- Parameters:
sigma_min (float) – The minimum value of the sigma parameter.
sigma_max (float) – The maximum value of the sigma parameter.
- Return type:
- sigma_min = 0.01¶
- sigma_max = 1.0¶
- noise_schedule¶
- get_hparams() Dict¶
Return hyperparameters for this VE SDE.
- drift(x: torch.Tensor, t: torch.Tensor) torch.Tensor¶
Implement VP drift term.
Defines the drift term of the SDE: f(x, t).
- Parameters:
x (torch.Tensor) – The positions of the atoms.
t (torch.Tensor) – The time at which to calculate the drift term.
- Returns:
drift – The drift term of the SDE.
- Return type:
torch.Tensor
- diffusion(t: torch.Tensor) torch.Tensor¶
Implement VP diffusion term.
Defines the diffusion term of the SDE: g(t).
- Parameters:
t (torch.Tensor) – The time at which to calculate the diffusion term.
- Returns:
diffusion – The diffusion term of the SDE.
- Return type:
torch.Tensor
- mean(t: torch.Tensor) torch.Tensor¶
Implement VP mean term.
Calculates the mean of transition kernel at time t: mu(t).
- Parameters:
t (torch.Tensor) – The time at which to calculate the mean.
- Returns:
mean – The mean of the diffusion process.
- Return type:
torch.Tensor
- var(t: torch.Tensor) torch.Tensor¶
Implement VP variance term.
Calculates the variance of transition kernel at time t: sigma^2(t).
- Parameters:
t (torch.Tensor) – The time at which to calculate the variance.
- Returns:
var – The variance of the diffusion process.
- Return type:
torch.Tensor
- sigma(t: torch.Tensor) torch.Tensor¶
VE sigma function
Calculates the value of sigma at time t.
- Parameters:
t (torch.Tensor) – The time at which to calculate sigma.
- Returns:
sigma – The value of sigma at time t.
- Return type:
torch.Tensor
- class agedi.diffusion.sdes.VP(beta_min: float = 0.01, beta_max: float = 3, **kwargs)¶
Bases:
agedi.diffusion.sdes.SDEImplements variance-preserving (VP) SDE.
- Parameters:
beta_min (float) – The minimum value of the beta parameter.
beta_max (float) – The maximum value of the beta parameter.
- Return type:
- beta_min = 0.01¶
- beta_max = 3¶
- noise_schedule¶
- get_hparams() Dict¶
Return hyperparameters for this VP SDE.
- drift(x: torch.Tensor, t: torch.Tensor) torch.Tensor¶
Implement VP drift term.
Defines the drift term of the SDE: f(x, t).
- Parameters:
x (torch.Tensor) – The positions of the atoms.
t (torch.Tensor) – The time at which to calculate the drift term.
- Returns:
drift – The drift term of the SDE.
- Return type:
torch.Tensor
- diffusion(t: torch.Tensor) torch.Tensor¶
Implement VP diffusion term.
Defines the diffusion term of the SDE: g(t).
- Parameters:
t (torch.Tensor) – The time at which to calculate the diffusion term.
- Returns:
diffusion – The diffusion term of the SDE.
- Return type:
torch.Tensor
- mean(t: torch.Tensor) torch.Tensor¶
Implement VP mean term.
Calculates the mean of transition kernel at time t: mu(t).
- Parameters:
t (torch.Tensor) – The time at which to calculate the mean.
- Returns:
mean – The mean of the diffusion process.
- Return type:
torch.Tensor
- var(t: torch.Tensor) torch.Tensor¶
Implement VP variance term.
Calculates the variance of transition kernel at time t: sigma^2(t).
- Parameters:
t (torch.Tensor) – The time at which to calculate the variance.
- Returns:
var – The variance of the diffusion process.
- Return type:
torch.Tensor
- beta(t: torch.Tensor) torch.Tensor¶
VP Beta function
Calculates the value of beta at time t.
- Parameters:
t (torch.Tensor) – The time at which to calculate beta.
- Returns:
beta – The value of beta at time t.
- Return type:
torch.Tensor
- alpha(t: torch.Tensor) torch.Tensor¶
VP Alpha function
Calculates the value of alpha at time t with .. math:: lpha(t) = int_{0}^{t} beta(s) ds.
- Parameters:
t (torch.Tensor) – The time at which to calculate alpha.
- Returns:
alpha – The value of alpha at time t.
- Return type:
torch.Tensor