agedi.models.head

Classes

Head

Abstract base class for any score model heads.

Module Contents

class agedi.models.head.Head(score_clip: float | None = None, **kwargs)

Bases: abc.ABC, torch.nn.Module

Abstract base class for any score model heads.

The head is responsible for taking the translated batch with precalculated representation and returning a score tensor.

The score tensor should have the same shape as the original tensor for the key of the head.

Return type:

Head

_key: str
_score_clip = None
property key: str

The key of the attribute to be noised and denoised.

get_hparams() Dict

Return hyperparameters sufficient to reconstruct this head.

Returns a dictionary with a _target_ key (the fully-qualified class name) plus score_clip from the base class. Subclasses should call super().get_hparams() and merge in their own constructor parameters.

Returns:

Hyperparameter dictionary.

Return type:

dict

forward(translated_batch: Any) torch.Tensor

Forward pass of the head using a translated batch

The output shape must match the either the positions (pos), types (x) or cell (cell) of the original batch.

Parameters:

translated_batch (Any) – The translated batch to be used in the forward pass

Returns:

The output of the forward pass. The shape of the tensor depends on the key of the head.

Return type:

torch.Tensor

abstractmethod _score(translated_batch: Any) torch.Tensor

Abstract method for the forward pass of the head.

Must be implemented by the subclass.

Parameters:

translated_batch (Any) – The translated batch to be used in the forward pass

Returns:

The output of the forward pass. The shape of the tensor depends on the key of the head.

Return type:

torch.Tensor