agedi.models.head¶
Classes¶
Abstract base class for any score model heads. |
Module Contents¶
- class agedi.models.head.Head(score_clip: float | None = None, **kwargs)¶
Bases:
abc.ABC,torch.nn.ModuleAbstract base class for any score model heads.
The head is responsible for taking the translated batch with precalculated representation and returning a score tensor.
The score tensor should have the same shape as the original tensor for the key of the head.
- Return type:
- _key: str¶
- _score_clip = None¶
- property key: str¶
The key of the attribute to be noised and denoised.
- get_hparams() Dict¶
Return hyperparameters sufficient to reconstruct this head.
Returns a dictionary with a
_target_key (the fully-qualified class name) plusscore_clipfrom the base class. Subclasses should callsuper().get_hparams()and merge in their own constructor parameters.- Returns:
Hyperparameter dictionary.
- Return type:
dict
- forward(translated_batch: Any) torch.Tensor¶
Forward pass of the head using a translated batch
The output shape must match the either the positions (pos), types (x) or cell (cell) of the original batch.
- Parameters:
translated_batch (Any) – The translated batch to be used in the forward pass
- Returns:
The output of the forward pass. The shape of the tensor depends on the key of the head.
- Return type:
torch.Tensor
- abstractmethod _score(translated_batch: Any) torch.Tensor¶
Abstract method for the forward pass of the head.
Must be implemented by the subclass.
- Parameters:
translated_batch (Any) – The translated batch to be used in the forward pass
- Returns:
The output of the forward pass. The shape of the tensor depends on the key of the head.
- Return type:
torch.Tensor