agedi.diffusion.noisers.sde

Classes

SDENoiser

Implements a SDE base class that can be inherited by other classes.

Module Contents

class agedi.diffusion.noisers.sde.SDENoiser(sde_class: agedi.diffusion.sdes.SDE, sde_kwargs: Dict | None, distribution: agedi.diffusion.distributions.Distribution, prior: agedi.diffusion.distributions.Distribution, sde: agedi.diffusion.sdes.SDE | None = None, **kwargs)

Bases: agedi.diffusion.noisers.Noiser, abc.ABC

Implements a SDE base class that can be inherited by other classes.

Parameters:
  • sde_class (SDE) – The class of the SDE to be used for the noising.

  • sde_kwargs (Dict) – The keyword arguments to be passed to the SDE class.

  • distribution (Distribution) – The distribution to be used for the noise.

  • prior (Distribution) – The prior distribution to be used for the noise.

  • sde (SDE, optional) – An already-instantiated SDE object. When provided, sde_class and sde_kwargs are ignored.

  • key (str) – The key to be used for the noising.

  • **kwargs – Additional keyword arguments to be passed to the Noiser class.

Returns:

The noiser for the atoms positions in Cartesian coordinates.

Return type:

Noiser

_key = None
get_hparams() Dict

Return hyperparameters for this SDE noiser.

abstractmethod postprocess_score(score: torch.Tensor) torch.Tensor

Post-process the predicted score before computing the loss.

Parameters:

score (torch.Tensor) – Raw predicted score tensor.

Returns:

Post-processed score tensor.

Return type:

torch.Tensor

abstractmethod postprocess_noise(noise: torch.Tensor) torch.Tensor

Post-process the noise tensor before computing the loss.

Parameters:

noise (torch.Tensor) – Raw noise tensor.

Returns:

Post-processed noise tensor.

Return type:

torch.Tensor

_noise(batch: agedi.data.AtomsGraph) agedi.data.AtomsGraph

Adds noise to the atomistic structure.

Added noise is stored in the self.key+”_noise”.

Parameters:

batch (AtomsGraph) – The atomistic structure (or batch hereof) to be noised.

Returns:

The noised atomistic structure (or bach hereof).

Return type:

AtomsGraph

_denoise(batch: agedi.data.AtomsGraph, delta_t: float, last: bool) agedi.data.AtomsGraph

Denoises the positions of the atomistic structure.

The denoising follows the Euler-Maruyama scheme. ::math:: R_i+1 = R_i +

Delta t (f(R_i, t) + g(t)**2 * s(R_i, t)) + sqrt{Delta t} g(t) * w

The used score is expected to be stored in the self.key+”_score”.

Parameters:
  • batch (AtomsGraph) – The atomistic structure (or batch hereof) to be denoised.

  • delta_t (float) – The time step for the denoising.

  • last (bool) – If the denoising is the last step of the denoising.

Returns:

The denoised atomistic structure (or bach hereof).

Return type:

AtomsGraph

_loss(batch: agedi.data.AtomsGraph) torch.Tensor

Compute the noiser loss.

Computes the loss of the diffusion model SDE noiser

Expects the total added noise to be stored in the self.key+”_noise”, and the predicted score to be stored in the self.key+”_score”.

The loss is computed as ::math:: L = sum_i ||sigma_t w_i + sigma_t^2 s(R_i)||^2

Parameters:

batch (AtomsGraph) – The atomistic structure (or batch hereof) to be noised and denoised.

Returns:

The loss of the noised and denoised atomistic structure.

Return type:

float