agedi.models.translator¶
Classes¶
Base class for all translators. |
Module Contents¶
- class agedi.models.translator.Translator(input_modules: List[Callable] | None = None)¶
Bases:
abc.ABCBase class for all translators.
Translators are used to convert a batch of data into a format that can be used by the model. This is useful when the data is not in the correct format or needs to be preprocessed before being fed into the model.
- Parameters:
input_modules (List[Callable]) – A list of functions that will be applied to the input data after it is translated.
- input_modules¶
- get_hparams() Dict¶
Return hyperparameters sufficient to reconstruct this translator.
Returns a dictionary with a
_target_key (the fully-qualified class name) plusinput_modules(each serialised with its own_target_key where available). Subclasses should callsuper().get_hparams()and merge in their own constructor parameters.- Returns:
Hyperparameter dictionary.
- Return type:
dict
- abstractmethod get_representation_hparams(representation: Any) Dict¶
Extract hyperparameters from a representation object.
This method is called by
get_hparams()to serialise the representation (e.g. a PaiNN network) that the translator wraps. The base implementation raisesNotImplementedError; subclasses must override it for the specific representation type they support.- Parameters:
representation (any) – The instantiated representation object.
- Returns:
Hyperparameter dictionary that can be used to reconstruct the representation (should contain a
_target_key).- Return type:
dict
- Raises:
NotImplementedError – If the subclass has not implemented this method.
- abstractmethod _translate(batch: agedi.data.AtomsGraph) agedi.data.AtomsGraph¶
Translate the batch of data.
Abstract method that must be implemented by all subclasses.
This method is used to translate the batch of data into a format that can be used by the model.
- Parameters:
batch (AtomsGraph) – The batch of data to be translated.
- Returns:
The translated batch of data.
- Return type:
- abstractmethod _get_representation(batch: agedi.data.AtomsGraph, out: Any) agedi.data.Representation¶
Get the representation of the batch of data.
Abstract method that must be implemented by all subclasses.
This method is used to add the representation given by the model to the original batch of data.
- Parameters:
batch (AtomsGraph) – The original batch of data.
out (Any) – The output of the model.
- Returns:
The representation given by the model.
- Return type:
- abstractmethod _translate_representation(rep: agedi.data.Representation, translated_batch: Any) Any¶
Translate the representation of the batch of data.
Abstract method that must be implemented by all subclasses.
This method is used to translate the representation given by the model back into the original batch of data.
- Parameters:
rep (Representation) – The representation given by the model.
translated_batch (Any) – The translated batch of data.
- Returns:
translated_batch – The translated batch of data.
- Return type:
Any
- __call__(batch: agedi.data.AtomsGraph) agedi.data.AtomsGraph¶
Call method for the Translator class.
implementation of the __call__ method. This method is used to call the translator object as a function.
- Parameters:
batch (AtomsGraph) – The batch of data to be translated.
- Returns:
The translated batch of data.
- Return type:
- translate_input(batch: agedi.data.AtomsGraph) agedi.data.AtomsGraph¶
Translate the batch without injecting any stored representation.
Unlike
__call__(), this method always skips the_translate_representation()step regardless of whetherbatch.representationis set. Use this for the first forward pass through the backbone (before the representation has been computed).- Parameters:
batch (AtomsGraph) – The batch of data to translate.
- Returns:
The translated batch.
- Return type:
- translate_with_representation(batch: agedi.data.AtomsGraph) agedi.data.AtomsGraph¶
Translate the batch and inject the stored representation.
Like
translate_input()but always calls_translate_representation()to inject the representation that was previously attached viaadd_representation(). Use this for the second forward pass through the backbone (after the representation has been computed and stored onbatch).- Parameters:
batch (AtomsGraph) – The batch of data to translate.
batch.representationmust not beNonewhen this method is called.- Returns:
The translated batch with the representation injected.
- Return type:
- add_representation(batch: agedi.data.AtomsGraph, out: Any) agedi.data.AtomsGraph¶
Adds the representation given by the model to the original batch of data.
- Parameters:
batch (AtomsGraph) – The original batch of data.
out (Any) – The output of the model.
- Returns:
The original batch of data with the representation added.
- Return type:
- add_scores(batch: agedi.data.AtomsGraph, scores: Dict[str, torch.Tensor]) agedi.data.AtomsGraph¶
Adds the scores given by the model to the original batch of data.
- Parameters:
batch (AtomsGraph) – The original batch of data.
out (Dict[str, Any]) – The output of the model. Format is {head key: head predicted scores}
- Returns:
The original batch of data with the scores added.
- Return type:
- add_prediction(batch: agedi.data.AtomsGraph, targets: Dict[str, torch.Tensor], type: str | None = None) agedi.data.AtomsGraph¶
Adds the targets given by the model to the original batch of data.
- Parameters:
batch (AtomsGraph) – The original batch of data.
out (Dict[str, Any]) – The output of the model. Format is {head key: head predicted target}
- Returns:
The original batch of data with the scores added.
- Return type: