agedi.diffusion.sdes.base¶
Classes¶
SDE base class |
Module Contents¶
- class agedi.diffusion.sdes.base.SDE(noise_schedule: agedi.diffusion.sdes.noise_schedules.NoiseSchedule = Linear)¶
Bases:
abc.ABCSDE base class
- noise_schedule_cls¶
- get_hparams() Dict¶
Return hyperparameters sufficient to reconstruct this SDE.
Returns a dictionary with a
_target_key (the fully-qualified class name). Subclasses should callsuper().get_hparams()and merge in their own constructor parameters.- Returns:
Hyperparameter dictionary.
- Return type:
dict
- abstractmethod drift(x: torch.Tensor, t: torch.Tensor) torch.Tensor¶
Drift term of the SDE.
Must be implemented by subclass.
Defines the drift term of the SDE: .. math:
f(x, t) = ...
- Parameters:
x (torch.Tensor) – The positions of the atoms.
t (torch.Tensor) – The time at which to calculate the drift term.
- Returns:
drift – The drift term of the SDE.
- Return type:
torch.Tensor
- abstractmethod diffusion(t: torch.Tensor) torch.Tensor¶
Diffusion term of the SDE.
Must be implemented by subclass.
Defines the diffusion term of the SDE: .. math:
g(t) = ...
- Parameters:
t (torch.Tensor) – The time at which to calculate the diffusion term.
- Returns:
diffusion – The diffusion term of the SDE.
- Return type:
torch.Tensor
- abstractmethod mean(t: torch.Tensor) torch.Tensor¶
Mean of the SDE.
Must be implemented by subclass.
Calculates the mean of transition kernel at time t: .. math:
\mu_t = ...
- Parameters:
t (torch.Tensor) – The time at which to calculate the mean.
- Returns:
mean – The mean of the diffusion process.
- Return type:
torch.Tensor
- abstractmethod var(t: torch.Tensor) torch.Tensor¶
Variance of the SDE.
Must be implemented by subclass.
Calculates the variance of transition kernel at time t: .. math:
\sigma_t^2 = ...
- Parameters:
t (torch.Tensor) – The time at which to calculate the variance.
- Returns:
var – The variance of the diffusion process.
- Return type:
torch.Tensor
- transition_kernel(x: torch.Tensor, t: torch.Tensor, w: Callable) torch.Tensor¶
Transition kernel of the SDE.
Calculates the transition kernel of the diffusion process: .. math:
- p(mathbf{x}_t | mathbf{x}_0) = mu_t mathbf{x} + sigma_t mathbf{w},
with \(\mathbf{w} \sim N(0,1)\).
- Parameters:
x (torch.Tensor) – The positions of the atoms.
w (torch.Tensor) – The noise term.
t (torch.Tensor) – The time at which to calculate the transition kernel.
- Returns:
transition_kernel – The transition kernel of the diffusion process.
- Return type:
torch.Tensor
- noise(x0: torch.Tensor, xt: torch.Tensor, t: torch.Tensor) torch.Tensor¶
Noise term of the SDE.
Calculates the noise term of the SDE: .. math:: mathbf{w} =
rac{mathbf{x}_t - mu_t mathbf{x}_0}{sigma_t}
- x0: torch.Tensor
x at time 0.
- xt: torch.Tensor
x at time t.
- t: torch.Tensor
The time at which to calculate the noise term.
- noise: torch.Tensor
The noise term of the diffusion process.