agedi.data.atoms_graph

Attributes

Classes

Representation

Representation class

AtomsGraph

Atomistic Graph Class

Functions

batched(→ Callable)

Batched decorator

Module Contents

agedi.data.atoms_graph.NVIDIA_NEIGHBOR_IMPORT_ERROR = None
agedi.data.atoms_graph.NVIDIA_CELL_LIST_IMPORT_ERROR = None
agedi.data.atoms_graph.NEIGHBOR_CACHE_KEYS = ('edge_index', 'shift_vectors')
agedi.data.atoms_graph.batched(update_keys: Sequence[str] | None = None, return_batch: bool = False) Callable

Batched decorator

Decorator for functions that return Data objects, but can with this operator be called with batched inputs. The function will be called for each element in the batch, and the results will be concatenated into a single Data object.

If called with a Data-object as input, the function will be called with as if it not decorated.

Parameters:
  • update_keys (Optional[Sequence[str]]) – The keys in the Batch object that should be updated. If None, no keys will be updated.

  • return_batch (bool) – If True, the function will return a Batch object instead of None.

Return type:

Callable

class agedi.data.atoms_graph.Representation

Representation class

A simple container holding the scalar (l=0) and vector (l=1) equivariant representations produced by the backbone network. Both fields are optional so that the class can also be used for partial representations.

Registered as a torch.utils._pytree node so that torch.compile can traverse instances transparently without introducing graph breaks.

Parameters:
  • scalar (Optional[torch.Tensor]) – Per-node scalar features of shape (n_nodes, n_features, 1). Default is None.

  • vector (Optional[torch.Tensor]) – Per-node vector features of shape (n_nodes, n_features, 3). Default is None.

scalar: torch.Tensor | None = None
vector: torch.Tensor | None = None
to_tensor(n_graphs: int) Tuple[torch.Tensor, torch.Tensor]

Serialise scalar and vector tensors into a single flat representation.

Concatenates scalar and vector (when present) along the feature dimension. Returns the concatenated tensor together with per-graph slice boundaries and degree values so that from_tensor() can reconstruct the original fields.

Parameters:

n_graphs (int) – The number of graphs in the batch. The slice and degree tensors are repeated once per graph so they can be stored as graph-level attributes.

Returns:

  • tensor (torch.Tensor) – Concatenated representation of shape (n_nodes, total_features).

  • slices (torch.Tensor) – Cumulative slice boundaries of shape (n_graphs, n_parts + 1).

  • ls (torch.Tensor) – Degree values of shape (n_graphs, n_parts).

classmethod from_tensor(tensor: torch.Tensor, slices: torch.Tensor, ls: torch.Tensor) Representation

Reconstruct a Representation from a flat serialised form.

Parameters:
  • tensor (torch.Tensor) – Flat representation of shape (n_nodes, total_features).

  • slices (torch.Tensor) – Cumulative slice boundaries of shape (n_graphs, n_parts + 1).

  • ls (torch.Tensor) – Degree values of shape (n_graphs, n_parts).

Return type:

Representation

class agedi.data.atoms_graph.AtomsGraph

Bases: torch_geometric.data.Data

Atomistic Graph Class

Class defining a graph with atoms as nodes and edges formed between all atoms within a finite cutoff radius.

Parameters:
  • pos (torch.Tensor) – The positions of the atoms with shape (n_atoms, 3).

  • x (torch.Tensor) – The node features i.e atomic types of the graph with shape (n_nodes, 1).

  • edge_index (torch.Tensor) – The edge index tensor of the graph with shape (2, n_edges).

  • edge_attr (torch.Tensor) – The edge attributes of the graph with shape (n_edges, n_edge_features).

  • y (Optional[torch.Tensor]) – The target tensor of the graph with shape (n_targets,).

  • representation (Optional[Representation]) – The representation of the atoms in the graph.

  • confinement (Optional[torch.Tensor]) – z-directional confinement of the atoms with shape (1,2).

  • kwargs (Dict[str, torch.Tensor])

classmethod from_atoms(atoms: ase.Atoms, cutoff: float = 6.0, dtype: torch.dtype = torch.float, initialize_mask: bool | None = None, confinement: Tuple[float, float] | None = None, canonical_cell: bool = False) AtomsGraph

Create a graph from an ASE Atoms object.

Parameters:
  • atoms (Atoms) – The ASE Atoms object.

  • cutoff (float) – The cutoff radius for the edges.

  • dtype (torch.dtype) – The data type of the tensors.

  • initialize_mask (Optional[bool]) – Whether to initialize the mask tensor. When None (the default), the mask is initialised only when confinement is not provided (i.e. initialize_mask defaults to False for template / confinement graphs).

  • confinement (Optional[Tuple[float, float]]) – Optional z-directional confinement bounds (z_min, z_max) to attach to the graph. When provided, a confinement tensor of shape (1, 2) is stored on the graph. When None (the default), no confinement attribute is added.

  • canonical_cell (bool) – When True, the cell is stored in canonical lower-triangular form. If the input cell is not already canonical, Cartesian positions are recomputed to preserve fractional coordinates and a warning is raised. Set to False (the default) to store the cell exactly as provided by ASE (no rotation or recomputation is performed).

Returns:

graph – The graph object.

Return type:

AtomsGraph

classmethod empty(cutoff: float = 6.0) AtomsGraph

Create an empty graph.

Parameters:

cutoff (float) – The cutoff radius for the edges.

Returns:

graph – The graph object.

Return type:

AtomsGraph

add_batch_attr(key: str, value: torch.Tensor, type: str = 'node') None

Add a batch attribute to the graph.

