agedi.models.conditionings.time

Classes

TimeConditioning

Condition the model on the time t.

Module Contents

class agedi.models.conditionings.time.TimeConditioning(input_dim: int = 1, output_dim: int = 2, **kwargs)

Bases: agedi.models.conditionings.base.Conditioning

Condition the model on the time t.

Parameters:

t (torch.Tensor) – Time tensor of shape (Nodes, 1).

omega
get_hparams() Dict

Return hyperparameters for this time conditioning module.

get_conditioning(t: torch.Tensor) torch.Tensor

Get the conditioning tensor for the time t.

::math::

egin{align*} mathbf{c} = egin{bmatrix} sin(omega t) cos(omega t) end{bmatrix} end{align*}

Parameters:

t (torch.Tensor) – Time tensor of shape (Nodes, 1).

Returns:

Conditioning tensor of shape (Nodes, 2).

Return type:

torch.Tensor

get_empty_conditioning(n: int) torch.Tensor

Get an empty conditioning tensor.

Returns:

Empty conditioning tensor of shape (1, 2).

Return type:

torch.Tensor

forward(batch: AtomsGraph, empty: bool = False) AtomsGraph

Forward method to get the conditioning from the input

This ignores training and empty flags.

Parameters:
  • batch (AtomsGraph) – The input batch

  • empty (bool) – If True, return an empty conditioning tensor

Returns:

The batch with the conditioning added to the representation

Return type:

AtomsGraph