agedi.data

Submodules

Classes

AtomsGraph

Atomistic Graph Class

Representation

Representation class

Dataset

Defines a custom dataset for AtomsGraph data

Package Contents

class agedi.data.AtomsGraph

Bases: torch_geometric.data.Data

Atomistic Graph Class

Class defining a graph with atoms as nodes and edges formed between all atoms within a finite cutoff radius.

Parameters:
  • pos (torch.Tensor) – The positions of the atoms with shape (n_atoms, 3).

  • x (torch.Tensor) – The node features i.e atomic types of the graph with shape (n_nodes, 1).

  • edge_index (torch.Tensor) – The edge index tensor of the graph with shape (2, n_edges).

  • edge_attr (torch.Tensor) – The edge attributes of the graph with shape (n_edges, n_edge_features).

  • y (Optional[torch.Tensor]) – The target tensor of the graph with shape (n_targets,).

  • representation (Optional[Representation]) – The representation of the atoms in the graph.

  • confinement (Optional[torch.Tensor]) – z-directional confinement of the atoms with shape (1,2).

  • kwargs (Dict[str, torch.Tensor])

classmethod from_atoms(atoms: ase.Atoms, cutoff: float = 6.0, dtype: torch.dtype = torch.float, initialize_mask: bool | None = None, confinement: Tuple[float, float] | None = None, canonical_cell: bool = False) AtomsGraph

Create a graph from an ASE Atoms object.

Parameters:
  • atoms (Atoms) – The ASE Atoms object.

  • cutoff (float) – The cutoff radius for the edges.

  • dtype (torch.dtype) – The data type of the tensors.

  • initialize_mask (Optional[bool]) – Whether to initialize the mask tensor. When None (the default), the mask is initialised only when confinement is not provided (i.e. initialize_mask defaults to False for template / confinement graphs).

  • confinement (Optional[Tuple[float, float]]) – Optional z-directional confinement bounds (z_min, z_max) to attach to the graph. When provided, a confinement tensor of shape (1, 2) is stored on the graph. When None (the default), no confinement attribute is added.

  • canonical_cell (bool) – When True, the cell is stored in canonical lower-triangular form. If the input cell is not already canonical, Cartesian positions are recomputed to preserve fractional coordinates and a warning is raised. Set to False (the default) to store the cell exactly as provided by ASE (no rotation or recomputation is performed).

Returns:

graph – The graph object.

Return type:

AtomsGraph

classmethod empty(cutoff: float = 6.0) AtomsGraph

Create an empty graph.

Parameters:

cutoff (float) – The cutoff radius for the edges.

Returns:

graph – The graph object.

Return type:

AtomsGraph

add_batch_attr(key: str, value: torch.Tensor, type: str = 'node') None

Add a batch attribute to the graph.

Parameters:
  • key (str) – The key of the attribute.

  • value (torch.Tensor) – The value of the attribute.

  • type (str) – The type of the attribute. Can be either “node” or “graph”

Return type:

None

to_atoms() ase.Atoms

Convert the graph to an ASE Atoms object.

Only works on unbatched graphs.

Returns:

atoms – The atoms object.

Return type:

ase.Atoms

_get_scalar_attr(key: str) float | None
prepare_for_compile(cutoff: float) None

Pre-allocate neighbor-list buffers for torch.compile compatibility.

Estimates the maximum number of neighbors per atom using estimate_max_neighbors() and the cell-list dimensions using estimate_cell_list_sizes(), then allocates the cell list and all output buffers with fixed shapes. Fixed shapes are required for torch.compile to trace the reverse diffusion step once without retracing on subsequent iterations.

Must be called on a Batch before the first update_graph() call.

Requires the nvalchemiops package.

Parameters:

cutoff (float) – Neighbor-list cutoff radius (Å).

Raises:
  • RuntimeError – When nvalchemiops is not installed.

  • TypeError – When called on an unbatched AtomsGraph instead of a Batch.

static _cell_list_to_graph(neighbor_matrix: torch.Tensor, neighbor_shifts: torch.Tensor, cell: torch.Tensor, dtype: torch.dtype, batch_idx: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor]

Convert cell-list query output to (edge_index, shift_vectors).

update_graph() bool

Update the graph with new edges

This should be called after changing any of the positions or cell.

Returns:

rebuiltTrue when the neighbor list was fully recomputed.

Return type:

bool

static _make_graph_matscipy(positions: torch.Tensor, cell: torch.Tensor, cutoff: float, pbc: torch.Tensor, dtype: torch.dtype | None = None, batch_idx: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor]
static make_graph(positions: torch.Tensor, cell: torch.Tensor, cutoff: float, pbc: torch.Tensor, dtype: torch.dtype = None, batch_idx: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor]

Create the graph-edges from the positions and cell.

Parameters:
  • positions (torch.Tensor) – The positions of the atoms.

  • cell (torch.Tensor) – The cell of the system.

  • cutoff (float) – The cutoff radius for the edges.

  • pbc (torch.Tensor) – The periodic boundary conditions.

  • dtype (torch.dtype) – The data type of the output.

Returns:

  • edge_index (torch.Tensor) – The edge index tensor.

  • shift_vectors (torch.Tensor) – The shift vectors tensor.

clear_graph() None

Clear the graph removing all edges

Return type:

None

__len__() int

Return the number of atoms in the graph.

Returns:

n_atoms – The number of atoms in the graph.

Return type:

int

property cell: torch.Tensor

Return the canonical cell matrix of the graph.

Returns:

cell – The cell matrix of shape (3, 3).

Return type:

torch.Tensor

property frac: torch.Tensor

Return the fractional coordinates of the positions

Returns:

frac – The fractional coordinates of the atoms.

Return type:

torch.Tensor

frac_to_pos(f: torch.Tensor) torch.Tensor

Fraction -> Cartesian coordinates.

Convert fractional coordinates to cartesian coordinates.

Parameters:

f (torch.Tensor) – The fractional coordinates.

Returns:

r – The cartesian coordinates.

Return type:

torch.Tensor

pos_to_frac(r: torch.Tensor) torch.Tensor

Cartesian -> Fractional coordinates.

Convert cartesian coordinates to fractional coordinates.

Parameters:

r (torch.Tensor) – The cartesian coordinates.

Returns:

f – The fractional coordinates.

Return type:

torch.Tensor

property positions_mask: torch.Tensor

Return the mask of the positions that are fixed.

True for fixed atom-positions and else false.

Returns:

mask – The mask of the positions that are fixed.

Return type:

torch.Tensor

property time: torch.Tensor

Return the time of the graph.

Returns:

time – The time of the graph.

Return type:

torch.Tensor

property representation: Representation | None

Return the representation of the graph.

Returns:

representation – The representation of the graph, or None if not set.

Return type:

Optional[Representation]

wrap_positions() None

Wrap the positions of the atoms to the unit cell.

Return type:

None

apply_mask(x: torch.Tensor, val: float = 0.0) torch.Tensor

Apply the mask to the tensor x.

Parameters:
  • x (torch.Tensor) – The tensor to apply the mask to.

  • val (float) – The value to set the masked values to.

Returns:

x – The tensor with the mask applied.

Return type:

torch.Tensor

property confinement: torch.Tensor

Return the confinement of the graph.

Returns:

confinement – The confinement of the graph.

Return type:

torch.Tensor

property cellpar: torch.Tensor

Return the cell parameters of the graph.

static _is_lower_triangular(cell: torch.Tensor) bool

Return True if cell is in canonical lower-triangular form.

A cell matrix is considered canonical when the three strictly upper-triangular entries (cell[0,1], cell[0,2], cell[1,2]) are all zero (within a tight floating-point tolerance of 1e-10).

Parameters:

cell (torch.Tensor) – The cell matrix.

Returns:

True if the cell is already lower-triangular.

Return type:

bool

static cell_to_vectors(cell: torch.Tensor) torch.Tensor

Convert cell matrix to cell parameters.

Parameters:

