agedi.models.conditionings.base =============================== .. py:module:: agedi.models.conditionings.base Classes ------- .. autoapisummary:: agedi.models.conditionings.base.Conditioning Module Contents --------------- .. py:class:: Conditioning(property: str, input_dim: int, output_dim: int, concatenation_type: str = 'scalar', probability: float = 0.8, **kwargs) Bases: :py:obj:`abc.ABC`, :py:obj:`lightning.LightningModule` Conditioning Base Class :param property: The property of the batch to condition on :type property: str :param input_dim: The dimension of the input conditioning :type input_dim: int :param output_dim: The dimension of the output conditioning :type output_dim: int :param concatenation_type: The type of concatenation to use. Default is "scalar" :type concatenation_type: str :param probability: The probability of conditioning. Default is 0.5. Only used in training mode :type probability: float :rtype: Conditioning .. py:attribute:: property .. py:attribute:: input_dim .. py:attribute:: output_dim .. py:attribute:: concatenation_type :value: 'scalar' .. py:attribute:: probability :value: 0.8 .. py:method:: get_hparams() -> Dict Return hyperparameters sufficient to reconstruct this conditioning module. Returns a dictionary with a ``_target_`` key (the fully-qualified class name) plus ``property``, ``probability``, ``input_dim``, and ``output_dim`` from the base class. Subclasses should call ``super().get_hparams()`` and merge in their own constructor parameters. :returns: Hyperparameter dictionary. :rtype: dict .. py:method:: get_conditioning(x: torch.Tensor) -> torch.Tensor :abstractmethod: Abstract method to get the conditioning from the input Must be implemented by the subclass :param x: The input tensor :type x: torch.Tensor :returns: The conditioning tensor :rtype: torch.Tensor .. py:method:: get_empty_conditioning(n: int) -> torch.Tensor :abstractmethod: Abstract method to get an empty conditioning tensor Must be implemented by the subclass :param n: The number of nodes in the batch :type n: int :param Returns: :param torch.Tensor: The empty conditioning tensor .. py:method:: forward(batch: AtomsGraph, empty: bool = False) -> AtomsGraph Forward method to get the conditioning from the input :param batch: The input batch :type batch: AtomsGraph :param empty: If True, return an empty conditioning tensor :type empty: bool :returns: The batch with the conditioning added to the representation :rtype: AtomsGraph .. py:method:: concatenate(batch: AtomsGraph, c: torch.Tensor) -> None Concatenate the conditioning to the batch ``c`` must already be expanded to node-level (``c.shape[0] == batch.pos.shape[0]``) before this method is called. The pre-expansion is performed by :meth:`forward` to keep this method free of dynamic shape comparisons, which would otherwise cause graph breaks under ``torch.compile``. :param batch: The input batch :type batch: AtomsGraph :param c: Node-level conditioning tensor of shape ``(n_nodes, output_dim)`` or ``(n_nodes, output_dim, 1)``. :type c: torch.Tensor :rtype: None .. py:method:: sample_mode() -> None Set the model to sample mode :rtype: None .. py:method:: training_mode() -> None Set the model to train mode :rtype: None