agedi.models.regressor¶
Classes¶
Class that defines a regressor model. |
Module Contents¶
- class agedi.models.regressor.RegressorModel(translator: agedi.models.translator.Translator, representation: agedi.data.Representation, heads: List[agedi.models.head.Head] = [], head_weights={}, use_weighting: bool = False, mask_forces: bool = True, **kwargs)¶
Bases:
lightning.LightningModuleClass that defines a regressor model.
It is a combination of a translator, a representation and a list of heads.
- Parameters:
translator (Translator) – The translator that will be used to translate the input batch.
representation (Representation) – The representation that will be used to represent the translated batch.
heads (List[Head]) – The list of heads that will be used to compute scores.
- translator¶
- representation¶
- head_weights¶
- use_weighting = False¶
- mask_forces = True¶
- head_keys¶
- heads¶
- get_hparams() Dict¶
Return hyperparameters sufficient to reconstruct this regressor model.
- Returns:
Hyperparameter dictionary with a
_target_key and nestedtranslator,representation, andheadsentries.- Return type:
dict
- forward(batch: torch_geometric.data.Batch) torch_geometric.data.Batch¶
Forward pass of the model.
- Parameters:
batch (Batch) – The input batch that will be used to compute the scores.
- Returns:
The output batch containing the scores.
- Return type:
Batch
- loss(batch: torch_geometric.data.Batch) Dict¶
Compute the loss of the model.
- Parameters:
batch (Batch) – The input batch that will be used to compute the loss.
- Returns:
A dictionary containing the loss and the individual head losses.
- Return type:
dict