agedi.diffusion.distributions.normal ==================================== .. py:module:: agedi.diffusion.distributions.normal Attributes ---------- .. autoapisummary:: agedi.diffusion.distributions.normal._CONFINEMENT_CLAMP_EPS Classes ------- .. autoapisummary:: agedi.diffusion.distributions.normal.StandardNormal agedi.diffusion.distributions.normal.Normal agedi.diffusion.distributions.normal.TruncatedNormal Module Contents --------------- .. py:data:: _CONFINEMENT_CLAMP_EPS :value: 0.0001 .. py:class:: StandardNormal Bases: :py:obj:`agedi.diffusion.distributions.Distribution` Standard Normal Distribution .. py:method:: _setup(batch: agedi.data.AtomsGraph) -> None Prepare the distribution for sampling from *batch*. Sets ``self.shape`` to ``(n_atoms, *trailing)`` where ``n_atoms`` is read from ``batch.n_atoms`` and the trailing dimensions come from the existing attribute. Using ``n_atoms`` rather than the attribute's leading dimension avoids a shape-mismatch when called during graph initialisation (via :meth:`~agedi.diffusion.noisers.Noiser.initialize_graph`), where the attribute tensor may still be empty even though ``n_atoms`` has already been set. :param batch: Batch of atomistic data. :type batch: AtomsGraph .. py:method:: _sample(shape: Optional[torch.Size] = None, **kwargs) -> torch.Tensor Sample from the standard normal distribution :param mu: Mean of the distribution :type mu: torch.Tensor :param sigma: Standard deviation of the distribution :type sigma: torch.Tensor :returns: Sampled tensor :rtype: torch.Tensor .. py:class:: Normal Bases: :py:obj:`agedi.diffusion.distributions.Distribution` Normal Distribution .. py:method:: _sample(mu: torch.Tensor, sigma: torch.Tensor, **kwargs) -> torch.Tensor Sample from the normal distribution :param mu: Mean of the distribution :type mu: torch.Tensor :param sigma: Standard deviation of the distribution :type sigma: torch.Tensor :returns: Sampled tensor :rtype: torch.Tensor .. py:class:: TruncatedNormal(index: int = 2, **kwargs) Bases: :py:obj:`agedi.diffusion.distributions.Distribution` Truncated Normal Distribution :param index: The index of the property to truncate :type index: int .. py:attribute:: index :value: 2 .. py:method:: get_hparams() -> Dict Return hyperparameters for this distribution. .. py:method:: _setup(batch: agedi.data.AtomsGraph) -> None Setup the distribution Prepare the distribution for sampling of the batch :param batch: Batch of data :type batch: AtomsGraph :rtype: None .. py:method:: _sample(mu: torch.Tensor, sigma: torch.Tensor, **kwargs) -> torch.Tensor Sample from the truncated normal distribution :param mu: Mean of the distribution :type mu: torch.Tensor :param sigma: Standard deviation of the distribution :type sigma: torch.Tensor :returns: Sampled tensor :rtype: torch.Tensor