agedi.diffusion.distributions

Submodules

Classes

Distribution

Base Class for noise distributions

StandardNormal

Standard Normal Distribution

Normal

Normal Distribution

TruncatedNormal

Truncated Normal Distribution

Uniform

Uniform Distribution

UniformCell

Uniform Prior Distribution for cell parameters

UniformCellConfined

Uniform Prior Distribution for cell parameters with Z-directional confinement

Constant

Constant Integer Distribution

Categorical

Categorical Distribution

Package Contents

class agedi.diffusion.distributions.Distribution(key: str | None = None, **kwargs)

Bases: abc.ABC

Base Class for noise distributions

Parameters:

key (str) – Key to identify the property from the batch

Return type:

Distribution

key = None
get_hparams() Dict

Return hyperparameters sufficient to reconstruct this distribution.

Returns a dictionary with a _target_ key (the fully-qualified class name) plus any constructor arguments stored on the base class. Subclasses should call super().get_hparams() and merge in their own parameters.

Returns:

Hyperparameter dictionary.

Return type:

dict

abstractmethod _sample(**kwargs) torch.Tensor

Sample distribution

Sample from the distribution and return tensor of shape self.key

Parameters:

kwargs (dict) – The parameters of the distribution

Returns:

Sampled tensor

Return type:

torch.Tensor

_setup(batch: agedi.data.AtomsGraph) None

Prepare distribution

Prepare the distribution for sampling of the batch

Parameters:

batch (AtomsGraph) – Batch of data

Return type:

None

get_callable(batch: agedi.data.AtomsGraph) Callable

Get callable function

Return a callable function that samples from the distribution

Parameters:

batch (AtomsGraph) – Batch of data

Returns:

Callable function that samples from the distribution

Return type:

Callable

class agedi.diffusion.distributions.StandardNormal

Bases: agedi.diffusion.distributions.Distribution

Standard Normal Distribution

_setup(batch: agedi.data.AtomsGraph) None

Prepare the distribution for sampling from batch.

Sets self.shape to (n_atoms, *trailing) where n_atoms is read from batch.n_atoms and the trailing dimensions come from the existing attribute. Using n_atoms rather than the attribute’s leading dimension avoids a shape-mismatch when called during graph initialisation (via initialize_graph()), where the attribute tensor may still be empty even though n_atoms has already been set.

Parameters:

batch (AtomsGraph) – Batch of atomistic data.

_sample(shape: torch.Size | None = None, **kwargs) torch.Tensor

Sample from the standard normal distribution

Parameters:
  • mu (torch.Tensor) – Mean of the distribution

  • sigma (torch.Tensor) – Standard deviation of the distribution

Returns:

Sampled tensor

Return type:

torch.Tensor

class agedi.diffusion.distributions.Normal

Bases: agedi.diffusion.distributions.Distribution

Normal Distribution

_sample(mu: torch.Tensor, sigma: torch.Tensor, **kwargs) torch.Tensor

Sample from the normal distribution

Parameters:
  • mu (torch.Tensor) – Mean of the distribution

  • sigma (torch.Tensor) – Standard deviation of the distribution

Returns:

Sampled tensor

Return type:

torch.Tensor

class agedi.diffusion.distributions.TruncatedNormal(index: int = 2, **kwargs)

Bases: agedi.diffusion.distributions.Distribution

Truncated Normal Distribution

Parameters:

index (int) – The index of the property to truncate

index = 2
get_hparams() Dict

Return hyperparameters for this distribution.

_setup(batch: agedi.data.AtomsGraph) None

Setup the distribution

Prepare the distribution for sampling of the batch

Parameters:

batch (AtomsGraph) – Batch of data

Return type:

None

_sample(mu: torch.Tensor, sigma: torch.Tensor, **kwargs) torch.Tensor

Sample from the truncated normal distribution

Parameters:
  • mu (torch.Tensor) – Mean of the distribution

  • sigma (torch.Tensor) – Standard deviation of the distribution

Returns:

Sampled tensor

Return type:

torch.Tensor

class agedi.diffusion.distributions.Uniform(low: float = 0.0, high: float = 1.0, key: str = 'x', **kwargs)

Bases: agedi.diffusion.distributions.Distribution

Uniform Distribution

Parameters:
  • low (float) – The lower bound of the distribution

  • high (float) – The upper bound of the distribution

low = 0.0
high = 1.0
get_hparams() Dict

Return hyperparameters for this distribution.

_setup(batch: agedi.data.AtomsGraph) None

Prepare the distribution for sampling from batch.

Sets self.shape to the shape of the target attribute in the batch.

Parameters:

batch (AtomsGraph) – Batch of atomistic data.

_sample(shape: torch.Size | None = None, **kwargs) torch.Tensor

Sample from the uniform distribution

Parameters:

shape (torch.Size) – The shape of the sample

Returns:

Sampled tensor

Return type:

torch.Tensor

class agedi.diffusion.distributions.UniformCell(low: float = 0.0, high: float = 1.0, key: str = 'x', **kwargs)

Bases: Uniform

Uniform Prior Distribution for cell parameters

_setup(batch: agedi.data.AtomsGraph) None

Prepare the distribution for sampling of the batch

Parameters:

batch (AtomsGraph) – Batch of data

Return type:

None

_sample(**kwargs) torch.Tensor

Sample from the uniform distribution

Parameters:
  • mu (torch.Tensor) – Mean of the distribution

  • sigma (torch.Tensor) – Standard deviation of the distribution

Returns:

Sampled tensor

Return type:

torch.Tensor

class agedi.diffusion.distributions.UniformCellConfined(low: float = 0.0, high: float = 1.0, key: str = 'x', **kwargs)

Bases: UniformCell

Uniform Prior Distribution for cell parameters with Z-directional confinement

_setup(batch: agedi.data.AtomsGraph) None

Prepare the distribution for sampling of the batch

Parameters:

batch (AtomsGraph) – Batch of data

Return type:

None

class agedi.diffusion.distributions.Constant(value: float = 0, key: str = 'x', dtype: Type = torch.int64, **kwargs)

Bases: agedi.diffusion.distributions.Distribution

Constant Integer Distribution

value = 0
dtype
get_hparams() Dict

Return hyperparameters for this distribution.

_setup(batch: agedi.data.AtomsGraph) None

Prepare the distribution for sampling from batch.

Sets self.shape based on the total number of atoms in the batch.

Parameters:

batch (AtomsGraph) – Batch of atomistic data.

_sample(shape: torch.Size | None = None) torch.Tensor

Sample from the integer distribution

Parameters:
  • mu (torch.Tensor) – Mean of the distribution

  • sigma (torch.Tensor) – Standard deviation of the distribution

Returns:

Sampled tensor

Return type:

torch.Tensor

class agedi.diffusion.distributions.Categorical

Bases: agedi.diffusion.distributions.Distribution

Categorical Distribution

Implements hard sampling using the Gumbel-Max trick.

_sample(probs: torch.Tensor) torch.Tensor

Sample from the categorical distribution where probabilites define the likelihood of mu value to be set to the masked, 0, value

Parameters:

probs (torch.Tensor) – The probabilities of each category

Returns:

Sampled tensor

Return type:

torch.Tensor