agedi.diffusion.distributions.normal

Attributes

Classes

StandardNormal

Standard Normal Distribution

Normal

Normal Distribution

TruncatedNormal

Truncated Normal Distribution

Module Contents

agedi.diffusion.distributions.normal._CONFINEMENT_CLAMP_EPS = 0.0001
class agedi.diffusion.distributions.normal.StandardNormal

Bases: agedi.diffusion.distributions.Distribution

Standard Normal Distribution

_setup(batch: agedi.data.AtomsGraph) None

Prepare the distribution for sampling from batch.

Sets self.shape to (n_atoms, *trailing) where n_atoms is read from batch.n_atoms and the trailing dimensions come from the existing attribute. Using n_atoms rather than the attribute’s leading dimension avoids a shape-mismatch when called during graph initialisation (via initialize_graph()), where the attribute tensor may still be empty even though n_atoms has already been set.

Parameters:

batch (AtomsGraph) – Batch of atomistic data.

_sample(shape: torch.Size | None = None, **kwargs) torch.Tensor

Sample from the standard normal distribution

Parameters:
  • mu (torch.Tensor) – Mean of the distribution

  • sigma (torch.Tensor) – Standard deviation of the distribution

Returns:

Sampled tensor

Return type:

torch.Tensor

class agedi.diffusion.distributions.normal.Normal

Bases: agedi.diffusion.distributions.Distribution

Normal Distribution

_sample(mu: torch.Tensor, sigma: torch.Tensor, **kwargs) torch.Tensor

Sample from the normal distribution

Parameters:
  • mu (torch.Tensor) – Mean of the distribution

  • sigma (torch.Tensor) – Standard deviation of the distribution

Returns:

Sampled tensor

Return type:

torch.Tensor

class agedi.diffusion.distributions.normal.TruncatedNormal(index: int = 2, **kwargs)

Bases: agedi.diffusion.distributions.Distribution

Truncated Normal Distribution

Parameters:

index (int) – The index of the property to truncate

index = 2
get_hparams() Dict

Return hyperparameters for this distribution.

_setup(batch: agedi.data.AtomsGraph) None

Setup the distribution

Prepare the distribution for sampling of the batch

Parameters:

batch (AtomsGraph) – Batch of data

Return type:

None

_sample(mu: torch.Tensor, sigma: torch.Tensor, **kwargs) torch.Tensor

Sample from the truncated normal distribution

Parameters:
  • mu (torch.Tensor) – Mean of the distribution

  • sigma (torch.Tensor) – Standard deviation of the distribution

Returns:

Sampled tensor

Return type:

torch.Tensor