agedi.models

Submodules

Classes

ScoreModel

Class that defines a the score model.

TimeConditioning

Condition the model on the time t.

Conditioning

Conditioning Base Class

Package Contents

class agedi.models.ScoreModel(translator: agedi.models.translator.Translator, representation: agedi.data.Representation, conditionings: List[agedi.models.conditionings.Conditioning] | None = None, heads: List[agedi.models.head.Head] | None = None, w: float = -1.0, **kwargs)

Bases: lightning.LightningModule

Class that defines a the score model.

It is a combination of a translator, a representation, a list of conditionings and a list of heads.

Parameters:
  • translator (Translator) – The translator that will be used to translate the input batch.

  • representation (Representation) – The representation that will be used to represent the translated batch.

  • conditionings (List[Conditioning]) – The list of conditionings that will be applied to the representation.

  • heads (List[Head]) – The list of heads that will be used to compute scores.

translator
representation
conditionings
heads
w
guidance = True
get_hparams() Dict

Return hyperparameters sufficient to reconstruct this score model.

Collects hyperparameters from the translator, representation (via get_representation_hparams()), conditionings, and heads, as well as the guidance weight w.

Returns:

Hyperparameter dictionary with a _target_ key and nested translator, representation, conditionings, heads, and w entries.

Return type:

dict

forward(batch: torch_geometric.data.Batch) torch_geometric.data.Batch

Forward pass of the model.

Dispatches to forward_sample() when the model is in sampling mode, and to forward_train() otherwise. This keeps self.sample as a compile-time constant for each compiled subgraph, avoiding retracing on mode changes and eliminating the Python-level branch from the compiled region.

Parameters:

batch (Batch) – The input batch that will be used to compute the scores.

Returns:

The output batch containing the scores.

Return type:

Batch

forward_train(batch: torch_geometric.data.Batch) torch_geometric.data.Batch

Training-mode forward pass.

Computes the backbone representation, applies all conditionings unconditionally, translates the conditioned batch, and evaluates every score head.

Parameters:

batch (Batch) – The input batch.

Returns:

Batch with score tensors attached.

Return type:

Batch

forward_sample(batch: torch_geometric.data.Batch) torch_geometric.data.Batch

Sampling-mode forward pass.

Computes the backbone representation, applies classifier-free guidance (when self.guidance is True), translates the conditioned batch, and evaluates every score head with optional guidance mixing.

Parameters:

batch (Batch) – The input batch.

Returns:

Batch with score tensors attached.

Return type:

Batch

sample_mode() None

Switch the model to sampling mode.

Sets self.sample = True and calls sample_mode() on all conditioning modules so that classifier-free guidance is applied during inference.

training_mode() None

Switch the model to training mode.

Sets self.sample = False and calls training_mode() on all conditioning modules so that conditioning is applied unconditionally during the forward pass.

class agedi.models.TimeConditioning(input_dim: int = 1, output_dim: int = 2, **kwargs)

Bases: agedi.models.conditionings.base.Conditioning

Condition the model on the time t.

Parameters:

t (torch.Tensor) – Time tensor of shape (Nodes, 1).

omega
get_hparams() Dict

Return hyperparameters for this time conditioning module.

get_conditioning(t: torch.Tensor) torch.Tensor

Get the conditioning tensor for the time t.

::math::

egin{align*} mathbf{c} = egin{bmatrix} sin(omega t) cos(omega t) end{bmatrix} end{align*}

Parameters:

t (torch.Tensor) – Time tensor of shape (Nodes, 1).

Returns:

Conditioning tensor of shape (Nodes, 2).

Return type:

torch.Tensor

get_empty_conditioning(n: int) torch.Tensor

Get an empty conditioning tensor.

Returns:

Empty conditioning tensor of shape (1, 2).

Return type:

torch.Tensor

forward(batch: AtomsGraph, empty: bool = False) AtomsGraph

Forward method to get the conditioning from the input

This ignores training and empty flags.

Parameters:
  • batch (AtomsGraph) – The input batch

  • empty (bool) – If True, return an empty conditioning tensor

Returns:

The batch with the conditioning added to the representation

Return type:

AtomsGraph

class agedi.models.Conditioning(property: str, input_dim: int, output_dim: int, concatenation_type: str = 'scalar', probability: float = 0.8, **kwargs)

Bases: abc.ABC, lightning.LightningModule

Conditioning Base Class

Parameters:
  • property (str) – The property of the batch to condition on

  • input_dim (int) – The dimension of the input conditioning

  • output_dim (int) – The dimension of the output conditioning

  • concatenation_type (str) – The type of concatenation to use. Default is “scalar”

  • probability (float) – The probability of conditioning. Default is 0.5. Only used in training mode

Return type:

Conditioning

property
input_dim
output_dim
concatenation_type = 'scalar'
probability = 0.8
get_hparams() Dict

Return hyperparameters sufficient to reconstruct this conditioning module.

Returns a dictionary with a _target_ key (the fully-qualified class name) plus property, probability, input_dim, and output_dim from the base class. Subclasses should call super().get_hparams() and merge in their own constructor parameters.

Returns:

Hyperparameter dictionary.

Return type:

dict

abstractmethod get_conditioning(x: torch.Tensor) torch.Tensor

Abstract method to get the conditioning from the input

Must be implemented by the subclass

Parameters:

x (torch.Tensor) – The input tensor

Returns:

The conditioning tensor

Return type:

torch.Tensor

abstractmethod get_empty_conditioning(n: int) torch.Tensor

Abstract method to get an empty conditioning tensor

Must be implemented by the subclass

Parameters:
  • n (int) – The number of nodes in the batch

  • Returns

  • torch.Tensor – The empty conditioning tensor

forward(batch: AtomsGraph, empty: bool = False) AtomsGraph

Forward method to get the conditioning from the input

Parameters:
  • batch (AtomsGraph) – The input batch

  • empty (bool) – If True, return an empty conditioning tensor

Returns:

The batch with the conditioning added to the representation

Return type:

AtomsGraph

concatenate(batch: AtomsGraph, c: torch.Tensor) None

Concatenate the conditioning to the batch

c must already be expanded to node-level (c.shape[0] == batch.pos.shape[0]) before this method is called. The pre-expansion is performed by forward() to keep this method free of dynamic shape comparisons, which would otherwise cause graph breaks under torch.compile.

Parameters:
  • batch (AtomsGraph) – The input batch

  • c (torch.Tensor) – Node-level conditioning tensor of shape (n_nodes, output_dim) or (n_nodes, output_dim, 1).

Return type:

None

sample_mode() None

Set the model to sample mode

Return type:

None

training_mode() None

Set the model to train mode

Return type:

None