agedi.models.schnetpack.translator

Classes

SchNetPackTranslator

Translator for SchNetPack models.

Module Contents

class agedi.models.schnetpack.translator.SchNetPackTranslator(input_modules: List[Callable] | None = None)

Bases: agedi.models.translator.Translator

Translator for SchNetPack models.

This class is used to translate the input data to the format required by the SchNetPack models.

get_representation_hparams(representation: Any) Dict

Extract hyperparameters from a SchNetPack representation object.

Supports schnetpack.representation.PaiNN. Extracts n_atom_basis, n_interactions, and nested configs for the radial_basis and cutoff_fn so that the representation can be fully reconstructed with _instantiate_from_config().

Parameters:

representation (any) – An instantiated SchNetPack representation object.

Returns:

Hyperparameter dictionary with a _target_ key and representation-specific parameters.

Return type:

dict

Raises:

NotImplementedError – If the representation type is not recognised.

_translate(batch: AtomsGraph) Dict[str, torch.Tensor]

Translate the input batch to the format required by the model.

The schnetpack model uses a dictionary format for the input data.

The keywords in the dictionary given in schnetpack.properties and describes: - n_atoms: number of atoms in the system - Z: atomic numbers - R: atomic positions - cell: cell vectors - pbc: periodic boundary conditions - idx_i: edge indices - idx_j: edge indices - offsets: shift vectors - idx_m: batch indices describing which atoms belong to which structure

Additionally energy and forces targets can be added to the dictionary.

Parameters:

batch (AtomsGraph) – The input batch of data.

Returns:

The translated batch of data.

Return type:

Dict

_translate_representation(representation: agedi.data.Representation, translated_batch: Dict[str, torch.Tensor]) Dict[str, torch.Tensor]

Translate the representation to the format required by the model.

SchnetPack uses scalar_representation and vector_representation for the two types of representations.

Parameters:
  • representation (Representation) – The input representation.

  • translated_batch (Dict) – The translated batch of data.

Returns:

The translated batch with representation keys.

Return type:

Dict

_get_representation(batch: AtomsGraph, translated_batch: Dict[str, torch.Tensor]) agedi.data.Representation

Get the representation from the output of the model.

Parameters:
  • batch (AtomsGraph) – The input batch of data.

  • translated_batch (Dict[str, torch.Tensor]) – The output of the model.

Returns:

The representation output of the model.

Return type:

Representation