agedi.models.translator

Classes

Translator

Base class for all translators.

Module Contents

class agedi.models.translator.Translator(input_modules: List[Callable] | None = None)

Bases: abc.ABC

Base class for all translators.

Translators are used to convert a batch of data into a format that can be used by the model. This is useful when the data is not in the correct format or needs to be preprocessed before being fed into the model.

Parameters:

input_modules (List[Callable]) – A list of functions that will be applied to the input data after it is translated.

input_modules
get_hparams() Dict

Return hyperparameters sufficient to reconstruct this translator.

Returns a dictionary with a _target_ key (the fully-qualified class name) plus input_modules (each serialised with its own _target_ key where available). Subclasses should call super().get_hparams() and merge in their own constructor parameters.

Returns:

Hyperparameter dictionary.

Return type:

dict

abstractmethod get_representation_hparams(representation: Any) Dict

Extract hyperparameters from a representation object.

This method is called by get_hparams() to serialise the representation (e.g. a PaiNN network) that the translator wraps. The base implementation raises NotImplementedError; subclasses must override it for the specific representation type they support.

Parameters:

representation (any) – The instantiated representation object.

Returns:

Hyperparameter dictionary that can be used to reconstruct the representation (should contain a _target_ key).

Return type:

dict

Raises:

NotImplementedError – If the subclass has not implemented this method.

abstractmethod _translate(batch: agedi.data.AtomsGraph) agedi.data.AtomsGraph

Translate the batch of data.

Abstract method that must be implemented by all subclasses.

This method is used to translate the batch of data into a format that can be used by the model.

Parameters:

batch (AtomsGraph) – The batch of data to be translated.

Returns:

The translated batch of data.

Return type:

AtomsGraph

abstractmethod _get_representation(batch: agedi.data.AtomsGraph, out: Any) agedi.data.Representation

Get the representation of the batch of data.

Abstract method that must be implemented by all subclasses.

This method is used to add the representation given by the model to the original batch of data.

Parameters:
  • batch (AtomsGraph) – The original batch of data.

  • out (Any) – The output of the model.

Returns:

The representation given by the model.

Return type:

Representation

abstractmethod _translate_representation(rep: agedi.data.Representation, translated_batch: Any) Any

Translate the representation of the batch of data.

Abstract method that must be implemented by all subclasses.

This method is used to translate the representation given by the model back into the original batch of data.

Parameters:
  • rep (Representation) – The representation given by the model.

  • translated_batch (Any) – The translated batch of data.

Returns:

translated_batch – The translated batch of data.

Return type:

Any

__call__(batch: agedi.data.AtomsGraph) agedi.data.AtomsGraph

Call method for the Translator class.

implementation of the __call__ method. This method is used to call the translator object as a function.

Parameters:

batch (AtomsGraph) – The batch of data to be translated.

Returns:

The translated batch of data.

Return type:

AtomsGraph

translate_input(batch: agedi.data.AtomsGraph) agedi.data.AtomsGraph

Translate the batch without injecting any stored representation.

Unlike __call__(), this method always skips the _translate_representation() step regardless of whether batch.representation is set. Use this for the first forward pass through the backbone (before the representation has been computed).

Parameters:

batch (AtomsGraph) – The batch of data to translate.

Returns:

The translated batch.

Return type:

AtomsGraph

translate_with_representation(batch: agedi.data.AtomsGraph) agedi.data.AtomsGraph

Translate the batch and inject the stored representation.

Like translate_input() but always calls _translate_representation() to inject the representation that was previously attached via add_representation(). Use this for the second forward pass through the backbone (after the representation has been computed and stored on batch).

Parameters:

batch (AtomsGraph) – The batch of data to translate. batch.representation must not be None when this method is called.

Returns:

The translated batch with the representation injected.

Return type:

AtomsGraph

add_representation(batch: agedi.data.AtomsGraph, out: Any) agedi.data.AtomsGraph

Adds the representation given by the model to the original batch of data.

Parameters:
  • batch (AtomsGraph) – The original batch of data.

  • out (Any) – The output of the model.

Returns:

The original batch of data with the representation added.

Return type:

AtomsGraph

add_scores(batch: agedi.data.AtomsGraph, scores: Dict[str, torch.Tensor]) agedi.data.AtomsGraph

Adds the scores given by the model to the original batch of data.

Parameters:
  • batch (AtomsGraph) – The original batch of data.

  • out (Dict[str, Any]) – The output of the model. Format is {head key: head predicted scores}

Returns:

The original batch of data with the scores added.

Return type:

AtomsGraph

add_prediction(batch: agedi.data.AtomsGraph, targets: Dict[str, torch.Tensor], type: str | None = None) agedi.data.AtomsGraph

Adds the targets given by the model to the original batch of data.

Parameters:
  • batch (AtomsGraph) – The original batch of data.

  • out (Dict[str, Any]) – The output of the model. Format is {head key: head predicted target}

Returns:

The original batch of data with the scores added.

Return type:

AtomsGraph