agedi.models.conditionings

Submodules

Classes

Conditioning

Conditioning Base Class

TimeConditioning

Condition the model on the time t.

ScalarConditioning

Conditioning module for continuous scalar properties.

IntegerConditioning

Conditioning module for integer-valued properties.

Package Contents

class agedi.models.conditionings.Conditioning(property: str, input_dim: int, output_dim: int, concatenation_type: str = 'scalar', probability: float = 0.8, **kwargs)

Bases: abc.ABC, lightning.LightningModule

Conditioning Base Class

Parameters:
  • property (str) – The property of the batch to condition on

  • input_dim (int) – The dimension of the input conditioning

  • output_dim (int) – The dimension of the output conditioning

  • concatenation_type (str) – The type of concatenation to use. Default is “scalar”

  • probability (float) – The probability of conditioning. Default is 0.5. Only used in training mode

Return type:

Conditioning

property
input_dim
output_dim
concatenation_type = 'scalar'
probability = 0.8
get_hparams() Dict

Return hyperparameters sufficient to reconstruct this conditioning module.

Returns a dictionary with a _target_ key (the fully-qualified class name) plus property, probability, input_dim, and output_dim from the base class. Subclasses should call super().get_hparams() and merge in their own constructor parameters.

Returns:

Hyperparameter dictionary.

Return type:

dict

abstractmethod get_conditioning(x: torch.Tensor) torch.Tensor

Abstract method to get the conditioning from the input

Must be implemented by the subclass

Parameters:

x (torch.Tensor) – The input tensor

Returns:

The conditioning tensor

Return type:

torch.Tensor

abstractmethod get_empty_conditioning(n: int) torch.Tensor

Abstract method to get an empty conditioning tensor

Must be implemented by the subclass

Parameters:
  • n (int) – The number of nodes in the batch

  • Returns

  • torch.Tensor – The empty conditioning tensor

forward(batch: AtomsGraph, empty: bool = False) AtomsGraph

Forward method to get the conditioning from the input

Parameters:
  • batch (AtomsGraph) – The input batch

  • empty (bool) – If True, return an empty conditioning tensor

Returns:

The batch with the conditioning added to the representation

Return type:

AtomsGraph

concatenate(batch: AtomsGraph, c: torch.Tensor) None

Concatenate the conditioning to the batch

c must already be expanded to node-level (c.shape[0] == batch.pos.shape[0]) before this method is called. The pre-expansion is performed by forward() to keep this method free of dynamic shape comparisons, which would otherwise cause graph breaks under torch.compile.

Parameters:
  • batch (AtomsGraph) – The input batch

  • c (torch.Tensor) – Node-level conditioning tensor of shape (n_nodes, output_dim) or (n_nodes, output_dim, 1).

Return type:

None

sample_mode() None

Set the model to sample mode

Return type:

None

training_mode() None

Set the model to train mode

Return type:

None

class agedi.models.conditionings.TimeConditioning(input_dim: int = 1, output_dim: int = 2, **kwargs)

Bases: agedi.models.conditionings.base.Conditioning

Condition the model on the time t.

Parameters:

t (torch.Tensor) – Time tensor of shape (Nodes, 1).

omega
get_hparams() Dict

Return hyperparameters for this time conditioning module.

get_conditioning(t: torch.Tensor) torch.Tensor

Get the conditioning tensor for the time t.

::math::

egin{align*} mathbf{c} = egin{bmatrix} sin(omega t) cos(omega t) end{bmatrix} end{align*}

Parameters:

t (torch.Tensor) – Time tensor of shape (Nodes, 1).

Returns:

Conditioning tensor of shape (Nodes, 2).

Return type:

torch.Tensor

get_empty_conditioning(n: int) torch.Tensor

Get an empty conditioning tensor.

Returns:

Empty conditioning tensor of shape (1, 2).

Return type:

torch.Tensor

forward(batch: AtomsGraph, empty: bool = False) AtomsGraph

Forward method to get the conditioning from the input

This ignores training and empty flags.

Parameters:
  • batch (AtomsGraph) – The input batch

  • empty (bool) – If True, return an empty conditioning tensor

Returns:

The batch with the conditioning added to the representation

Return type:

AtomsGraph

class agedi.models.conditionings.ScalarConditioning(*args, input_dim: int = 1, output_dim: int = 2, **kwargs)

Bases: agedi.models.conditionings.base.Conditioning

Conditioning module for continuous scalar properties.

Projects a scalar property through a learned linear layer and encodes it with sinusoidal features (cos and sin), producing a 2-dimensional conditioning vector.

embedder
get_conditioning(x: torch.Tensor) torch.Tensor

Get the conditioning tensor for x

Parameters:

x (torch.Tensor) – Time tensor of shape (Nodes, 1).

Returns:

Conditioning tensor of shape (Nodes, 2).

Return type:

torch.Tensor

get_empty_conditioning(n: int) torch.Tensor

Get an empty conditioning tensor.

Returns:

Empty conditioning tensor of shape (n, 2).

Return type:

torch.Tensor

class agedi.models.conditionings.IntegerConditioning(max_int: int = 200, input_dim: int = 1, output_dim: int = 64, *args, **kwargs)

Bases: agedi.models.conditionings.base.Conditioning

Conditioning module for integer-valued properties.

Embeds an integer property (e.g. number of atoms) into a fixed-size representation using torch.nn.Embedding.

max_int = 200
embedder
get_hparams() Dict

Return hyperparameters for this integer conditioning module.

get_conditioning(x: torch.Tensor) torch.Tensor

Get the conditioning tensor for x

Parameters:

x (torch.Tensor) – Time tensor of shape (Nodes, 1).

Returns:

Conditioning tensor of shape (Nodes, 2).

Return type:

torch.Tensor

get_empty_conditioning(n: int) torch.Tensor

Get an empty conditioning tensor.

Returns:

Empty conditioning tensor of shape (n, 2).

Return type:

torch.Tensor