agedi.models.regressor

Classes

RegressorModel

Class that defines a regressor model.

Module Contents

class agedi.models.regressor.RegressorModel(translator: agedi.models.translator.Translator, representation: agedi.data.Representation, heads: List[agedi.models.head.Head] = [], head_weights={}, use_weighting: bool = False, mask_forces: bool = True, **kwargs)

Bases: lightning.LightningModule

Class that defines a regressor model.

It is a combination of a translator, a representation and a list of heads.

Parameters:
  • translator (Translator) – The translator that will be used to translate the input batch.

  • representation (Representation) – The representation that will be used to represent the translated batch.

  • heads (List[Head]) – The list of heads that will be used to compute scores.

translator
representation
head_weights
use_weighting = False
mask_forces = True
head_keys
heads
get_hparams() Dict

Return hyperparameters sufficient to reconstruct this regressor model.

Returns:

Hyperparameter dictionary with a _target_ key and nested translator, representation, and heads entries.

Return type:

dict

forward(batch: torch_geometric.data.Batch) torch_geometric.data.Batch

Forward pass of the model.

Parameters:

batch (Batch) – The input batch that will be used to compute the scores.

Returns:

The output batch containing the scores.

Return type:

Batch

loss(batch: torch_geometric.data.Batch) Dict

Compute the loss of the model.

Parameters:

batch (Batch) – The input batch that will be used to compute the loss.

Returns:

A dictionary containing the loss and the individual head losses.

Return type:

dict