Parameters:
  • key (str) – The key of the attribute.

  • value (torch.Tensor) – The value of the attribute.

  • type (str) – The type of the attribute. Can be either “node” or “graph”

Return type:

None

to_atoms() ase.Atoms

Convert the graph to an ASE Atoms object.

Only works on unbatched graphs.

Returns:

atoms – The atoms object.

Return type:

ase.Atoms

_get_scalar_attr(key: str) float | None
prepare_for_compile(cutoff: float) None

Pre-allocate neighbor-list buffers for torch.compile compatibility.

Estimates the maximum number of neighbors per atom using estimate_max_neighbors() and the cell-list dimensions using estimate_cell_list_sizes(), then allocates the cell list and all output buffers with fixed shapes. Fixed shapes are required for torch.compile to trace the reverse diffusion step once without retracing on subsequent iterations.

Must be called on a Batch before the first update_graph() call.

Requires the nvalchemiops package.

Parameters:

cutoff (float) – Neighbor-list cutoff radius (Å).

Raises:
  • RuntimeError – When nvalchemiops is not installed.

  • TypeError – When called on an unbatched AtomsGraph instead of a Batch.

static _cell_list_to_graph(neighbor_matrix: torch.Tensor, neighbor_shifts: torch.Tensor, cell: torch.Tensor, dtype: torch.dtype, batch_idx: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor]

Convert cell-list query output to (edge_index, shift_vectors).

update_graph() bool

Update the graph with new edges

This should be called after changing any of the positions or cell.

Returns:

rebuiltTrue when the neighbor list was fully recomputed.

Return type:

bool

static _make_graph_matscipy(positions: torch.Tensor, cell: torch.Tensor, cutoff: float, pbc: torch.Tensor, dtype: torch.dtype | None = None, batch_idx: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor]
static make_graph(positions: torch.Tensor, cell: torch.Tensor, cutoff: float, pbc: torch.Tensor, dtype: torch.dtype = None, batch_idx: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor]

Create the graph-edges from the positions and cell.

Parameters:
  • positions (torch.Tensor) – The positions of the atoms.

  • cell (torch.Tensor) – The cell of the system.

  • cutoff (float) – The cutoff radius for the edges.

  • pbc (torch.Tensor) – The periodic boundary conditions.

  • dtype (torch.dtype) – The data type of the output.

Returns:

  • edge_index (torch.Tensor) – The edge index tensor.

  • shift_vectors (torch.Tensor) – The shift vectors tensor.

clear_graph() None

Clear the graph removing all edges

Return type:

None

__len__() int

Return the number of atoms in the graph.

Returns:

n_atoms – The number of atoms in the graph.

Return type:

int

property cell: torch.Tensor

Return the canonical cell matrix of the graph.

Returns:

cell – The cell matrix of shape (3, 3).

Return type:

torch.Tensor

property frac: torch.Tensor

Return the fractional coordinates of the positions

Returns:

frac – The fractional coordinates of the atoms.

Return type:

torch.Tensor

frac_to_pos(f: torch.Tensor) torch.Tensor

Fraction -> Cartesian coordinates.

Convert fractional coordinates to cartesian coordinates.

Parameters:

f (torch.Tensor) – The fractional coordinates.

Returns:

r – The cartesian coordinates.

Return type:

torch.Tensor

pos_to_frac(r: torch.Tensor) torch.Tensor

Cartesian -> Fractional coordinates.

Convert cartesian coordinates to fractional coordinates.

Parameters:

r (torch.Tensor) – The cartesian coordinates.

Returns:

f – The fractional coordinates.

Return type:

torch.Tensor

property positions_mask: torch.Tensor

Return the mask of the positions that are fixed.

True for fixed atom-positions and else false.

Returns:

mask – The mask of the positions that are fixed.

Return type:

torch.Tensor

property time: torch.Tensor

Return the time of the graph.

Returns:

time – The time of the graph.

Return type:

torch.Tensor

property representation: Representation | None

Return the representation of the graph.

Returns:

representation – The representation of the graph, or None if not set.

Return type:

Optional[Representation]

wrap_positions() None

Wrap the positions of the atoms to the unit cell.

Return type:

None

apply_mask(x: torch.Tensor, val: float = 0.0) torch.Tensor

Apply the mask to the tensor x.

Parameters:
  • x (torch.Tensor) – The tensor to apply the mask to.

  • val (float) – The value to set the masked values to.

Returns:

x – The tensor with the mask applied.

Return type:

torch.Tensor

property confinement: torch.Tensor

Return the confinement of the graph.

Returns:

confinement – The confinement of the graph.

Return type:

torch.Tensor

property cellpar: torch.Tensor

Return the cell parameters of the graph.

static _is_lower_triangular(cell: torch.Tensor) bool

Return True if cell is in canonical lower-triangular form.

A cell matrix is considered canonical when the three strictly upper-triangular entries (cell[0,1], cell[0,2], cell[1,2]) are all zero (within a tight floating-point tolerance of 1e-10).

Parameters:

cell (torch.Tensor) – The cell matrix.

Returns:

True if the cell is already lower-triangular.

Return type:

bool

static cell_to_vectors(cell: torch.Tensor) torch.Tensor

Convert cell matrix to cell parameters.

Parameters:

cell (torch.Tensor) – The cell matrix of shape (N, 3) or (N, 3, 3).

Returns:

The cell parameters of shape (N, 6).

Return type:

torch.Tensor

static vector_to_cell(cellpar: torch.Tensor) torch.Tensor

Convert cell parameters to cell matrix.

Parameters:

cellpar (torch.Tensor) – The cell parameters of shape (N, 6).

Returns:

The cell matrix of shape (N, 3, 3) where each row is a lattice vector.

Return type:

torch.Tensor