agedi.diffusion.distributions.base

Classes

Distribution

Base Class for noise distributions

Module Contents

class agedi.diffusion.distributions.base.Distribution(key: str | None = None, **kwargs)

Bases: abc.ABC

Base Class for noise distributions

Parameters:

key (str) – Key to identify the property from the batch

Return type:

Distribution

key = None
get_hparams() Dict

Return hyperparameters sufficient to reconstruct this distribution.

Returns a dictionary with a _target_ key (the fully-qualified class name) plus any constructor arguments stored on the base class. Subclasses should call super().get_hparams() and merge in their own parameters.

Returns:

Hyperparameter dictionary.

Return type:

dict

abstractmethod _sample(**kwargs) torch.Tensor

Sample distribution

Sample from the distribution and return tensor of shape self.key

Parameters:

kwargs (dict) – The parameters of the distribution

Returns:

Sampled tensor

Return type:

torch.Tensor

_setup(batch: agedi.data.AtomsGraph) None

Prepare distribution

Prepare the distribution for sampling of the batch

Parameters:

batch (AtomsGraph) – Batch of data

Return type:

None

get_callable(batch: agedi.data.AtomsGraph) Callable

Get callable function

Return a callable function that samples from the distribution

Parameters:

batch (AtomsGraph) – Batch of data

Returns:

Callable function that samples from the distribution

Return type:

Callable