agedi.models.conditionings.integer

Classes

IntegerConditioning

Conditioning module for integer-valued properties.

Module Contents

class agedi.models.conditionings.integer.IntegerConditioning(max_int: int = 200, input_dim: int = 1, output_dim: int = 64, *args, **kwargs)

Bases: agedi.models.conditionings.base.Conditioning

Conditioning module for integer-valued properties.

Embeds an integer property (e.g. number of atoms) into a fixed-size representation using torch.nn.Embedding.

max_int = 200
embedder
get_hparams() Dict

Return hyperparameters for this integer conditioning module.

get_conditioning(x: torch.Tensor) torch.Tensor

Get the conditioning tensor for x

Parameters:

x (torch.Tensor) – Time tensor of shape (Nodes, 1).

Returns:

Conditioning tensor of shape (Nodes, 2).

Return type:

torch.Tensor

get_empty_conditioning(n: int) torch.Tensor

Get an empty conditioning tensor.

Returns:

Empty conditioning tensor of shape (n, 2).

Return type:

torch.Tensor