agedi.models.schnetpack¶
Submodules¶
Classes¶
Translator for SchNetPack models. |
|
Predict the positions score of the atoms in the structure. |
|
Predict the types score of the atoms in the structure. |
Package Contents¶
- class agedi.models.schnetpack.SchNetPackTranslator(input_modules: List[Callable] | None = None)¶
Bases:
agedi.models.translator.TranslatorTranslator for SchNetPack models.
This class is used to translate the input data to the format required by the SchNetPack models.
- get_representation_hparams(representation: Any) Dict¶
Extract hyperparameters from a SchNetPack representation object.
Supports
schnetpack.representation.PaiNN. Extractsn_atom_basis,n_interactions, and nested configs for theradial_basisandcutoff_fnso that the representation can be fully reconstructed with_instantiate_from_config().- Parameters:
representation (any) – An instantiated SchNetPack representation object.
- Returns:
Hyperparameter dictionary with a
_target_key and representation-specific parameters.- Return type:
dict
- Raises:
NotImplementedError – If the representation type is not recognised.
- _translate(batch: AtomsGraph) Dict[str, torch.Tensor]¶
Translate the input batch to the format required by the model.
The schnetpack model uses a dictionary format for the input data.
The keywords in the dictionary given in schnetpack.properties and describes: - n_atoms: number of atoms in the system - Z: atomic numbers - R: atomic positions - cell: cell vectors - pbc: periodic boundary conditions - idx_i: edge indices - idx_j: edge indices - offsets: shift vectors - idx_m: batch indices describing which atoms belong to which structure
Additionally energy and forces targets can be added to the dictionary.
- Parameters:
batch (AtomsGraph) – The input batch of data.
- Returns:
The translated batch of data.
- Return type:
Dict
- _translate_representation(representation: agedi.data.Representation, translated_batch: Dict[str, torch.Tensor]) Dict[str, torch.Tensor]¶
Translate the representation to the format required by the model.
SchnetPack uses scalar_representation and vector_representation for the two types of representations.
- Parameters:
representation (Representation) – The input representation.
translated_batch (Dict) – The translated batch of data.
- Returns:
The translated batch with representation keys.
- Return type:
Dict
- _get_representation(batch: AtomsGraph, translated_batch: Dict[str, torch.Tensor]) agedi.data.Representation¶
Get the representation from the output of the model.
- Parameters:
batch (AtomsGraph) – The input batch of data.
translated_batch (Dict[str, torch.Tensor]) – The output of the model.
- Returns:
The representation output of the model.
- Return type:
- class agedi.models.schnetpack.PositionsScore(input_dim_scalar: int = 66, input_dim_vector: int = 64, gated_blocks: int = 3, **kwargs)¶
Bases:
agedi.models.head.HeadPredict the positions score of the atoms in the structure.
- Parameters:
input_dim_scalar (int) – The dimension of the scalar input.
input_dim_vector (int) – The dimension of the vector input.
gated_blocks (int) – The number of gated blocks in the network.
- Return type:
- _key = 'pos'¶
- input_dim_scalar = 66¶
- input_dim_vector = 64¶
- gated_blocks = 3¶
- net¶
- get_hparams() Dict¶
Return hyperparameters for this positions score head.
- _score(batch: dict) torch.Tensor¶
Predict the positions score of the atoms in the structure.
- Parameters:
batch (dict) – The translated input batch with
scalar_representationandvector_representationkeys.- Returns:
The predicted positions score.
- Return type:
torch.Tensor
- class agedi.models.schnetpack.TypesScore(input_dim_scalar: int = 66, input_dim_vector: int = 64, n_classes: int = 100, **kwargs)¶
Bases:
agedi.models.head.HeadPredict the types score of the atoms in the structure.
- Parameters:
input_dim_scalar (int) – The dimension of the scalar input.
input_dim_vector (int) – The dimension of the vector input.
- Return type:
- _key = 'x'¶
- input_dim_scalar = 66¶
- input_dim_vector = 64¶
- n_classes = 100¶
- net¶
- get_hparams() Dict¶
Return hyperparameters for this types score head.
- _score(batch: dict) torch.Tensor¶
Predict the types score of the atoms in the structure.
- Parameters:
batch (dict) – The translated input batch with a
scalar_representationkey.- Returns:
The predicted types score.
- Return type:
torch.Tensor