agedi.models.conditionings.base

Classes

Conditioning

Conditioning Base Class

Module Contents

class agedi.models.conditionings.base.Conditioning(property: str, input_dim: int, output_dim: int, concatenation_type: str = 'scalar', probability: float = 0.8, **kwargs)

Bases: abc.ABC, lightning.LightningModule

Conditioning Base Class

Parameters:
  • property (str) – The property of the batch to condition on

  • input_dim (int) – The dimension of the input conditioning

  • output_dim (int) – The dimension of the output conditioning

  • concatenation_type (str) – The type of concatenation to use. Default is “scalar”

  • probability (float) – The probability of conditioning. Default is 0.5. Only used in training mode

Return type:

Conditioning

property
input_dim
output_dim
concatenation_type = 'scalar'
probability = 0.8
get_hparams() Dict

Return hyperparameters sufficient to reconstruct this conditioning module.

Returns a dictionary with a _target_ key (the fully-qualified class name) plus property, probability, input_dim, and output_dim from the base class. Subclasses should call super().get_hparams() and merge in their own constructor parameters.

Returns:

Hyperparameter dictionary.

Return type:

dict

abstractmethod get_conditioning(x: torch.Tensor) torch.Tensor

Abstract method to get the conditioning from the input

Must be implemented by the subclass

Parameters:

x (torch.Tensor) – The input tensor

Returns:

The conditioning tensor

Return type:

torch.Tensor

abstractmethod get_empty_conditioning(n: int) torch.Tensor

Abstract method to get an empty conditioning tensor

Must be implemented by the subclass

Parameters:
  • n (int) – The number of nodes in the batch

  • Returns

  • torch.Tensor – The empty conditioning tensor

forward(batch: AtomsGraph, empty: bool = False) AtomsGraph

Forward method to get the conditioning from the input

Parameters:
  • batch (AtomsGraph) – The input batch

  • empty (bool) – If True, return an empty conditioning tensor

Returns:

The batch with the conditioning added to the representation

Return type:

AtomsGraph

concatenate(batch: AtomsGraph, c: torch.Tensor) None

Concatenate the conditioning to the batch

c must already be expanded to node-level (c.shape[0] == batch.pos.shape[0]) before this method is called. The pre-expansion is performed by forward() to keep this method free of dynamic shape comparisons, which would otherwise cause graph breaks under torch.compile.

Parameters:
  • batch (AtomsGraph) – The input batch

  • c (torch.Tensor) – Node-level conditioning tensor of shape (n_nodes, output_dim) or (n_nodes, output_dim, 1).

Return type:

None

sample_mode() None

Set the model to sample mode

Return type:

None

training_mode() None

Set the model to train mode

Return type:

None