agedi.models.schnetpack.translator¶
Classes¶
Translator for SchNetPack models. |
Module Contents¶
- class agedi.models.schnetpack.translator.SchNetPackTranslator(input_modules: List[Callable] | None = None)¶
Bases:
agedi.models.translator.TranslatorTranslator for SchNetPack models.
This class is used to translate the input data to the format required by the SchNetPack models.
- get_representation_hparams(representation: Any) Dict¶
Extract hyperparameters from a SchNetPack representation object.
Supports
schnetpack.representation.PaiNN. Extractsn_atom_basis,n_interactions, and nested configs for theradial_basisandcutoff_fnso that the representation can be fully reconstructed with_instantiate_from_config().- Parameters:
representation (any) – An instantiated SchNetPack representation object.
- Returns:
Hyperparameter dictionary with a
_target_key and representation-specific parameters.- Return type:
dict
- Raises:
NotImplementedError – If the representation type is not recognised.
- _translate(batch: AtomsGraph) Dict[str, torch.Tensor]¶
Translate the input batch to the format required by the model.
The schnetpack model uses a dictionary format for the input data.
The keywords in the dictionary given in schnetpack.properties and describes: - n_atoms: number of atoms in the system - Z: atomic numbers - R: atomic positions - cell: cell vectors - pbc: periodic boundary conditions - idx_i: edge indices - idx_j: edge indices - offsets: shift vectors - idx_m: batch indices describing which atoms belong to which structure
Additionally energy and forces targets can be added to the dictionary.
- Parameters:
batch (AtomsGraph) – The input batch of data.
- Returns:
The translated batch of data.
- Return type:
Dict
- _translate_representation(representation: agedi.data.Representation, translated_batch: Dict[str, torch.Tensor]) Dict[str, torch.Tensor]¶
Translate the representation to the format required by the model.
SchnetPack uses scalar_representation and vector_representation for the two types of representations.
- Parameters:
representation (Representation) – The input representation.
translated_batch (Dict) – The translated batch of data.
- Returns:
The translated batch with representation keys.
- Return type:
Dict
- _get_representation(batch: AtomsGraph, translated_batch: Dict[str, torch.Tensor]) agedi.data.Representation¶
Get the representation from the output of the model.
- Parameters:
batch (AtomsGraph) – The input batch of data.
translated_batch (Dict[str, torch.Tensor]) – The output of the model.
- Returns:
The representation output of the model.
- Return type: