agedi.diffusion.noisers.sde¶
Classes¶
Implements a SDE base class that can be inherited by other classes. |
Module Contents¶
- class agedi.diffusion.noisers.sde.SDENoiser(sde_class: agedi.diffusion.sdes.SDE, sde_kwargs: Dict | None, distribution: agedi.diffusion.distributions.Distribution, prior: agedi.diffusion.distributions.Distribution, sde: agedi.diffusion.sdes.SDE | None = None, **kwargs)¶
Bases:
agedi.diffusion.noisers.Noiser,abc.ABCImplements a SDE base class that can be inherited by other classes.
- Parameters:
sde_class (SDE) – The class of the SDE to be used for the noising.
sde_kwargs (Dict) – The keyword arguments to be passed to the SDE class.
distribution (Distribution) – The distribution to be used for the noise.
prior (Distribution) – The prior distribution to be used for the noise.
sde (SDE, optional) – An already-instantiated SDE object. When provided, sde_class and sde_kwargs are ignored.
key (str) – The key to be used for the noising.
**kwargs – Additional keyword arguments to be passed to the Noiser class.
- Returns:
The noiser for the atoms positions in Cartesian coordinates.
- Return type:
- _key = None¶
- get_hparams() Dict¶
Return hyperparameters for this SDE noiser.
- abstractmethod postprocess_score(score: torch.Tensor) torch.Tensor¶
Post-process the predicted score before computing the loss.
- Parameters:
score (torch.Tensor) – Raw predicted score tensor.
- Returns:
Post-processed score tensor.
- Return type:
torch.Tensor
- abstractmethod postprocess_noise(noise: torch.Tensor) torch.Tensor¶
Post-process the noise tensor before computing the loss.
- Parameters:
noise (torch.Tensor) – Raw noise tensor.
- Returns:
Post-processed noise tensor.
- Return type:
torch.Tensor
- _noise(batch: agedi.data.AtomsGraph) agedi.data.AtomsGraph¶
Adds noise to the atomistic structure.
Added noise is stored in the self.key+”_noise”.
- Parameters:
batch (AtomsGraph) – The atomistic structure (or batch hereof) to be noised.
- Returns:
The noised atomistic structure (or bach hereof).
- Return type:
- _denoise(batch: agedi.data.AtomsGraph, delta_t: float, last: bool) agedi.data.AtomsGraph¶
Denoises the positions of the atomistic structure.
The denoising follows the Euler-Maruyama scheme. ::math:: R_i+1 = R_i +
Delta t (f(R_i, t) + g(t)**2 * s(R_i, t)) + sqrt{Delta t} g(t) * w
The used score is expected to be stored in the self.key+”_score”.
- Parameters:
batch (AtomsGraph) – The atomistic structure (or batch hereof) to be denoised.
delta_t (float) – The time step for the denoising.
last (bool) – If the denoising is the last step of the denoising.
- Returns:
The denoised atomistic structure (or bach hereof).
- Return type:
- _loss(batch: agedi.data.AtomsGraph) torch.Tensor¶
Compute the noiser loss.
Computes the loss of the diffusion model SDE noiser
Expects the total added noise to be stored in the self.key+”_noise”, and the predicted score to be stored in the self.key+”_score”.
The loss is computed as ::math:: L = sum_i ||sigma_t w_i + sigma_t^2 s(R_i)||^2
- Parameters:
batch (AtomsGraph) – The atomistic structure (or batch hereof) to be noised and denoised.
- Returns:
The loss of the noised and denoised atomistic structure.
- Return type:
float