agedi.models.schnetpack ======================= .. py:module:: agedi.models.schnetpack Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/agedi/models/schnetpack/heads/index /autoapi/agedi/models/schnetpack/regressor_heads/index /autoapi/agedi/models/schnetpack/translator/index Classes ------- .. autoapisummary:: agedi.models.schnetpack.SchNetPackTranslator agedi.models.schnetpack.PositionsScore agedi.models.schnetpack.TypesScore Package Contents ---------------- .. py:class:: SchNetPackTranslator(input_modules: Optional[List[Callable]] = None) Bases: :py:obj:`agedi.models.translator.Translator` Translator for SchNetPack models. This class is used to translate the input data to the format required by the SchNetPack models. .. py:method:: get_representation_hparams(representation: Any) -> Dict Extract hyperparameters from a SchNetPack representation object. Supports :class:`schnetpack.representation.PaiNN`. Extracts ``n_atom_basis``, ``n_interactions``, and nested configs for the ``radial_basis`` and ``cutoff_fn`` so that the representation can be fully reconstructed with :func:`~agedi.functional._instantiate_from_config`. :param representation: An instantiated SchNetPack representation object. :type representation: any :returns: Hyperparameter dictionary with a ``_target_`` key and representation-specific parameters. :rtype: dict :raises NotImplementedError: If the representation type is not recognised. .. py:method:: _translate(batch: AtomsGraph) -> Dict[str, torch.Tensor] Translate the input batch to the format required by the model. The schnetpack model uses a dictionary format for the input data. The keywords in the dictionary given in schnetpack.properties and describes: - n_atoms: number of atoms in the system - Z: atomic numbers - R: atomic positions - cell: cell vectors - pbc: periodic boundary conditions - idx_i: edge indices - idx_j: edge indices - offsets: shift vectors - idx_m: batch indices describing which atoms belong to which structure Additionally energy and forces targets can be added to the dictionary. :param batch: The input batch of data. :type batch: AtomsGraph :returns: The translated batch of data. :rtype: Dict .. py:method:: _translate_representation(representation: agedi.data.Representation, translated_batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor] Translate the representation to the format required by the model. SchnetPack uses scalar_representation and vector_representation for the two types of representations. :param representation: The input representation. :type representation: Representation :param translated_batch: The translated batch of data. :type translated_batch: Dict :returns: The translated batch with representation keys. :rtype: Dict .. py:method:: _get_representation(batch: AtomsGraph, translated_batch: Dict[str, torch.Tensor]) -> agedi.data.Representation Get the representation from the output of the model. :param batch: The input batch of data. :type batch: AtomsGraph :param translated_batch: The output of the model. :type translated_batch: Dict[str, torch.Tensor] :returns: The representation output of the model. :rtype: Representation .. py:class:: PositionsScore(input_dim_scalar: int = 66, input_dim_vector: int = 64, gated_blocks: int = 3, **kwargs) Bases: :py:obj:`agedi.models.head.Head` Predict the positions score of the atoms in the structure. :param input_dim_scalar: The dimension of the scalar input. :type input_dim_scalar: int :param input_dim_vector: The dimension of the vector input. :type input_dim_vector: int :param gated_blocks: The number of gated blocks in the network. :type gated_blocks: int :rtype: Head .. py:attribute:: _key :value: 'pos' .. py:attribute:: input_dim_scalar :value: 66 .. py:attribute:: input_dim_vector :value: 64 .. py:attribute:: gated_blocks :value: 3 .. py:attribute:: net .. py:method:: get_hparams() -> Dict Return hyperparameters for this positions score head. .. py:method:: _score(batch: dict) -> torch.Tensor Predict the positions score of the atoms in the structure. :param batch: The translated input batch with ``scalar_representation`` and ``vector_representation`` keys. :type batch: dict :returns: The predicted positions score. :rtype: torch.Tensor .. py:class:: TypesScore(input_dim_scalar: int = 66, input_dim_vector: int = 64, n_classes: int = 100, **kwargs) Bases: :py:obj:`agedi.models.head.Head` Predict the types score of the atoms in the structure. :param input_dim_scalar: The dimension of the scalar input. :type input_dim_scalar: int :param input_dim_vector: The dimension of the vector input. :type input_dim_vector: int :rtype: Head .. py:attribute:: _key :value: 'x' .. py:attribute:: input_dim_scalar :value: 66 .. py:attribute:: input_dim_vector :value: 64 .. py:attribute:: n_classes :value: 100 .. py:attribute:: net .. py:method:: get_hparams() -> Dict Return hyperparameters for this types score head. .. py:method:: _score(batch: dict) -> torch.Tensor Predict the types score of the atoms in the structure. :param batch: The translated input batch with a ``scalar_representation`` key. :type batch: dict :returns: The predicted types score. :rtype: torch.Tensor