agedi.models.score¶
Classes¶
Class that defines a the score model. |
Module Contents¶
- class agedi.models.score.ScoreModel(translator: agedi.models.translator.Translator, representation: agedi.data.Representation, conditionings: List[agedi.models.conditionings.Conditioning] | None = None, heads: List[agedi.models.head.Head] | None = None, w: float = -1.0, **kwargs)¶
Bases:
lightning.LightningModuleClass that defines a the score model.
It is a combination of a translator, a representation, a list of conditionings and a list of heads.
- Parameters:
translator (Translator) – The translator that will be used to translate the input batch.
representation (Representation) – The representation that will be used to represent the translated batch.
conditionings (List[Conditioning]) – The list of conditionings that will be applied to the representation.
heads (List[Head]) – The list of heads that will be used to compute scores.
- translator¶
- representation¶
- conditionings¶
- heads¶
- w¶
- guidance = True¶
- get_hparams() Dict¶
Return hyperparameters sufficient to reconstruct this score model.
Collects hyperparameters from the translator, representation (via
get_representation_hparams()), conditionings, and heads, as well as the guidance weightw.- Returns:
Hyperparameter dictionary with a
_target_key and nestedtranslator,representation,conditionings,heads, andwentries.- Return type:
dict
- forward(batch: torch_geometric.data.Batch) torch_geometric.data.Batch¶
Forward pass of the model.
Dispatches to
forward_sample()when the model is in sampling mode, and toforward_train()otherwise. This keepsself.sampleas a compile-time constant for each compiled subgraph, avoiding retracing on mode changes and eliminating the Python-level branch from the compiled region.- Parameters:
batch (Batch) – The input batch that will be used to compute the scores.
- Returns:
The output batch containing the scores.
- Return type:
Batch
- forward_train(batch: torch_geometric.data.Batch) torch_geometric.data.Batch¶
Training-mode forward pass.
Computes the backbone representation, applies all conditionings unconditionally, translates the conditioned batch, and evaluates every score head.
- Parameters:
batch (Batch) – The input batch.
- Returns:
Batch with score tensors attached.
- Return type:
Batch
- forward_sample(batch: torch_geometric.data.Batch) torch_geometric.data.Batch¶
Sampling-mode forward pass.
Computes the backbone representation, applies classifier-free guidance (when
self.guidanceisTrue), translates the conditioned batch, and evaluates every score head with optional guidance mixing.- Parameters:
batch (Batch) – The input batch.
- Returns:
Batch with score tensors attached.
- Return type:
Batch
- sample_mode() None¶
Switch the model to sampling mode.
Sets
self.sample = Trueand callssample_mode()on all conditioning modules so that classifier-free guidance is applied during inference.
- training_mode() None¶
Switch the model to training mode.
Sets
self.sample = Falseand callstraining_mode()on all conditioning modules so that conditioning is applied unconditionally during the forward pass.