agedi.models ============ .. py:module:: agedi.models Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/agedi/models/conditionings/index /autoapi/agedi/models/head/index /autoapi/agedi/models/regressor/index /autoapi/agedi/models/schnetpack/index /autoapi/agedi/models/score/index /autoapi/agedi/models/translator/index Classes ------- .. autoapisummary:: agedi.models.ScoreModel agedi.models.TimeConditioning agedi.models.Conditioning Package Contents ---------------- .. py:class:: ScoreModel(translator: agedi.models.translator.Translator, representation: agedi.data.Representation, conditionings: Optional[List[agedi.models.conditionings.Conditioning]] = None, heads: Optional[List[agedi.models.head.Head]] = None, w: float = -1.0, **kwargs) Bases: :py:obj:`lightning.LightningModule` Class that defines a the score model. It is a combination of a translator, a representation, a list of conditionings and a list of heads. :param translator: The translator that will be used to translate the input batch. :type translator: Translator :param representation: The representation that will be used to represent the translated batch. :type representation: Representation :param conditionings: The list of conditionings that will be applied to the representation. :type conditionings: List[Conditioning] :param heads: The list of heads that will be used to compute scores. :type heads: List[Head] .. py:attribute:: translator .. py:attribute:: representation .. py:attribute:: conditionings .. py:attribute:: heads .. py:attribute:: w .. py:attribute:: guidance :value: True .. py:method:: get_hparams() -> Dict Return hyperparameters sufficient to reconstruct this score model. Collects hyperparameters from the translator, representation (via :meth:`~agedi.models.translator.Translator.get_representation_hparams`), conditionings, and heads, as well as the guidance weight ``w``. :returns: Hyperparameter dictionary with a ``_target_`` key and nested ``translator``, ``representation``, ``conditionings``, ``heads``, and ``w`` entries. :rtype: dict .. py:method:: forward(batch: torch_geometric.data.Batch) -> torch_geometric.data.Batch Forward pass of the model. Dispatches to :meth:`forward_sample` when the model is in sampling mode, and to :meth:`forward_train` otherwise. This keeps ``self.sample`` as a compile-time constant for each compiled subgraph, avoiding retracing on mode changes and eliminating the Python-level branch from the compiled region. :param batch: The input batch that will be used to compute the scores. :type batch: Batch :returns: The output batch containing the scores. :rtype: Batch .. py:method:: forward_train(batch: torch_geometric.data.Batch) -> torch_geometric.data.Batch Training-mode forward pass. Computes the backbone representation, applies all conditionings unconditionally, translates the conditioned batch, and evaluates every score head. :param batch: The input batch. :type batch: Batch :returns: Batch with score tensors attached. :rtype: Batch .. py:method:: forward_sample(batch: torch_geometric.data.Batch) -> torch_geometric.data.Batch Sampling-mode forward pass. Computes the backbone representation, applies classifier-free guidance (when ``self.guidance`` is ``True``), translates the conditioned batch, and evaluates every score head with optional guidance mixing. :param batch: The input batch. :type batch: Batch :returns: Batch with score tensors attached. :rtype: Batch .. py:method:: sample_mode() -> None Switch the model to sampling mode. Sets ``self.sample = True`` and calls ``sample_mode()`` on all conditioning modules so that classifier-free guidance is applied during inference. .. py:method:: training_mode() -> None Switch the model to training mode. Sets ``self.sample = False`` and calls ``training_mode()`` on all conditioning modules so that conditioning is applied unconditionally during the forward pass. .. py:class:: TimeConditioning(input_dim: int = 1, output_dim: int = 2, **kwargs) Bases: :py:obj:`agedi.models.conditionings.base.Conditioning` Condition the model on the time t. :param t: Time tensor of shape (Nodes, 1). :type t: torch.Tensor .. py:attribute:: omega .. py:method:: get_hparams() -> Dict Return hyperparameters for this time conditioning module. .. py:method:: get_conditioning(t: torch.Tensor) -> torch.Tensor Get the conditioning tensor for the time t. ::math:: egin{align*} \mathbf{c} = egin{bmatrix} \sin(\omega t) \ \cos(\omega t) \end{bmatrix} \end{align*} :param t: Time tensor of shape (Nodes, 1). :type t: torch.Tensor :returns: Conditioning tensor of shape (Nodes, 2). :rtype: torch.Tensor .. py:method:: get_empty_conditioning(n: int) -> torch.Tensor Get an empty conditioning tensor. :returns: Empty conditioning tensor of shape (1, 2). :rtype: torch.Tensor .. py:method:: forward(batch: AtomsGraph, empty: bool = False) -> AtomsGraph Forward method to get the conditioning from the input This ignores training and empty flags. :param batch: The input batch :type batch: AtomsGraph :param empty: If True, return an empty conditioning tensor :type empty: bool :returns: The batch with the conditioning added to the representation :rtype: AtomsGraph .. py:class:: Conditioning(property: str, input_dim: int, output_dim: int, concatenation_type: str = 'scalar', probability: float = 0.8, **kwargs) Bases: :py:obj:`abc.ABC`, :py:obj:`lightning.LightningModule` Conditioning Base Class :param property: The property of the batch to condition on :type property: str :param input_dim: The dimension of the input conditioning :type input_dim: int :param output_dim: The dimension of the output conditioning :type output_dim: int :param concatenation_type: The type of concatenation to use. Default is "scalar" :type concatenation_type: str :param probability: The probability of conditioning. Default is 0.5. Only used in training mode :type probability: float :rtype: Conditioning .. py:attribute:: property .. py:attribute:: input_dim .. py:attribute:: output_dim .. py:attribute:: concatenation_type :value: 'scalar' .. py:attribute:: probability :value: 0.8 .. py:method:: get_hparams() -> Dict Return hyperparameters sufficient to reconstruct this conditioning module. Returns a dictionary with a ``_target_`` key (the fully-qualified class name) plus ``property``, ``probability``, ``input_dim``, and ``output_dim`` from the base class. Subclasses should call ``super().get_hparams()`` and merge in their own constructor parameters. :returns: Hyperparameter dictionary. :rtype: dict .. py:method:: get_conditioning(x: torch.Tensor) -> torch.Tensor :abstractmethod: Abstract method to get the conditioning from the input Must be implemented by the subclass :param x: The input tensor :type x: torch.Tensor :returns: The conditioning tensor :rtype: torch.Tensor .. py:method:: get_empty_conditioning(n: int) -> torch.Tensor :abstractmethod: Abstract method to get an empty conditioning tensor Must be implemented by the subclass :param n: The number of nodes in the batch :type n: int :param Returns: :param torch.Tensor: The empty conditioning tensor .. py:method:: forward(batch: AtomsGraph, empty: bool = False) -> AtomsGraph Forward method to get the conditioning from the input :param batch: The input batch :type batch: AtomsGraph :param empty: If True, return an empty conditioning tensor :type empty: bool :returns: The batch with the conditioning added to the representation :rtype: AtomsGraph .. py:method:: concatenate(batch: AtomsGraph, c: torch.Tensor) -> None Concatenate the conditioning to the batch ``c`` must already be expanded to node-level (``c.shape[0] == batch.pos.shape[0]``) before this method is called. The pre-expansion is performed by :meth:`forward` to keep this method free of dynamic shape comparisons, which would otherwise cause graph breaks under ``torch.compile``. :param batch: The input batch :type batch: AtomsGraph :param c: Node-level conditioning tensor of shape ``(n_nodes, output_dim)`` or ``(n_nodes, output_dim, 1)``. :type c: torch.Tensor :rtype: None .. py:method:: sample_mode() -> None Set the model to sample mode :rtype: None .. py:method:: training_mode() -> None Set the model to train mode :rtype: None