agedi.models.score ================== .. py:module:: agedi.models.score Classes ------- .. autoapisummary:: agedi.models.score.ScoreModel Module Contents --------------- .. py:class:: ScoreModel(translator: agedi.models.translator.Translator, representation: agedi.data.Representation, conditionings: Optional[List[agedi.models.conditionings.Conditioning]] = None, heads: Optional[List[agedi.models.head.Head]] = None, w: float = -1.0, **kwargs) Bases: :py:obj:`lightning.LightningModule` Class that defines a the score model. It is a combination of a translator, a representation, a list of conditionings and a list of heads. :param translator: The translator that will be used to translate the input batch. :type translator: Translator :param representation: The representation that will be used to represent the translated batch. :type representation: Representation :param conditionings: The list of conditionings that will be applied to the representation. :type conditionings: List[Conditioning] :param heads: The list of heads that will be used to compute scores. :type heads: List[Head] .. py:attribute:: translator .. py:attribute:: representation .. py:attribute:: conditionings .. py:attribute:: heads .. py:attribute:: w .. py:attribute:: guidance :value: True .. py:method:: get_hparams() -> Dict Return hyperparameters sufficient to reconstruct this score model. Collects hyperparameters from the translator, representation (via :meth:`~agedi.models.translator.Translator.get_representation_hparams`), conditionings, and heads, as well as the guidance weight ``w``. :returns: Hyperparameter dictionary with a ``_target_`` key and nested ``translator``, ``representation``, ``conditionings``, ``heads``, and ``w`` entries. :rtype: dict .. py:method:: forward(batch: torch_geometric.data.Batch) -> torch_geometric.data.Batch Forward pass of the model. Dispatches to :meth:`forward_sample` when the model is in sampling mode, and to :meth:`forward_train` otherwise. This keeps ``self.sample`` as a compile-time constant for each compiled subgraph, avoiding retracing on mode changes and eliminating the Python-level branch from the compiled region. :param batch: The input batch that will be used to compute the scores. :type batch: Batch :returns: The output batch containing the scores. :rtype: Batch .. py:method:: forward_train(batch: torch_geometric.data.Batch) -> torch_geometric.data.Batch Training-mode forward pass. Computes the backbone representation, applies all conditionings unconditionally, translates the conditioned batch, and evaluates every score head. :param batch: The input batch. :type batch: Batch :returns: Batch with score tensors attached. :rtype: Batch .. py:method:: forward_sample(batch: torch_geometric.data.Batch) -> torch_geometric.data.Batch Sampling-mode forward pass. Computes the backbone representation, applies classifier-free guidance (when ``self.guidance`` is ``True``), translates the conditioned batch, and evaluates every score head with optional guidance mixing. :param batch: The input batch. :type batch: Batch :returns: Batch with score tensors attached. :rtype: Batch .. py:method:: sample_mode() -> None Switch the model to sampling mode. Sets ``self.sample = True`` and calls ``sample_mode()`` on all conditioning modules so that classifier-free guidance is applied during inference. .. py:method:: training_mode() -> None Switch the model to training mode. Sets ``self.sample = False`` and calls ``training_mode()`` on all conditioning modules so that conditioning is applied unconditionally during the forward pass.