agedi.diffusion.distributions¶
Submodules¶
Classes¶
Base Class for noise distributions |
|
Standard Normal Distribution |
|
Normal Distribution |
|
Truncated Normal Distribution |
|
Uniform Distribution |
|
Uniform Prior Distribution for cell parameters |
|
Uniform Prior Distribution for cell parameters with Z-directional confinement |
|
Constant Integer Distribution |
|
Categorical Distribution |
Package Contents¶
- class agedi.diffusion.distributions.Distribution(key: str | None = None, **kwargs)¶
Bases:
abc.ABCBase Class for noise distributions
- Parameters:
key (str) – Key to identify the property from the batch
- Return type:
- key = None¶
- get_hparams() Dict¶
Return hyperparameters sufficient to reconstruct this distribution.
Returns a dictionary with a
_target_key (the fully-qualified class name) plus any constructor arguments stored on the base class. Subclasses should callsuper().get_hparams()and merge in their own parameters.- Returns:
Hyperparameter dictionary.
- Return type:
dict
- abstractmethod _sample(**kwargs) torch.Tensor¶
Sample distribution
Sample from the distribution and return tensor of shape self.key
- Parameters:
kwargs (dict) – The parameters of the distribution
- Returns:
Sampled tensor
- Return type:
torch.Tensor
- _setup(batch: agedi.data.AtomsGraph) None¶
Prepare distribution
Prepare the distribution for sampling of the batch
- Parameters:
batch (AtomsGraph) – Batch of data
- Return type:
None
- get_callable(batch: agedi.data.AtomsGraph) Callable¶
Get callable function
Return a callable function that samples from the distribution
- Parameters:
batch (AtomsGraph) – Batch of data
- Returns:
Callable function that samples from the distribution
- Return type:
Callable
- class agedi.diffusion.distributions.StandardNormal¶
Bases:
agedi.diffusion.distributions.DistributionStandard Normal Distribution
- _setup(batch: agedi.data.AtomsGraph) None¶
Prepare the distribution for sampling from batch.
Sets
self.shapeto(n_atoms, *trailing)wheren_atomsis read frombatch.n_atomsand the trailing dimensions come from the existing attribute. Usingn_atomsrather than the attribute’s leading dimension avoids a shape-mismatch when called during graph initialisation (viainitialize_graph()), where the attribute tensor may still be empty even thoughn_atomshas already been set.- Parameters:
batch (AtomsGraph) – Batch of atomistic data.
- _sample(shape: torch.Size | None = None, **kwargs) torch.Tensor¶
Sample from the standard normal distribution
- Parameters:
mu (torch.Tensor) – Mean of the distribution
sigma (torch.Tensor) – Standard deviation of the distribution
- Returns:
Sampled tensor
- Return type:
torch.Tensor
- class agedi.diffusion.distributions.Normal¶
Bases:
agedi.diffusion.distributions.DistributionNormal Distribution
- _sample(mu: torch.Tensor, sigma: torch.Tensor, **kwargs) torch.Tensor¶
Sample from the normal distribution
- Parameters:
mu (torch.Tensor) – Mean of the distribution
sigma (torch.Tensor) – Standard deviation of the distribution
- Returns:
Sampled tensor
- Return type:
torch.Tensor
- class agedi.diffusion.distributions.TruncatedNormal(index: int = 2, **kwargs)¶
Bases:
agedi.diffusion.distributions.DistributionTruncated Normal Distribution
- Parameters:
index (int) – The index of the property to truncate
- index = 2¶
- get_hparams() Dict¶
Return hyperparameters for this distribution.
- _setup(batch: agedi.data.AtomsGraph) None¶
Setup the distribution
Prepare the distribution for sampling of the batch
- Parameters:
batch (AtomsGraph) – Batch of data
- Return type:
None
- _sample(mu: torch.Tensor, sigma: torch.Tensor, **kwargs) torch.Tensor¶
Sample from the truncated normal distribution
- Parameters:
mu (torch.Tensor) – Mean of the distribution
sigma (torch.Tensor) – Standard deviation of the distribution
- Returns:
Sampled tensor
- Return type:
torch.Tensor
- class agedi.diffusion.distributions.Uniform(low: float = 0.0, high: float = 1.0, key: str = 'x', **kwargs)¶
Bases:
agedi.diffusion.distributions.DistributionUniform Distribution
- Parameters:
low (float) – The lower bound of the distribution
high (float) – The upper bound of the distribution
- low = 0.0¶
- high = 1.0¶
- get_hparams() Dict¶
Return hyperparameters for this distribution.
- _setup(batch: agedi.data.AtomsGraph) None¶
Prepare the distribution for sampling from batch.
Sets
self.shapeto the shape of the target attribute in the batch.- Parameters:
batch (AtomsGraph) – Batch of atomistic data.
- _sample(shape: torch.Size | None = None, **kwargs) torch.Tensor¶
Sample from the uniform distribution
- Parameters:
shape (torch.Size) – The shape of the sample
- Returns:
Sampled tensor
- Return type:
torch.Tensor
- class agedi.diffusion.distributions.UniformCell(low: float = 0.0, high: float = 1.0, key: str = 'x', **kwargs)¶
Bases:
UniformUniform Prior Distribution for cell parameters
- _setup(batch: agedi.data.AtomsGraph) None¶
Prepare the distribution for sampling of the batch
- Parameters:
batch (AtomsGraph) – Batch of data
- Return type:
None
- _sample(**kwargs) torch.Tensor¶
Sample from the uniform distribution
- Parameters:
mu (torch.Tensor) – Mean of the distribution
sigma (torch.Tensor) – Standard deviation of the distribution
- Returns:
Sampled tensor
- Return type:
torch.Tensor
- class agedi.diffusion.distributions.UniformCellConfined(low: float = 0.0, high: float = 1.0, key: str = 'x', **kwargs)¶
Bases:
UniformCellUniform Prior Distribution for cell parameters with Z-directional confinement
- _setup(batch: agedi.data.AtomsGraph) None¶
Prepare the distribution for sampling of the batch
- Parameters:
batch (AtomsGraph) – Batch of data
- Return type:
None
- class agedi.diffusion.distributions.Constant(value: float = 0, key: str = 'x', dtype: Type = torch.int64, **kwargs)¶
Bases:
agedi.diffusion.distributions.DistributionConstant Integer Distribution
- value = 0¶
- dtype¶
- get_hparams() Dict¶
Return hyperparameters for this distribution.
- _setup(batch: agedi.data.AtomsGraph) None¶
Prepare the distribution for sampling from batch.
Sets
self.shapebased on the total number of atoms in the batch.- Parameters:
batch (AtomsGraph) – Batch of atomistic data.
- _sample(shape: torch.Size | None = None) torch.Tensor¶
Sample from the integer distribution
- Parameters:
mu (torch.Tensor) – Mean of the distribution
sigma (torch.Tensor) – Standard deviation of the distribution
- Returns:
Sampled tensor
- Return type:
torch.Tensor
- class agedi.diffusion.distributions.Categorical¶
Bases:
agedi.diffusion.distributions.DistributionCategorical Distribution
Implements hard sampling using the Gumbel-Max trick.
- _sample(probs: torch.Tensor) torch.Tensor¶
Sample from the categorical distribution where probabilites define the likelihood of mu value to be set to the masked, 0, value
- Parameters:
probs (torch.Tensor) – The probabilities of each category
- Returns:
Sampled tensor
- Return type:
torch.Tensor