agedi.models.schnetpack

Submodules

Classes

SchNetPackTranslator

Translator for SchNetPack models.

PositionsScore

Predict the positions score of the atoms in the structure.

TypesScore

Predict the types score of the atoms in the structure.

Package Contents

class agedi.models.schnetpack.SchNetPackTranslator(input_modules: List[Callable] | None = None)

Bases: agedi.models.translator.Translator

Translator for SchNetPack models.

This class is used to translate the input data to the format required by the SchNetPack models.

get_representation_hparams(representation: Any) Dict

Extract hyperparameters from a SchNetPack representation object.

Supports schnetpack.representation.PaiNN. Extracts n_atom_basis, n_interactions, and nested configs for the radial_basis and cutoff_fn so that the representation can be fully reconstructed with _instantiate_from_config().

Parameters:

representation (any) – An instantiated SchNetPack representation object.

Returns:

Hyperparameter dictionary with a _target_ key and representation-specific parameters.

Return type:

dict

Raises:

NotImplementedError – If the representation type is not recognised.

_translate(batch: AtomsGraph) Dict[str, torch.Tensor]

Translate the input batch to the format required by the model.

The schnetpack model uses a dictionary format for the input data.

The keywords in the dictionary given in schnetpack.properties and describes: - n_atoms: number of atoms in the system - Z: atomic numbers - R: atomic positions - cell: cell vectors - pbc: periodic boundary conditions - idx_i: edge indices - idx_j: edge indices - offsets: shift vectors - idx_m: batch indices describing which atoms belong to which structure

Additionally energy and forces targets can be added to the dictionary.

Parameters:

batch (AtomsGraph) – The input batch of data.

Returns:

The translated batch of data.

Return type:

Dict

_translate_representation(representation: agedi.data.Representation, translated_batch: Dict[str, torch.Tensor]) Dict[str, torch.Tensor]

Translate the representation to the format required by the model.

SchnetPack uses scalar_representation and vector_representation for the two types of representations.

Parameters:
  • representation (Representation) – The input representation.

  • translated_batch (Dict) – The translated batch of data.

Returns:

The translated batch with representation keys.

Return type:

Dict

_get_representation(batch: AtomsGraph, translated_batch: Dict[str, torch.Tensor]) agedi.data.Representation

Get the representation from the output of the model.

Parameters:
  • batch (AtomsGraph) – The input batch of data.

  • translated_batch (Dict[str, torch.Tensor]) – The output of the model.

Returns:

The representation output of the model.

Return type:

Representation

class agedi.models.schnetpack.PositionsScore(input_dim_scalar: int = 66, input_dim_vector: int = 64, gated_blocks: int = 3, **kwargs)

Bases: agedi.models.head.Head

Predict the positions score of the atoms in the structure.

Parameters:
  • input_dim_scalar (int) – The dimension of the scalar input.

  • input_dim_vector (int) – The dimension of the vector input.

  • gated_blocks (int) – The number of gated blocks in the network.

Return type:

Head

_key = 'pos'
input_dim_scalar = 66
input_dim_vector = 64
gated_blocks = 3
net
get_hparams() Dict

Return hyperparameters for this positions score head.

_score(batch: dict) torch.Tensor

Predict the positions score of the atoms in the structure.

Parameters:

batch (dict) – The translated input batch with scalar_representation and vector_representation keys.

Returns:

The predicted positions score.

Return type:

torch.Tensor

class agedi.models.schnetpack.TypesScore(input_dim_scalar: int = 66, input_dim_vector: int = 64, n_classes: int = 100, **kwargs)

Bases: agedi.models.head.Head

Predict the types score of the atoms in the structure.

Parameters:
  • input_dim_scalar (int) – The dimension of the scalar input.

  • input_dim_vector (int) – The dimension of the vector input.

Return type:

Head

_key = 'x'
input_dim_scalar = 66
input_dim_vector = 64
n_classes = 100
net
get_hparams() Dict

Return hyperparameters for this types score head.

_score(batch: dict) torch.Tensor

Predict the types score of the atoms in the structure.

Parameters:

batch (dict) – The translated input batch with a scalar_representation key.

Returns:

The predicted types score.

Return type:

torch.Tensor