agedi.diffusion.distributions.constant¶
Classes¶
Constant Integer Distribution |
Module Contents¶
- class agedi.diffusion.distributions.constant.Constant(value: float = 0, key: str = 'x', dtype: Type = torch.int64, **kwargs)¶
Bases:
agedi.diffusion.distributions.DistributionConstant Integer Distribution
- value = 0¶
- dtype¶
- get_hparams() Dict¶
Return hyperparameters for this distribution.
- _setup(batch: agedi.data.AtomsGraph) None¶
Prepare the distribution for sampling from batch.
Sets
self.shapebased on the total number of atoms in the batch.- Parameters:
batch (AtomsGraph) – Batch of atomistic data.
- _sample(shape: torch.Size | None = None) torch.Tensor¶
Sample from the integer distribution
- Parameters:
mu (torch.Tensor) – Mean of the distribution
sigma (torch.Tensor) – Standard deviation of the distribution
- Returns:
Sampled tensor
- Return type:
torch.Tensor