agedi.models¶
Submodules¶
Classes¶
Class that defines a the score model. |
|
Condition the model on the time t. |
|
Conditioning Base Class |
Package Contents¶
- class agedi.models.ScoreModel(translator: agedi.models.translator.Translator, representation: agedi.data.Representation, conditionings: List[agedi.models.conditionings.Conditioning] | None = None, heads: List[agedi.models.head.Head] | None = None, w: float = -1.0, **kwargs)¶
Bases:
lightning.LightningModuleClass that defines a the score model.
It is a combination of a translator, a representation, a list of conditionings and a list of heads.
- Parameters:
translator (Translator) – The translator that will be used to translate the input batch.
representation (Representation) – The representation that will be used to represent the translated batch.
conditionings (List[Conditioning]) – The list of conditionings that will be applied to the representation.
heads (List[Head]) – The list of heads that will be used to compute scores.
- translator¶
- representation¶
- conditionings¶
- heads¶
- w¶
- guidance = True¶
- get_hparams() Dict¶
Return hyperparameters sufficient to reconstruct this score model.
Collects hyperparameters from the translator, representation (via
get_representation_hparams()), conditionings, and heads, as well as the guidance weightw.- Returns:
Hyperparameter dictionary with a
_target_key and nestedtranslator,representation,conditionings,heads, andwentries.- Return type:
dict
- forward(batch: torch_geometric.data.Batch) torch_geometric.data.Batch¶
Forward pass of the model.
Dispatches to
forward_sample()when the model is in sampling mode, and toforward_train()otherwise. This keepsself.sampleas a compile-time constant for each compiled subgraph, avoiding retracing on mode changes and eliminating the Python-level branch from the compiled region.- Parameters:
batch (Batch) – The input batch that will be used to compute the scores.
- Returns:
The output batch containing the scores.
- Return type:
Batch
- forward_train(batch: torch_geometric.data.Batch) torch_geometric.data.Batch¶
Training-mode forward pass.
Computes the backbone representation, applies all conditionings unconditionally, translates the conditioned batch, and evaluates every score head.
- Parameters:
batch (Batch) – The input batch.
- Returns:
Batch with score tensors attached.
- Return type:
Batch
- forward_sample(batch: torch_geometric.data.Batch) torch_geometric.data.Batch¶
Sampling-mode forward pass.
Computes the backbone representation, applies classifier-free guidance (when
self.guidanceisTrue), translates the conditioned batch, and evaluates every score head with optional guidance mixing.- Parameters:
batch (Batch) – The input batch.
- Returns:
Batch with score tensors attached.
- Return type:
Batch
- sample_mode() None¶
Switch the model to sampling mode.
Sets
self.sample = Trueand callssample_mode()on all conditioning modules so that classifier-free guidance is applied during inference.
- training_mode() None¶
Switch the model to training mode.
Sets
self.sample = Falseand callstraining_mode()on all conditioning modules so that conditioning is applied unconditionally during the forward pass.
- class agedi.models.TimeConditioning(input_dim: int = 1, output_dim: int = 2, **kwargs)¶
Bases:
agedi.models.conditionings.base.ConditioningCondition the model on the time t.
- Parameters:
t (torch.Tensor) – Time tensor of shape (Nodes, 1).
- omega¶
- get_hparams() Dict¶
Return hyperparameters for this time conditioning module.
- get_conditioning(t: torch.Tensor) torch.Tensor¶
Get the conditioning tensor for the time t.
- ::math::
egin{align*} mathbf{c} = egin{bmatrix} sin(omega t) cos(omega t) end{bmatrix} end{align*}
- Parameters:
t (torch.Tensor) – Time tensor of shape (Nodes, 1).
- Returns:
Conditioning tensor of shape (Nodes, 2).
- Return type:
torch.Tensor
- get_empty_conditioning(n: int) torch.Tensor¶
Get an empty conditioning tensor.
- Returns:
Empty conditioning tensor of shape (1, 2).
- Return type:
torch.Tensor
- forward(batch: AtomsGraph, empty: bool = False) AtomsGraph¶
Forward method to get the conditioning from the input
This ignores training and empty flags.
- Parameters:
batch (AtomsGraph) – The input batch
empty (bool) – If True, return an empty conditioning tensor
- Returns:
The batch with the conditioning added to the representation
- Return type:
- class agedi.models.Conditioning(property: str, input_dim: int, output_dim: int, concatenation_type: str = 'scalar', probability: float = 0.8, **kwargs)¶
Bases:
abc.ABC,lightning.LightningModuleConditioning Base Class
- Parameters:
property (str) – The property of the batch to condition on
input_dim (int) – The dimension of the input conditioning
output_dim (int) – The dimension of the output conditioning
concatenation_type (str) – The type of concatenation to use. Default is “scalar”
probability (float) – The probability of conditioning. Default is 0.5. Only used in training mode
- Return type:
- property¶
- input_dim¶
- output_dim¶
- concatenation_type = 'scalar'¶
- probability = 0.8¶
- get_hparams() Dict¶
Return hyperparameters sufficient to reconstruct this conditioning module.
Returns a dictionary with a
_target_key (the fully-qualified class name) plusproperty,probability,input_dim, andoutput_dimfrom the base class. Subclasses should callsuper().get_hparams()and merge in their own constructor parameters.- Returns:
Hyperparameter dictionary.
- Return type:
dict
- abstractmethod get_conditioning(x: torch.Tensor) torch.Tensor¶
Abstract method to get the conditioning from the input
Must be implemented by the subclass
- Parameters:
x (torch.Tensor) – The input tensor
- Returns:
The conditioning tensor
- Return type:
torch.Tensor
- abstractmethod get_empty_conditioning(n: int) torch.Tensor¶
Abstract method to get an empty conditioning tensor
Must be implemented by the subclass
- Parameters:
n (int) – The number of nodes in the batch
Returns
torch.Tensor – The empty conditioning tensor
- forward(batch: AtomsGraph, empty: bool = False) AtomsGraph¶
Forward method to get the conditioning from the input
- Parameters:
batch (AtomsGraph) – The input batch
empty (bool) – If True, return an empty conditioning tensor
- Returns:
The batch with the conditioning added to the representation
- Return type:
- concatenate(batch: AtomsGraph, c: torch.Tensor) None¶
Concatenate the conditioning to the batch
cmust already be expanded to node-level (c.shape[0] == batch.pos.shape[0]) before this method is called. The pre-expansion is performed byforward()to keep this method free of dynamic shape comparisons, which would otherwise cause graph breaks undertorch.compile.- Parameters:
batch (AtomsGraph) – The input batch
c (torch.Tensor) – Node-level conditioning tensor of shape
(n_nodes, output_dim)or(n_nodes, output_dim, 1).
- Return type:
None
- sample_mode() None¶
Set the model to sample mode
- Return type:
None
- training_mode() None¶
Set the model to train mode
- Return type:
None