agedi.models.translator ======================= .. py:module:: agedi.models.translator Classes ------- .. autoapisummary:: agedi.models.translator.Translator Module Contents --------------- .. py:class:: Translator(input_modules: Optional[List[Callable]] = None) Bases: :py:obj:`abc.ABC` Base class for all translators. Translators are used to convert a batch of data into a format that can be used by the model. This is useful when the data is not in the correct format or needs to be preprocessed before being fed into the model. :param input_modules: A list of functions that will be applied to the input data after it is translated. :type input_modules: List[Callable] .. py:attribute:: input_modules .. py:method:: get_hparams() -> Dict Return hyperparameters sufficient to reconstruct this translator. Returns a dictionary with a ``_target_`` key (the fully-qualified class name) plus ``input_modules`` (each serialised with its own ``_target_`` key where available). Subclasses should call ``super().get_hparams()`` and merge in their own constructor parameters. :returns: Hyperparameter dictionary. :rtype: dict .. py:method:: get_representation_hparams(representation: Any) -> Dict :abstractmethod: Extract hyperparameters from a representation object. This method is called by :meth:`~agedi.models.ScoreModel.get_hparams` to serialise the representation (e.g. a PaiNN network) that the translator wraps. The base implementation raises :class:`NotImplementedError`; subclasses must override it for the specific representation type they support. :param representation: The instantiated representation object. :type representation: any :returns: Hyperparameter dictionary that can be used to reconstruct the representation (should contain a ``_target_`` key). :rtype: dict :raises NotImplementedError: If the subclass has not implemented this method. .. py:method:: _translate(batch: agedi.data.AtomsGraph) -> agedi.data.AtomsGraph :abstractmethod: Translate the batch of data. Abstract method that must be implemented by all subclasses. This method is used to translate the batch of data into a format that can be used by the model. :param batch: The batch of data to be translated. :type batch: AtomsGraph :returns: The translated batch of data. :rtype: AtomsGraph .. py:method:: _get_representation(batch: agedi.data.AtomsGraph, out: Any) -> agedi.data.Representation :abstractmethod: Get the representation of the batch of data. Abstract method that must be implemented by all subclasses. This method is used to add the representation given by the model to the original batch of data. :param batch: The original batch of data. :type batch: AtomsGraph :param out: The output of the model. :type out: Any :returns: The representation given by the model. :rtype: Representation .. py:method:: _translate_representation(rep: agedi.data.Representation, translated_batch: Any) -> Any :abstractmethod: Translate the representation of the batch of data. Abstract method that must be implemented by all subclasses. This method is used to translate the representation given by the model back into the original batch of data. :param rep: The representation given by the model. :type rep: Representation :param translated_batch: The translated batch of data. :type translated_batch: Any :returns: **translated_batch** -- The translated batch of data. :rtype: Any .. py:method:: __call__(batch: agedi.data.AtomsGraph) -> agedi.data.AtomsGraph Call method for the Translator class. implementation of the __call__ method. This method is used to call the translator object as a function. :param batch: The batch of data to be translated. :type batch: AtomsGraph :returns: The translated batch of data. :rtype: AtomsGraph .. py:method:: translate_input(batch: agedi.data.AtomsGraph) -> agedi.data.AtomsGraph Translate the batch without injecting any stored representation. Unlike :meth:`__call__`, this method always skips the :meth:`_translate_representation` step regardless of whether ``batch.representation`` is set. Use this for the *first* forward pass through the backbone (before the representation has been computed). :param batch: The batch of data to translate. :type batch: AtomsGraph :returns: The translated batch. :rtype: AtomsGraph .. py:method:: translate_with_representation(batch: agedi.data.AtomsGraph) -> agedi.data.AtomsGraph Translate the batch and inject the stored representation. Like :meth:`translate_input` but always calls :meth:`_translate_representation` to inject the representation that was previously attached via :meth:`add_representation`. Use this for the *second* forward pass through the backbone (after the representation has been computed and stored on ``batch``). :param batch: The batch of data to translate. ``batch.representation`` must not be ``None`` when this method is called. :type batch: AtomsGraph :returns: The translated batch with the representation injected. :rtype: AtomsGraph .. py:method:: add_representation(batch: agedi.data.AtomsGraph, out: Any) -> agedi.data.AtomsGraph Adds the representation given by the model to the original batch of data. :param batch: The original batch of data. :type batch: AtomsGraph :param out: The output of the model. :type out: Any :returns: The original batch of data with the representation added. :rtype: AtomsGraph .. py:method:: add_scores(batch: agedi.data.AtomsGraph, scores: Dict[str, torch.Tensor]) -> agedi.data.AtomsGraph Adds the scores given by the model to the original batch of data. :param batch: The original batch of data. :type batch: AtomsGraph :param out: The output of the model. Format is {head key: head predicted scores} :type out: Dict[str, Any] :returns: The original batch of data with the scores added. :rtype: AtomsGraph .. py:method:: add_prediction(batch: agedi.data.AtomsGraph, targets: Dict[str, torch.Tensor], type: Optional[str] = None) -> agedi.data.AtomsGraph Adds the targets given by the model to the original batch of data. :param batch: The original batch of data. :type batch: AtomsGraph :param out: The output of the model. Format is {head key: head predicted target} :type out: Dict[str, Any] :returns: The original batch of data with the scores added. :rtype: AtomsGraph