agedi.models.conditionings ========================== .. py:module:: agedi.models.conditionings Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/agedi/models/conditionings/base/index /autoapi/agedi/models/conditionings/integer/index /autoapi/agedi/models/conditionings/scalar/index /autoapi/agedi/models/conditionings/time/index Classes ------- .. autoapisummary:: agedi.models.conditionings.Conditioning agedi.models.conditionings.TimeConditioning agedi.models.conditionings.ScalarConditioning agedi.models.conditionings.IntegerConditioning Package Contents ---------------- .. py:class:: Conditioning(property: str, input_dim: int, output_dim: int, concatenation_type: str = 'scalar', probability: float = 0.8, **kwargs) Bases: :py:obj:`abc.ABC`, :py:obj:`lightning.LightningModule` Conditioning Base Class :param property: The property of the batch to condition on :type property: str :param input_dim: The dimension of the input conditioning :type input_dim: int :param output_dim: The dimension of the output conditioning :type output_dim: int :param concatenation_type: The type of concatenation to use. Default is "scalar" :type concatenation_type: str :param probability: The probability of conditioning. Default is 0.5. Only used in training mode :type probability: float :rtype: Conditioning .. py:attribute:: property .. py:attribute:: input_dim .. py:attribute:: output_dim .. py:attribute:: concatenation_type :value: 'scalar' .. py:attribute:: probability :value: 0.8 .. py:method:: get_hparams() -> Dict Return hyperparameters sufficient to reconstruct this conditioning module. Returns a dictionary with a ``_target_`` key (the fully-qualified class name) plus ``property``, ``probability``, ``input_dim``, and ``output_dim`` from the base class. Subclasses should call ``super().get_hparams()`` and merge in their own constructor parameters. :returns: Hyperparameter dictionary. :rtype: dict .. py:method:: get_conditioning(x: torch.Tensor) -> torch.Tensor :abstractmethod: Abstract method to get the conditioning from the input Must be implemented by the subclass :param x: The input tensor :type x: torch.Tensor :returns: The conditioning tensor :rtype: torch.Tensor .. py:method:: get_empty_conditioning(n: int) -> torch.Tensor :abstractmethod: Abstract method to get an empty conditioning tensor Must be implemented by the subclass :param n: The number of nodes in the batch :type n: int :param Returns: :param torch.Tensor: The empty conditioning tensor .. py:method:: forward(batch: AtomsGraph, empty: bool = False) -> AtomsGraph Forward method to get the conditioning from the input :param batch: The input batch :type batch: AtomsGraph :param empty: If True, return an empty conditioning tensor :type empty: bool :returns: The batch with the conditioning added to the representation :rtype: AtomsGraph .. py:method:: concatenate(batch: AtomsGraph, c: torch.Tensor) -> None Concatenate the conditioning to the batch ``c`` must already be expanded to node-level (``c.shape[0] == batch.pos.shape[0]``) before this method is called. The pre-expansion is performed by :meth:`forward` to keep this method free of dynamic shape comparisons, which would otherwise cause graph breaks under ``torch.compile``. :param batch: The input batch :type batch: AtomsGraph :param c: Node-level conditioning tensor of shape ``(n_nodes, output_dim)`` or ``(n_nodes, output_dim, 1)``. :type c: torch.Tensor :rtype: None .. py:method:: sample_mode() -> None Set the model to sample mode :rtype: None .. py:method:: training_mode() -> None Set the model to train mode :rtype: None .. py:class:: TimeConditioning(input_dim: int = 1, output_dim: int = 2, **kwargs) Bases: :py:obj:`agedi.models.conditionings.base.Conditioning` Condition the model on the time t. :param t: Time tensor of shape (Nodes, 1). :type t: torch.Tensor .. py:attribute:: omega .. py:method:: get_hparams() -> Dict Return hyperparameters for this time conditioning module. .. py:method:: get_conditioning(t: torch.Tensor) -> torch.Tensor Get the conditioning tensor for the time t. ::math:: egin{align*} \mathbf{c} = egin{bmatrix} \sin(\omega t) \ \cos(\omega t) \end{bmatrix} \end{align*} :param t: Time tensor of shape (Nodes, 1). :type t: torch.Tensor :returns: Conditioning tensor of shape (Nodes, 2). :rtype: torch.Tensor .. py:method:: get_empty_conditioning(n: int) -> torch.Tensor Get an empty conditioning tensor. :returns: Empty conditioning tensor of shape (1, 2). :rtype: torch.Tensor .. py:method:: forward(batch: AtomsGraph, empty: bool = False) -> AtomsGraph Forward method to get the conditioning from the input This ignores training and empty flags. :param batch: The input batch :type batch: AtomsGraph :param empty: If True, return an empty conditioning tensor :type empty: bool :returns: The batch with the conditioning added to the representation :rtype: AtomsGraph .. py:class:: ScalarConditioning(*args, input_dim: int = 1, output_dim: int = 2, **kwargs) Bases: :py:obj:`agedi.models.conditionings.base.Conditioning` Conditioning module for continuous scalar properties. Projects a scalar property through a learned linear layer and encodes it with sinusoidal features (``cos`` and ``sin``), producing a 2-dimensional conditioning vector. .. py:attribute:: embedder .. py:method:: get_conditioning(x: torch.Tensor) -> torch.Tensor Get the conditioning tensor for x :param x: Time tensor of shape (Nodes, 1). :type x: torch.Tensor :returns: Conditioning tensor of shape (Nodes, 2). :rtype: torch.Tensor .. py:method:: get_empty_conditioning(n: int) -> torch.Tensor Get an empty conditioning tensor. :returns: Empty conditioning tensor of shape (n, 2). :rtype: torch.Tensor .. py:class:: IntegerConditioning(max_int: int = 200, input_dim: int = 1, output_dim: int = 64, *args, **kwargs) Bases: :py:obj:`agedi.models.conditionings.base.Conditioning` Conditioning module for integer-valued properties. Embeds an integer property (e.g. number of atoms) into a fixed-size representation using :class:`torch.nn.Embedding`. .. py:attribute:: max_int :value: 200 .. py:attribute:: embedder .. py:method:: get_hparams() -> Dict Return hyperparameters for this integer conditioning module. .. py:method:: get_conditioning(x: torch.Tensor) -> torch.Tensor Get the conditioning tensor for x :param x: Time tensor of shape (Nodes, 1). :type x: torch.Tensor :returns: Conditioning tensor of shape (Nodes, 2). :rtype: torch.Tensor .. py:method:: get_empty_conditioning(n: int) -> torch.Tensor Get an empty conditioning tensor. :returns: Empty conditioning tensor of shape (n, 2). :rtype: torch.Tensor