cell (torch.Tensor) – The cell matrix of shape (N, 3) or (N, 3, 3).

Returns:

The cell parameters of shape (N, 6).

Return type:

torch.Tensor

static vector_to_cell(cellpar: torch.Tensor) torch.Tensor

Convert cell parameters to cell matrix.

Parameters:

cellpar (torch.Tensor) – The cell parameters of shape (N, 6).

Returns:

The cell matrix of shape (N, 3, 3) where each row is a lattice vector.

Return type:

torch.Tensor

class agedi.data.Representation

Representation class

A simple container holding the scalar (l=0) and vector (l=1) equivariant representations produced by the backbone network. Both fields are optional so that the class can also be used for partial representations.

Registered as a torch.utils._pytree node so that torch.compile can traverse instances transparently without introducing graph breaks.

Parameters:
  • scalar (Optional[torch.Tensor]) – Per-node scalar features of shape (n_nodes, n_features, 1). Default is None.

  • vector (Optional[torch.Tensor]) – Per-node vector features of shape (n_nodes, n_features, 3). Default is None.

scalar: torch.Tensor | None = None
vector: torch.Tensor | None = None
to_tensor(n_graphs: int) Tuple[torch.Tensor, torch.Tensor]

Serialise scalar and vector tensors into a single flat representation.

Concatenates scalar and vector (when present) along the feature dimension. Returns the concatenated tensor together with per-graph slice boundaries and degree values so that from_tensor() can reconstruct the original fields.

Parameters:

n_graphs (int) – The number of graphs in the batch. The slice and degree tensors are repeated once per graph so they can be stored as graph-level attributes.

Returns:

  • tensor (torch.Tensor) – Concatenated representation of shape (n_nodes, total_features).

  • slices (torch.Tensor) – Cumulative slice boundaries of shape (n_graphs, n_parts + 1).

  • ls (torch.Tensor) – Degree values of shape (n_graphs, n_parts).

classmethod from_tensor(tensor: torch.Tensor, slices: torch.Tensor, ls: torch.Tensor) Representation

Reconstruct a Representation from a flat serialised form.

Parameters:
  • tensor (torch.Tensor) – Flat representation of shape (n_nodes, total_features).

  • slices (torch.Tensor) – Cumulative slice boundaries of shape (n_graphs, n_parts + 1).

  • ls (torch.Tensor) – Degree values of shape (n_graphs, n_parts).

Return type:

Representation

class agedi.data.Dataset(batch_size: int = 32, n_train: float | int = 0.9, n_val: float | int = 0.1, n_test: float | int = 0.0, shuffle: bool = True, properties: List[str] = ['energy', 'forces'], cutoff: float = 6.0, phase_transforms: List[List[torch_geometric.transforms.BaseTransform]] | None = None, num_workers: int = 0, **kwargs)

Bases: lightning.LightningDataModule

Defines a custom dataset for AtomsGraph data

Parameters:
  • batch_size (int) – The batch size for the DataLoader

  • n_train (Union[float, int]) – The number of training samples. If float, it is interpreted as a fraction of the dataset size

  • n_val (Union[float, int]) – The number of validation samples. If float, it is interpreted as a fraction of the dataset size

  • n_test (Union[float, int]) – The number of test samples. If float, it is interpreted as a fraction of the dataset size

  • shuffle (bool) – Whether to shuffle the dataset

  • properties (List[str]) – The properties to include in the dataset. Can be “energy”, “forces”, or both

  • cutoff (float) – The cutoff radius for the neighbor list

  • phase_transforms (Optional[List[List[BaseTransform]]]) – The data augmentation transforms to apply to each training phase

Return type:

Dataset

batch_size = 32
n_train = 0.9
n_val = 0.1
n_test = 0.0
properties = ['energy', 'forces']
cutoff = 6.0
dataset = None
train_idx = None
val_idx = None
test_idx = None
phase_transforms = None
num_workers = 0
regressor_dataset = None
regressor_train_loader = None
add_atoms_data(data: List[ase.Atoms], mask_method: str | None = None, confinement: Tuple[float, float] | None = None, properties: List[Dict] | None = None, canonical_cell: bool = False) None

