agedi.diffusion.distributions.base ================================== .. py:module:: agedi.diffusion.distributions.base Classes ------- .. autoapisummary:: agedi.diffusion.distributions.base.Distribution Module Contents --------------- .. py:class:: Distribution(key: Optional[str] = None, **kwargs) Bases: :py:obj:`abc.ABC` Base Class for noise distributions :param key: Key to identify the property from the batch :type key: str :rtype: Distribution .. py:attribute:: key :value: None .. py:method:: get_hparams() -> Dict Return hyperparameters sufficient to reconstruct this distribution. Returns a dictionary with a ``_target_`` key (the fully-qualified class name) plus any constructor arguments stored on the base class. Subclasses should call ``super().get_hparams()`` and merge in their own parameters. :returns: Hyperparameter dictionary. :rtype: dict .. py:method:: _sample(**kwargs) -> torch.Tensor :abstractmethod: Sample distribution Sample from the distribution and return tensor of shape self.key :param kwargs: The parameters of the distribution :type kwargs: dict :returns: Sampled tensor :rtype: torch.Tensor .. py:method:: _setup(batch: agedi.data.AtomsGraph) -> None Prepare distribution Prepare the distribution for sampling of the batch :param batch: Batch of data :type batch: AtomsGraph :rtype: None .. py:method:: get_callable(batch: agedi.data.AtomsGraph) -> Callable Get callable function Return a callable function that samples from the distribution :param batch: Batch of data :type batch: AtomsGraph :returns: Callable function that samples from the distribution :rtype: Callable