agedi.models.schnetpack.heads

Classes

PositionsScore

Predict the positions score of the atoms in the structure.

TypesScore

Predict the types score of the atoms in the structure.

Functions

build_gated_equivariant_mlp(→ torch.nn.Sequential)

Build neural network analog to MLP with `GatedEquivariantBlock`s instead of dense layers.

Module Contents

agedi.models.schnetpack.heads.build_gated_equivariant_mlp(s_in: int, v_in: int, n_out: int, n_layers: int = 2, activation: Callable = F.silu, sactivation: Callable = F.silu) torch.nn.Sequential

Build neural network analog to MLP with `GatedEquivariantBlock`s instead of dense layers.

Parameters:
  • n_in (int) – Number of input nodes.

  • n_out (int) – Number of output nodes.

  • n_layers (int) – Number of layers.

  • activation (Callable) – Activation function.

  • sactivation (Callable) – Activation function for the skip connection.

  • n_hidden (int) – Number of hidden nodes.

Return type:

nn.Module

class agedi.models.schnetpack.heads.PositionsScore(input_dim_scalar: int = 66, input_dim_vector: int = 64, gated_blocks: int = 3, **kwargs)

Bases: agedi.models.head.Head

Predict the positions score of the atoms in the structure.

Parameters:
  • input_dim_scalar (int) – The dimension of the scalar input.

  • input_dim_vector (int) – The dimension of the vector input.

  • gated_blocks (int) – The number of gated blocks in the network.

Return type:

Head

_key = 'pos'
input_dim_scalar = 66
input_dim_vector = 64
gated_blocks = 3
net
get_hparams() Dict

Return hyperparameters for this positions score head.

_score(batch: dict) torch.Tensor

Predict the positions score of the atoms in the structure.

Parameters:

batch (dict) – The translated input batch with scalar_representation and vector_representation keys.

Returns:

The predicted positions score.

Return type:

torch.Tensor

class agedi.models.schnetpack.heads.TypesScore(input_dim_scalar: int = 66, input_dim_vector: int = 64, n_classes: int = 100, **kwargs)

Bases: agedi.models.head.Head

Predict the types score of the atoms in the structure.

Parameters:
  • input_dim_scalar (int) – The dimension of the scalar input.

  • input_dim_vector (int) – The dimension of the vector input.

Return type:

Head

_key = 'x'
input_dim_scalar = 66
input_dim_vector = 64
n_classes = 100
net
get_hparams() Dict

Return hyperparameters for this types score head.

_score(batch: dict) torch.Tensor

Predict the types score of the atoms in the structure.

Parameters:

batch (dict) – The translated input batch with a scalar_representation key.

Returns:

The predicted types score.

Return type:

torch.Tensor