agedi.diffusion.distributions.normal¶
Attributes¶
Classes¶
Standard Normal Distribution |
|
Normal Distribution |
|
Truncated Normal Distribution |
Module Contents¶
- agedi.diffusion.distributions.normal._CONFINEMENT_CLAMP_EPS = 0.0001¶
- class agedi.diffusion.distributions.normal.StandardNormal¶
Bases:
agedi.diffusion.distributions.DistributionStandard Normal Distribution
- _setup(batch: agedi.data.AtomsGraph) None¶
Prepare the distribution for sampling from batch.
Sets
self.shapeto(n_atoms, *trailing)wheren_atomsis read frombatch.n_atomsand the trailing dimensions come from the existing attribute. Usingn_atomsrather than the attribute’s leading dimension avoids a shape-mismatch when called during graph initialisation (viainitialize_graph()), where the attribute tensor may still be empty even thoughn_atomshas already been set.- Parameters:
batch (AtomsGraph) – Batch of atomistic data.
- _sample(shape: torch.Size | None = None, **kwargs) torch.Tensor¶
Sample from the standard normal distribution
- Parameters:
mu (torch.Tensor) – Mean of the distribution
sigma (torch.Tensor) – Standard deviation of the distribution
- Returns:
Sampled tensor
- Return type:
torch.Tensor
- class agedi.diffusion.distributions.normal.Normal¶
Bases:
agedi.diffusion.distributions.DistributionNormal Distribution
- _sample(mu: torch.Tensor, sigma: torch.Tensor, **kwargs) torch.Tensor¶
Sample from the normal distribution
- Parameters:
mu (torch.Tensor) – Mean of the distribution
sigma (torch.Tensor) – Standard deviation of the distribution
- Returns:
Sampled tensor
- Return type:
torch.Tensor
- class agedi.diffusion.distributions.normal.TruncatedNormal(index: int = 2, **kwargs)¶
Bases:
agedi.diffusion.distributions.DistributionTruncated Normal Distribution
- Parameters:
index (int) – The index of the property to truncate
- index = 2¶
- get_hparams() Dict¶
Return hyperparameters for this distribution.
- _setup(batch: agedi.data.AtomsGraph) None¶
Setup the distribution
Prepare the distribution for sampling of the batch
- Parameters:
batch (AtomsGraph) – Batch of data
- Return type:
None
- _sample(mu: torch.Tensor, sigma: torch.Tensor, **kwargs) torch.Tensor¶
Sample from the truncated normal distribution
- Parameters:
mu (torch.Tensor) – Mean of the distribution
sigma (torch.Tensor) – Standard deviation of the distribution
- Returns:
Sampled tensor
- Return type:
torch.Tensor