agedi.diffusion.distributions.base¶
Classes¶
Base Class for noise distributions |
Module Contents¶
- class agedi.diffusion.distributions.base.Distribution(key: str | None = None, **kwargs)¶
Bases:
abc.ABCBase Class for noise distributions
- Parameters:
key (str) – Key to identify the property from the batch
- Return type:
- key = None¶
- get_hparams() Dict¶
Return hyperparameters sufficient to reconstruct this distribution.
Returns a dictionary with a
_target_key (the fully-qualified class name) plus any constructor arguments stored on the base class. Subclasses should callsuper().get_hparams()and merge in their own parameters.- Returns:
Hyperparameter dictionary.
- Return type:
dict
- abstractmethod _sample(**kwargs) torch.Tensor¶
Sample distribution
Sample from the distribution and return tensor of shape self.key
- Parameters:
kwargs (dict) – The parameters of the distribution
- Returns:
Sampled tensor
- Return type:
torch.Tensor
- _setup(batch: agedi.data.AtomsGraph) None¶
Prepare distribution
Prepare the distribution for sampling of the batch
- Parameters:
batch (AtomsGraph) – Batch of data
- Return type:
None
- get_callable(batch: agedi.data.AtomsGraph) Callable¶
Get callable function
Return a callable function that samples from the distribution
- Parameters:
batch (AtomsGraph) – Batch of data
- Returns:
Callable function that samples from the distribution
- Return type:
Callable