Add ASE data to the dataset

Converts a list of ASE Atoms objects to AtomsGraph objects and adds them to the dataset

Parameters:
  • data (List[Atoms]) – A list of ASE Atoms objects

  • mask_method (str, optional) – Method for computing the atom mask (e.g. "MaskFixed").

  • confinement (Tuple[float, float], optional) – Z-axis confinement bounds (z_min, z_max) applied to every structure.

  • properties (List[Dict], optional) – Per-structure property dictionaries; each entry is mapped to the corresponding graph via setattr().

  • canonical_cell (bool, optional) – When True (the default), cells are stored in canonical lower-triangular form. Set to False to store cells exactly as provided by ASE.

Return type:

None

add_graph_data(data: List[agedi.data.atoms_graph.AtomsGraph]) None

Add AtomsGraph data to the dataset

Adds a list of AtomsGraph objects to the dataset

Parameters:

data (List[AtomsGraph]) – A list of AtomsGraph objects

Return type:

None

add_regressor_data(data: List[ase.Atoms], canonical_cell: bool = False) None

Add atoms data that will be used exclusively for regressor training.

Structures in this dataset are only used to train the regressor model (e.g. force-field heads) and are never passed through the diffusion loss. This allows the regressor to learn from non-equilibrium structures that would be unsuitable as diffusion training targets.

Energy and forces are read from the ASE calculator attached to each Atoms object when available.

Parameters:
  • data (List[Atoms]) – A list of ASE Atoms objects, each with an attached calculator that provides energy and forces.

  • canonical_cell (bool, optional) – When True, cells are stored in canonical lower-triangular form. Defaults to False.

Return type:

None

setup(stage: str | None = None) None

Set up train/validation/test splits and initialise data loaders.

Performs a random split of the dataset (if not already split) and calls set_phase() to create the initial data loaders.

Parameters:

stage (str, optional) – Lightning stage identifier ("fit", "test", etc.). Not used internally; present for API compatibility.

train_dataloader() torch_geometric.loader.DataLoader

Get the training DataLoader

Returns a DataLoader for the training dataset. When a separate regressor dataset has been added via add_regressor_data(), a CombinedLoader is returned so that each training step receives both a regular batch (key "main") and a regressor-only batch (key "regressor").

Return type:

DataLoader or CombinedLoader

val_dataloader() torch_geometric.loader.DataLoader

Get the validation DataLoader

Returns a DataLoader for the validation dataset

Return type:

DataLoader

test_dataloader() torch_geometric.loader.DataLoader

Get the test DataLoader

Returns a DataLoader for the test dataset

Return type:

DataLoader

set_phase(phase: int) None

Switch the dataset to the given training phase.

Applies the phase-specific transforms to the dataset splits and re-creates the data loaders with the augmented data.

Parameters:

phase (int) – Zero-based phase index. Phase 0 uses the original data; subsequent phases append transformed copies according to phase_transforms[phase].

_check_confinement(dataset: List[agedi.data.atoms_graph.AtomsGraph], confinement: Tuple[float, float]) None

Check that all unmasked atoms in dataset lie within confinement.

Parameters:
  • dataset (List[AtomsGraph]) – The list of graphs to validate.

  • confinement (Tuple[float, float]) – The (z_min, z_max) confinement bounds.

Raises:

ValueError – If any unmasked atom has a Z position outside the confinement. The error message includes a suggested confinement that covers all unmasked atoms.

_has_energy_forces(atoms)

Check if the given ASE Atoms object has energy and forces information available. This method checks if a calculator is attached to the Atoms object and if it contains the ‘energy’ and ‘forces’ properties in its results. It avoids a calculation if there is a calculator, but it has not yet been used.

Parameters:

atoms (Atoms) – The ASE Atoms object to check for energy and forces information.

Returns:

A tuple indicating whether energy and forces information is available, respectively.

Return type:

Tuple[bool, bool]