agedi.models.regressor ====================== .. py:module:: agedi.models.regressor Classes ------- .. autoapisummary:: agedi.models.regressor.RegressorModel Module Contents --------------- .. py:class:: RegressorModel(translator: agedi.models.translator.Translator, representation: agedi.data.Representation, heads: List[agedi.models.head.Head] = [], head_weights={}, use_weighting: bool = False, mask_forces: bool = True, **kwargs) Bases: :py:obj:`lightning.LightningModule` Class that defines a regressor model. It is a combination of a translator, a representation and a list of heads. :param translator: The translator that will be used to translate the input batch. :type translator: Translator :param representation: The representation that will be used to represent the translated batch. :type representation: Representation :param heads: The list of heads that will be used to compute scores. :type heads: List[Head] .. py:attribute:: translator .. py:attribute:: representation .. py:attribute:: head_weights .. py:attribute:: use_weighting :value: False .. py:attribute:: mask_forces :value: True .. py:attribute:: head_keys .. py:attribute:: heads .. py:method:: get_hparams() -> Dict Return hyperparameters sufficient to reconstruct this regressor model. :returns: Hyperparameter dictionary with a ``_target_`` key and nested ``translator``, ``representation``, and ``heads`` entries. :rtype: dict .. py:method:: forward(batch: torch_geometric.data.Batch) -> torch_geometric.data.Batch Forward pass of the model. :param batch: The input batch that will be used to compute the scores. :type batch: Batch :returns: The output batch containing the scores. :rtype: Batch .. py:method:: loss(batch: torch_geometric.data.Batch) -> Dict Compute the loss of the model. :param batch: The input batch that will be used to compute the loss. :type batch: Batch :returns: A dictionary containing the loss and the individual head losses. :rtype: dict