agedi.models.schnetpack.translator ================================== .. py:module:: agedi.models.schnetpack.translator Classes ------- .. autoapisummary:: agedi.models.schnetpack.translator.SchNetPackTranslator Module Contents --------------- .. py:class:: SchNetPackTranslator(input_modules: Optional[List[Callable]] = None) Bases: :py:obj:`agedi.models.translator.Translator` Translator for SchNetPack models. This class is used to translate the input data to the format required by the SchNetPack models. .. py:method:: get_representation_hparams(representation: Any) -> Dict Extract hyperparameters from a SchNetPack representation object. Supports :class:`schnetpack.representation.PaiNN`. Extracts ``n_atom_basis``, ``n_interactions``, and nested configs for the ``radial_basis`` and ``cutoff_fn`` so that the representation can be fully reconstructed with :func:`~agedi.functional._instantiate_from_config`. :param representation: An instantiated SchNetPack representation object. :type representation: any :returns: Hyperparameter dictionary with a ``_target_`` key and representation-specific parameters. :rtype: dict :raises NotImplementedError: If the representation type is not recognised. .. py:method:: _translate(batch: AtomsGraph) -> Dict[str, torch.Tensor] Translate the input batch to the format required by the model. The schnetpack model uses a dictionary format for the input data. The keywords in the dictionary given in schnetpack.properties and describes: - n_atoms: number of atoms in the system - Z: atomic numbers - R: atomic positions - cell: cell vectors - pbc: periodic boundary conditions - idx_i: edge indices - idx_j: edge indices - offsets: shift vectors - idx_m: batch indices describing which atoms belong to which structure Additionally energy and forces targets can be added to the dictionary. :param batch: The input batch of data. :type batch: AtomsGraph :returns: The translated batch of data. :rtype: Dict .. py:method:: _translate_representation(representation: agedi.data.Representation, translated_batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor] Translate the representation to the format required by the model. SchnetPack uses scalar_representation and vector_representation for the two types of representations. :param representation: The input representation. :type representation: Representation :param translated_batch: The translated batch of data. :type translated_batch: Dict :returns: The translated batch with representation keys. :rtype: Dict .. py:method:: _get_representation(batch: AtomsGraph, translated_batch: Dict[str, torch.Tensor]) -> agedi.data.Representation Get the representation from the output of the model. :param batch: The input batch of data. :type batch: AtomsGraph :param translated_batch: The output of the model. :type translated_batch: Dict[str, torch.Tensor] :returns: The representation output of the model. :rtype: Representation