agedi.diffusion.distributions.constant

Classes

Constant

Constant Integer Distribution

Module Contents

class agedi.diffusion.distributions.constant.Constant(value: float = 0, key: str = 'x', dtype: Type = torch.int64, **kwargs)

Bases: agedi.diffusion.distributions.Distribution

Constant Integer Distribution

value = 0
dtype
get_hparams() Dict

Return hyperparameters for this distribution.

_setup(batch: agedi.data.AtomsGraph) None

Prepare the distribution for sampling from batch.

Sets self.shape based on the total number of atoms in the batch.

Parameters:

batch (AtomsGraph) – Batch of atomistic data.

_sample(shape: torch.Size | None = None) torch.Tensor

Sample from the integer distribution

Parameters:
  • mu (torch.Tensor) – Mean of the distribution

  • sigma (torch.Tensor) – Standard deviation of the distribution

Returns:

Sampled tensor

Return type:

torch.Tensor