agedi.diffusion.noisers.base

Classes

Noiser

Noiser Base class

Module Contents

class agedi.diffusion.noisers.base.Noiser(distribution: agedi.diffusion.distributions.Distribution, prior: agedi.diffusion.distributions.Distribution, loss_scaling: float = 1.0, key: str | None = None, **kwargs)

Bases: abc.ABC, torch.nn.Module

Noiser Base class

Impments a noiser that can noise and denoise a atomistic structure attribute.

Parameters:
  • distribution (Distribution) – The distribution to be used for the noising.

  • prior (Distribution) – The prior to be used for the denoising.

  • loss_scaling (float) – Scaling factor applied to this noiser’s loss contribution.

  • key (str, optional) – Override the class-level _key for the attribute to be noised and denoised. Useful for reusing a noiser class on a different attribute without subclassing purely to change _key.

Return type:

Noiser

_key: str
_registry: ClassVar[Dict[str, Callable[..., Noiser]]]
distribution
prior
loss_scaling = 1.0
classmethod register(name: str, factory: Callable[..., Noiser]) None

Register a noiser factory callable under name.

The factory is called with sde as a keyword argument containing the resolved SDE instance. Noisers that do not use an SDE can safely ignore it via **kwargs.

Parameters:
  • name (str) – Alias string used to reference the noiser in create_diffusion().

  • factory (Callable) – A callable that accepts sde as a keyword argument and returns a Noiser instance.

Examples

Register a custom noiser so it can be referenced by its alias:

from agedi.diffusion.noisers import Noiser

class MyNoiser(Noiser):
    ...

Noiser.register("my_noiser", lambda sde: MyNoiser(sde=sde))
get_hparams() Dict

Return hyperparameters sufficient to reconstruct this noiser.

Returns a dictionary with a _target_ key (the fully-qualified class name) plus distribution, prior, and loss_scaling entries taken from the base class. Subclasses should call super().get_hparams() and merge in their own constructor parameters.

Returns:

Hyperparameter dictionary.

Return type:

dict

property key: str

The key of the attribute to be noised and denoised.

abstractmethod _noise(batch: agedi.data.AtomsGraph) agedi.data.AtomsGraph

Noises the attribute of the atomistic structure.

Must be implemented by the subclass.

Parameters:

batch (AtomsGraph) – The atomistic structure (or batch hereof) to be noised.

Returns:

The noised atomistic structure (or bach hereof).

Return type:

AtomsGraph

abstractmethod _denoise(batch: agedi.data.AtomsGraph, delta_t: float, last: bool) agedi.data.AtomsGraph

Denoises the attribute of the atomistic structure.

Must be implemented by the subclass.

Parameters:
  • batch (AtomsGraph) – The atomistic structure (or batch hereof) to be denoised.

  • delta_t (float) – The time step to be used for the denoising.

Returns:

The denoised atomistic structure (or bach hereof).

Return type:

AtomsGraph

abstractmethod _loss(batch: agedi.data.AtomsGraph) float

Computes the training loss.

Must be implemented by the subclass.

Parameters:

batch (AtomsGraph) – The atomistic structure (or batch hereof) to be noised and denoised.

Returns:

The loss of the noised and denoised atomistic structure.

Return type:

float

noise(batch: agedi.data.AtomsGraph) agedi.data.AtomsGraph

Noises the attribute of the atomistic structure.

Parameters:

batch (AtomsGraph) – The atomistic structure (or batch hereof) to be noised.

Returns:

The noised atomistic structure (or bach hereof).

Return type:

AtomsGraph

denoise(batch: agedi.data.AtomsGraph, delta_t: float, last: bool) agedi.data.AtomsGraph

Denoises the attribute of the atomistic structure.

Parameters:
  • batch (AtomsGraph) – The atomistic structure (or batch hereof) to be denoised.

  • delta_t (float) – The time step to be used for the denoising.

  • last (bool) – If the denoising is the last step of the denoising.

Returns:

The denoised atomistic structure (or bach hereof).

Return type:

AtomsGraph

loss(batch: agedi.data.AtomsGraph) float

Compute the training loss.

Parameters:

batch (AtomsGraph) – The atomistic structure (or batch hereof) to be noised and denoised.

Returns:

The loss of the noised and denoised atomistic structure.

Return type:

float

langevin_step(batch: agedi.data.AtomsGraph, step_size: float | torch.Tensor = 0.01) agedi.data.AtomsGraph

Perform a Langevin corrector step at the current (constant) time.

Applies a small score-corrected Langevin update without advancing the diffusion time. The score must already be stored in batch[key + "_score"] (i.e. the score model must have been called before invoking this method).

The default implementation delegates to _denoise() with last=False and a fixed step_size. Subclasses may override this for a more specialised corrector.

Parameters:
  • batch (AtomsGraph) – The atomistic structure (or batch hereof) to be corrected.

  • step_size (float or torch.Tensor, optional) – Size of the Langevin corrector step. Passing a pre-created torch.Tensor avoids repeated tensor allocation when this method is called in a tight loop. Defaults to 0.01.

Returns:

The corrected atomistic structure.

Return type:

AtomsGraph

initialize_graph(batch: agedi.data.AtomsGraph) None

Initializes the graph with the prior distribution.

Can be overwritten by subclasses for specific initializations.

Parameters:

batch (AtomsGraph) – The atomistic structure (or batch hereof) to be noised and denoised.