agedi.models.score

Classes

ScoreModel

Class that defines a the score model.

Module Contents

class agedi.models.score.ScoreModel(translator: agedi.models.translator.Translator, representation: agedi.data.Representation, conditionings: List[agedi.models.conditionings.Conditioning] | None = None, heads: List[agedi.models.head.Head] | None = None, w: float = -1.0, **kwargs)

Bases: lightning.LightningModule

Class that defines a the score model.

It is a combination of a translator, a representation, a list of conditionings and a list of heads.

Parameters:
  • translator (Translator) – The translator that will be used to translate the input batch.

  • representation (Representation) – The representation that will be used to represent the translated batch.

  • conditionings (List[Conditioning]) – The list of conditionings that will be applied to the representation.

  • heads (List[Head]) – The list of heads that will be used to compute scores.

translator
representation
conditionings
heads
w
guidance = True
get_hparams() Dict

Return hyperparameters sufficient to reconstruct this score model.

Collects hyperparameters from the translator, representation (via get_representation_hparams()), conditionings, and heads, as well as the guidance weight w.

Returns:

Hyperparameter dictionary with a _target_ key and nested translator, representation, conditionings, heads, and w entries.

Return type:

dict

forward(batch: torch_geometric.data.Batch) torch_geometric.data.Batch

Forward pass of the model.

Dispatches to forward_sample() when the model is in sampling mode, and to forward_train() otherwise. This keeps self.sample as a compile-time constant for each compiled subgraph, avoiding retracing on mode changes and eliminating the Python-level branch from the compiled region.

Parameters:

batch (Batch) – The input batch that will be used to compute the scores.

Returns:

The output batch containing the scores.

Return type:

Batch

forward_train(batch: torch_geometric.data.Batch) torch_geometric.data.Batch

Training-mode forward pass.

Computes the backbone representation, applies all conditionings unconditionally, translates the conditioned batch, and evaluates every score head.

Parameters:

batch (Batch) – The input batch.

Returns:

Batch with score tensors attached.

Return type:

Batch

forward_sample(batch: torch_geometric.data.Batch) torch_geometric.data.Batch

Sampling-mode forward pass.

Computes the backbone representation, applies classifier-free guidance (when self.guidance is True), translates the conditioned batch, and evaluates every score head with optional guidance mixing.

Parameters:

batch (Batch) – The input batch.

Returns:

Batch with score tensors attached.

Return type:

Batch

sample_mode() None

Switch the model to sampling mode.

Sets self.sample = True and calls sample_mode() on all conditioning modules so that classifier-free guidance is applied during inference.

training_mode() None

Switch the model to training mode.

Sets self.sample = False and calls training_mode() on all conditioning modules so that conditioning is applied unconditionally during the forward pass.