agedi.models.schnetpack.heads ============================= .. py:module:: agedi.models.schnetpack.heads Classes ------- .. autoapisummary:: agedi.models.schnetpack.heads.PositionsScore agedi.models.schnetpack.heads.TypesScore Functions --------- .. autoapisummary:: agedi.models.schnetpack.heads.build_gated_equivariant_mlp Module Contents --------------- .. py:function:: build_gated_equivariant_mlp(s_in: int, v_in: int, n_out: int, n_layers: int = 2, activation: Callable = F.silu, sactivation: Callable = F.silu) -> torch.nn.Sequential Build neural network analog to MLP with `GatedEquivariantBlock`s instead of dense layers. :param n_in: Number of input nodes. :type n_in: int :param n_out: Number of output nodes. :type n_out: int :param n_layers: Number of layers. :type n_layers: int :param activation: Activation function. :type activation: Callable :param sactivation: Activation function for the skip connection. :type sactivation: Callable :param n_hidden: Number of hidden nodes. :type n_hidden: int :rtype: nn.Module .. py:class:: PositionsScore(input_dim_scalar: int = 66, input_dim_vector: int = 64, gated_blocks: int = 3, **kwargs) Bases: :py:obj:`agedi.models.head.Head` Predict the positions score of the atoms in the structure. :param input_dim_scalar: The dimension of the scalar input. :type input_dim_scalar: int :param input_dim_vector: The dimension of the vector input. :type input_dim_vector: int :param gated_blocks: The number of gated blocks in the network. :type gated_blocks: int :rtype: Head .. py:attribute:: _key :value: 'pos' .. py:attribute:: input_dim_scalar :value: 66 .. py:attribute:: input_dim_vector :value: 64 .. py:attribute:: gated_blocks :value: 3 .. py:attribute:: net .. py:method:: get_hparams() -> Dict Return hyperparameters for this positions score head. .. py:method:: _score(batch: dict) -> torch.Tensor Predict the positions score of the atoms in the structure. :param batch: The translated input batch with ``scalar_representation`` and ``vector_representation`` keys. :type batch: dict :returns: The predicted positions score. :rtype: torch.Tensor .. py:class:: TypesScore(input_dim_scalar: int = 66, input_dim_vector: int = 64, n_classes: int = 100, **kwargs) Bases: :py:obj:`agedi.models.head.Head` Predict the types score of the atoms in the structure. :param input_dim_scalar: The dimension of the scalar input. :type input_dim_scalar: int :param input_dim_vector: The dimension of the vector input. :type input_dim_vector: int :rtype: Head .. py:attribute:: _key :value: 'x' .. py:attribute:: input_dim_scalar :value: 66 .. py:attribute:: input_dim_vector :value: 64 .. py:attribute:: n_classes :value: 100 .. py:attribute:: net .. py:method:: get_hparams() -> Dict Return hyperparameters for this types score head. .. py:method:: _score(batch: dict) -> torch.Tensor Predict the types score of the atoms in the structure. :param batch: The translated input batch with a ``scalar_representation`` key. :type batch: dict :returns: The predicted types score. :rtype: torch.Tensor