agedi.models.schnetpack.heads¶
Classes¶
Predict the positions score of the atoms in the structure. |
|
Predict the types score of the atoms in the structure. |
Functions¶
|
Build neural network analog to MLP with `GatedEquivariantBlock`s instead of dense layers. |
Module Contents¶
- agedi.models.schnetpack.heads.build_gated_equivariant_mlp(s_in: int, v_in: int, n_out: int, n_layers: int = 2, activation: Callable = F.silu, sactivation: Callable = F.silu) torch.nn.Sequential¶
Build neural network analog to MLP with `GatedEquivariantBlock`s instead of dense layers.
- Parameters:
n_in (int) – Number of input nodes.
n_out (int) – Number of output nodes.
n_layers (int) – Number of layers.
activation (Callable) – Activation function.
sactivation (Callable) – Activation function for the skip connection.
n_hidden (int) – Number of hidden nodes.
- Return type:
nn.Module
- class agedi.models.schnetpack.heads.PositionsScore(input_dim_scalar: int = 66, input_dim_vector: int = 64, gated_blocks: int = 3, **kwargs)¶
Bases:
agedi.models.head.HeadPredict the positions score of the atoms in the structure.
- Parameters:
input_dim_scalar (int) – The dimension of the scalar input.
input_dim_vector (int) – The dimension of the vector input.
gated_blocks (int) – The number of gated blocks in the network.
- Return type:
- _key = 'pos'¶
- input_dim_scalar = 66¶
- input_dim_vector = 64¶
- gated_blocks = 3¶
- net¶
- get_hparams() Dict¶
Return hyperparameters for this positions score head.
- _score(batch: dict) torch.Tensor¶
Predict the positions score of the atoms in the structure.
- Parameters:
batch (dict) – The translated input batch with
scalar_representationandvector_representationkeys.- Returns:
The predicted positions score.
- Return type:
torch.Tensor
- class agedi.models.schnetpack.heads.TypesScore(input_dim_scalar: int = 66, input_dim_vector: int = 64, n_classes: int = 100, **kwargs)¶
Bases:
agedi.models.head.HeadPredict the types score of the atoms in the structure.
- Parameters:
input_dim_scalar (int) – The dimension of the scalar input.
input_dim_vector (int) – The dimension of the vector input.
- Return type:
- _key = 'x'¶
- input_dim_scalar = 66¶
- input_dim_vector = 64¶
- n_classes = 100¶
- net¶
- get_hparams() Dict¶
Return hyperparameters for this types score head.
- _score(batch: dict) torch.Tensor¶
Predict the types score of the atoms in the structure.
- Parameters:
batch (dict) – The translated input batch with a
scalar_representationkey.- Returns:
The predicted types score.
- Return type:
torch.Tensor