agedi.diffusion.distributions ============================= .. py:module:: agedi.diffusion.distributions Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/agedi/diffusion/distributions/base/index /autoapi/agedi/diffusion/distributions/categorical/index /autoapi/agedi/diffusion/distributions/constant/index /autoapi/agedi/diffusion/distributions/normal/index /autoapi/agedi/diffusion/distributions/uniform/index Classes ------- .. autoapisummary:: agedi.diffusion.distributions.Distribution agedi.diffusion.distributions.StandardNormal agedi.diffusion.distributions.Normal agedi.diffusion.distributions.TruncatedNormal agedi.diffusion.distributions.Uniform agedi.diffusion.distributions.UniformCell agedi.diffusion.distributions.UniformCellConfined agedi.diffusion.distributions.Constant agedi.diffusion.distributions.Categorical Package Contents ---------------- .. py:class:: Distribution(key: Optional[str] = None, **kwargs) Bases: :py:obj:`abc.ABC` Base Class for noise distributions :param key: Key to identify the property from the batch :type key: str :rtype: Distribution .. py:attribute:: key :value: None .. py:method:: get_hparams() -> Dict Return hyperparameters sufficient to reconstruct this distribution. Returns a dictionary with a ``_target_`` key (the fully-qualified class name) plus any constructor arguments stored on the base class. Subclasses should call ``super().get_hparams()`` and merge in their own parameters. :returns: Hyperparameter dictionary. :rtype: dict .. py:method:: _sample(**kwargs) -> torch.Tensor :abstractmethod: Sample distribution Sample from the distribution and return tensor of shape self.key :param kwargs: The parameters of the distribution :type kwargs: dict :returns: Sampled tensor :rtype: torch.Tensor .. py:method:: _setup(batch: agedi.data.AtomsGraph) -> None Prepare distribution Prepare the distribution for sampling of the batch :param batch: Batch of data :type batch: AtomsGraph :rtype: None .. py:method:: get_callable(batch: agedi.data.AtomsGraph) -> Callable Get callable function Return a callable function that samples from the distribution :param batch: Batch of data :type batch: AtomsGraph :returns: Callable function that samples from the distribution :rtype: Callable .. py:class:: StandardNormal Bases: :py:obj:`agedi.diffusion.distributions.Distribution` Standard Normal Distribution .. py:method:: _setup(batch: agedi.data.AtomsGraph) -> None Prepare the distribution for sampling from *batch*. Sets ``self.shape`` to ``(n_atoms, *trailing)`` where ``n_atoms`` is read from ``batch.n_atoms`` and the trailing dimensions come from the existing attribute. Using ``n_atoms`` rather than the attribute's leading dimension avoids a shape-mismatch when called during graph initialisation (via :meth:`~agedi.diffusion.noisers.Noiser.initialize_graph`), where the attribute tensor may still be empty even though ``n_atoms`` has already been set. :param batch: Batch of atomistic data. :type batch: AtomsGraph .. py:method:: _sample(shape: Optional[torch.Size] = None, **kwargs) -> torch.Tensor Sample from the standard normal distribution :param mu: Mean of the distribution :type mu: torch.Tensor :param sigma: Standard deviation of the distribution :type sigma: torch.Tensor :returns: Sampled tensor :rtype: torch.Tensor .. py:class:: Normal Bases: :py:obj:`agedi.diffusion.distributions.Distribution` Normal Distribution .. py:method:: _sample(mu: torch.Tensor, sigma: torch.Tensor, **kwargs) -> torch.Tensor Sample from the normal distribution :param mu: Mean of the distribution :type mu: torch.Tensor :param sigma: Standard deviation of the distribution :type sigma: torch.Tensor :returns: Sampled tensor :rtype: torch.Tensor .. py:class:: TruncatedNormal(index: int = 2, **kwargs) Bases: :py:obj:`agedi.diffusion.distributions.Distribution` Truncated Normal Distribution :param index: The index of the property to truncate :type index: int .. py:attribute:: index :value: 2 .. py:method:: get_hparams() -> Dict Return hyperparameters for this distribution. .. py:method:: _setup(batch: agedi.data.AtomsGraph) -> None Setup the distribution Prepare the distribution for sampling of the batch :param batch: Batch of data :type batch: AtomsGraph :rtype: None .. py:method:: _sample(mu: torch.Tensor, sigma: torch.Tensor, **kwargs) -> torch.Tensor Sample from the truncated normal distribution :param mu: Mean of the distribution :type mu: torch.Tensor :param sigma: Standard deviation of the distribution :type sigma: torch.Tensor :returns: Sampled tensor :rtype: torch.Tensor .. py:class:: Uniform(low: float = 0.0, high: float = 1.0, key: str = 'x', **kwargs) Bases: :py:obj:`agedi.diffusion.distributions.Distribution` Uniform Distribution :param low: The lower bound of the distribution :type low: float :param high: The upper bound of the distribution :type high: float .. py:attribute:: low :value: 0.0 .. py:attribute:: high :value: 1.0 .. py:method:: get_hparams() -> Dict Return hyperparameters for this distribution. .. py:method:: _setup(batch: agedi.data.AtomsGraph) -> None Prepare the distribution for sampling from *batch*. Sets ``self.shape`` to the shape of the target attribute in the batch. :param batch: Batch of atomistic data. :type batch: AtomsGraph .. py:method:: _sample(shape: Optional[torch.Size] = None, **kwargs) -> torch.Tensor Sample from the uniform distribution :param shape: The shape of the sample :type shape: torch.Size :returns: Sampled tensor :rtype: torch.Tensor .. py:class:: UniformCell(low: float = 0.0, high: float = 1.0, key: str = 'x', **kwargs) Bases: :py:obj:`Uniform` Uniform Prior Distribution for cell parameters .. py:method:: _setup(batch: agedi.data.AtomsGraph) -> None Prepare the distribution for sampling of the batch :param batch: Batch of data :type batch: AtomsGraph :rtype: None .. py:method:: _sample(**kwargs) -> torch.Tensor Sample from the uniform distribution :param mu: Mean of the distribution :type mu: torch.Tensor :param sigma: Standard deviation of the distribution :type sigma: torch.Tensor :returns: Sampled tensor :rtype: torch.Tensor .. py:class:: UniformCellConfined(low: float = 0.0, high: float = 1.0, key: str = 'x', **kwargs) Bases: :py:obj:`UniformCell` Uniform Prior Distribution for cell parameters with Z-directional confinement .. py:method:: _setup(batch: agedi.data.AtomsGraph) -> None Prepare the distribution for sampling of the batch :param batch: Batch of data :type batch: AtomsGraph :rtype: None .. py:class:: Constant(value: float = 0, key: str = 'x', dtype: Type = torch.int64, **kwargs) Bases: :py:obj:`agedi.diffusion.distributions.Distribution` Constant Integer Distribution .. py:attribute:: value :value: 0 .. py:attribute:: dtype .. py:method:: get_hparams() -> Dict Return hyperparameters for this distribution. .. py:method:: _setup(batch: agedi.data.AtomsGraph) -> None Prepare the distribution for sampling from *batch*. Sets ``self.shape`` based on the total number of atoms in the batch. :param batch: Batch of atomistic data. :type batch: AtomsGraph .. py:method:: _sample(shape: Optional[torch.Size] = None) -> torch.Tensor Sample from the integer distribution :param mu: Mean of the distribution :type mu: torch.Tensor :param sigma: Standard deviation of the distribution :type sigma: torch.Tensor :returns: Sampled tensor :rtype: torch.Tensor .. py:class:: Categorical Bases: :py:obj:`agedi.diffusion.distributions.Distribution` Categorical Distribution Implements hard sampling using the Gumbel-Max trick. .. py:method:: _sample(probs: torch.Tensor) -> torch.Tensor Sample from the categorical distribution where probabilites define the likelihood of mu value to be set to the masked, 0, value :param probs: The probabilities of each category :type probs: torch.Tensor :returns: Sampled tensor :rtype: torch.Tensor