agedi.data.atoms_graph¶
Attributes¶
Classes¶
Representation class |
|
Atomistic Graph Class |
Functions¶
|
Batched decorator |
Module Contents¶
- agedi.data.atoms_graph.NVIDIA_NEIGHBOR_IMPORT_ERROR = None¶
- agedi.data.atoms_graph.NVIDIA_CELL_LIST_IMPORT_ERROR = None¶
- agedi.data.atoms_graph.NEIGHBOR_CACHE_KEYS = ('edge_index', 'shift_vectors')¶
- agedi.data.atoms_graph.batched(update_keys: Sequence[str] | None = None, return_batch: bool = False) Callable¶
Batched decorator
Decorator for functions that return Data objects, but can with this operator be called with batched inputs. The function will be called for each element in the batch, and the results will be concatenated into a single Data object.
If called with a Data-object as input, the function will be called with as if it not decorated.
- Parameters:
update_keys (Optional[Sequence[str]]) – The keys in the Batch object that should be updated. If None, no keys will be updated.
return_batch (bool) – If True, the function will return a Batch object instead of None.
- Return type:
Callable
- class agedi.data.atoms_graph.Representation¶
Representation class
A simple container holding the scalar (l=0) and vector (l=1) equivariant representations produced by the backbone network. Both fields are optional so that the class can also be used for partial representations.
Registered as a
torch.utils._pytreenode so thattorch.compilecan traverse instances transparently without introducing graph breaks.- Parameters:
scalar (Optional[torch.Tensor]) – Per-node scalar features of shape
(n_nodes, n_features, 1). Default isNone.vector (Optional[torch.Tensor]) – Per-node vector features of shape
(n_nodes, n_features, 3). Default isNone.
- scalar: torch.Tensor | None = None¶
- vector: torch.Tensor | None = None¶
- to_tensor(n_graphs: int) Tuple[torch.Tensor, torch.Tensor]¶
Serialise scalar and vector tensors into a single flat representation.
Concatenates
scalarandvector(when present) along the feature dimension. Returns the concatenated tensor together with per-graph slice boundaries and degree values so thatfrom_tensor()can reconstruct the original fields.- Parameters:
n_graphs (int) – The number of graphs in the batch. The slice and degree tensors are repeated once per graph so they can be stored as graph-level attributes.
- Returns:
tensor (torch.Tensor) – Concatenated representation of shape
(n_nodes, total_features).slices (torch.Tensor) – Cumulative slice boundaries of shape
(n_graphs, n_parts + 1).ls (torch.Tensor) – Degree values of shape
(n_graphs, n_parts).
- classmethod from_tensor(tensor: torch.Tensor, slices: torch.Tensor, ls: torch.Tensor) Representation¶
Reconstruct a
Representationfrom a flat serialised form.- Parameters:
tensor (torch.Tensor) – Flat representation of shape
(n_nodes, total_features).slices (torch.Tensor) – Cumulative slice boundaries of shape
(n_graphs, n_parts + 1).ls (torch.Tensor) – Degree values of shape
(n_graphs, n_parts).
- Return type:
- class agedi.data.atoms_graph.AtomsGraph¶
Bases:
torch_geometric.data.DataAtomistic Graph Class
Class defining a graph with atoms as nodes and edges formed between all atoms within a finite cutoff radius.
- Parameters:
pos (torch.Tensor) – The positions of the atoms with shape (n_atoms, 3).
x (torch.Tensor) – The node features i.e atomic types of the graph with shape (n_nodes, 1).
edge_index (torch.Tensor) – The edge index tensor of the graph with shape (2, n_edges).
edge_attr (torch.Tensor) – The edge attributes of the graph with shape (n_edges, n_edge_features).
y (Optional[torch.Tensor]) – The target tensor of the graph with shape (n_targets,).
representation (Optional[Representation]) – The representation of the atoms in the graph.
confinement (Optional[torch.Tensor]) – z-directional confinement of the atoms with shape (1,2).
kwargs (Dict[str, torch.Tensor])
- classmethod from_atoms(atoms: ase.Atoms, cutoff: float = 6.0, dtype: torch.dtype = torch.float, initialize_mask: bool | None = None, confinement: Tuple[float, float] | None = None, canonical_cell: bool = False) AtomsGraph¶
Create a graph from an ASE Atoms object.
- Parameters:
atoms (Atoms) – The ASE Atoms object.
cutoff (float) – The cutoff radius for the edges.
dtype (torch.dtype) – The data type of the tensors.
initialize_mask (Optional[bool]) – Whether to initialize the mask tensor. When
None(the default), the mask is initialised only whenconfinementis not provided (i.e.initialize_maskdefaults toFalsefor template / confinement graphs).confinement (Optional[Tuple[float, float]]) – Optional z-directional confinement bounds
(z_min, z_max)to attach to the graph. When provided, aconfinementtensor of shape(1, 2)is stored on the graph. WhenNone(the default), no confinement attribute is added.canonical_cell (bool) – When
True, the cell is stored in canonical lower-triangular form. If the input cell is not already canonical, Cartesian positions are recomputed to preserve fractional coordinates and a warning is raised. Set toFalse(the default) to store the cell exactly as provided by ASE (no rotation or recomputation is performed).
- Returns:
graph – The graph object.
- Return type:
- classmethod empty(cutoff: float = 6.0) AtomsGraph¶
Create an empty graph.
- Parameters:
cutoff (float) – The cutoff radius for the edges.
- Returns:
graph – The graph object.
- Return type:
- add_batch_attr(key: str, value: torch.Tensor, type: str = 'node') None¶
Add a batch attribute to the graph.
- Parameters:
key (str) – The key of the attribute.
value (torch.Tensor) – The value of the attribute.
type (str) – The type of the attribute. Can be either “node” or “graph”
- Return type:
None
- to_atoms() ase.Atoms¶
Convert the graph to an ASE Atoms object.
Only works on unbatched graphs.
- Returns:
atoms – The atoms object.
- Return type:
ase.Atoms
- _get_scalar_attr(key: str) float | None¶
- prepare_for_compile(cutoff: float) None¶
Pre-allocate neighbor-list buffers for
torch.compilecompatibility.Estimates the maximum number of neighbors per atom using
estimate_max_neighbors()and the cell-list dimensions usingestimate_cell_list_sizes(), then allocates the cell list and all output buffers with fixed shapes. Fixed shapes are required fortorch.compileto trace the reverse diffusion step once without retracing on subsequent iterations.Must be called on a
Batchbefore the firstupdate_graph()call.Requires the
nvalchemiopspackage.- Parameters:
cutoff (float) – Neighbor-list cutoff radius (Å).
- Raises:
RuntimeError – When
nvalchemiopsis not installed.TypeError – When called on an unbatched
AtomsGraphinstead of aBatch.
- static _cell_list_to_graph(neighbor_matrix: torch.Tensor, neighbor_shifts: torch.Tensor, cell: torch.Tensor, dtype: torch.dtype, batch_idx: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor]¶
Convert cell-list query output to
(edge_index, shift_vectors).
- update_graph() bool¶
Update the graph with new edges
This should be called after changing any of the positions or cell.
- Returns:
rebuilt –
Truewhen the neighbor list was fully recomputed.- Return type:
bool
- static _make_graph_matscipy(positions: torch.Tensor, cell: torch.Tensor, cutoff: float, pbc: torch.Tensor, dtype: torch.dtype | None = None, batch_idx: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor]¶
- static make_graph(positions: torch.Tensor, cell: torch.Tensor, cutoff: float, pbc: torch.Tensor, dtype: torch.dtype = None, batch_idx: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor]¶
Create the graph-edges from the positions and cell.
- Parameters:
positions (torch.Tensor) – The positions of the atoms.
cell (torch.Tensor) – The cell of the system.
cutoff (float) – The cutoff radius for the edges.
pbc (torch.Tensor) – The periodic boundary conditions.
dtype (torch.dtype) – The data type of the output.
- Returns:
edge_index (torch.Tensor) – The edge index tensor.
shift_vectors (torch.Tensor) – The shift vectors tensor.
- clear_graph() None¶
Clear the graph removing all edges
- Return type:
None
- __len__() int¶
Return the number of atoms in the graph.
- Returns:
n_atoms – The number of atoms in the graph.
- Return type:
int
- property cell: torch.Tensor¶
Return the canonical cell matrix of the graph.
- Returns:
cell – The cell matrix of shape
(3, 3).- Return type:
torch.Tensor
- property frac: torch.Tensor¶
Return the fractional coordinates of the positions
- Returns:
frac – The fractional coordinates of the atoms.
- Return type:
torch.Tensor
- frac_to_pos(f: torch.Tensor) torch.Tensor¶
Fraction -> Cartesian coordinates.
Convert fractional coordinates to cartesian coordinates.
- Parameters:
f (torch.Tensor) – The fractional coordinates.
- Returns:
r – The cartesian coordinates.
- Return type:
torch.Tensor
- pos_to_frac(r: torch.Tensor) torch.Tensor¶
Cartesian -> Fractional coordinates.
Convert cartesian coordinates to fractional coordinates.
- Parameters:
r (torch.Tensor) – The cartesian coordinates.
- Returns:
f – The fractional coordinates.
- Return type:
torch.Tensor
- property positions_mask: torch.Tensor¶
Return the mask of the positions that are fixed.
True for fixed atom-positions and else false.
- Returns:
mask – The mask of the positions that are fixed.
- Return type:
torch.Tensor
- property time: torch.Tensor¶
Return the time of the graph.
- Returns:
time – The time of the graph.
- Return type:
torch.Tensor
- property representation: Representation | None¶
Return the representation of the graph.
- Returns:
representation – The representation of the graph, or
Noneif not set.- Return type:
Optional[Representation]
- wrap_positions() None¶
Wrap the positions of the atoms to the unit cell.
- Return type:
None
- apply_mask(x: torch.Tensor, val: float = 0.0) torch.Tensor¶
Apply the mask to the tensor x.
- Parameters:
x (torch.Tensor) – The tensor to apply the mask to.
val (float) – The value to set the masked values to.
- Returns:
x – The tensor with the mask applied.
- Return type:
torch.Tensor
- property confinement: torch.Tensor¶
Return the confinement of the graph.
- Returns:
confinement – The confinement of the graph.
- Return type:
torch.Tensor
- property cellpar: torch.Tensor¶
Return the cell parameters of the graph.
- static _is_lower_triangular(cell: torch.Tensor) bool¶
Return True if cell is in canonical lower-triangular form.
A cell matrix is considered canonical when the three strictly upper-triangular entries (cell[0,1], cell[0,2], cell[1,2]) are all zero (within a tight floating-point tolerance of 1e-10).
- Parameters:
cell (torch.Tensor) – The cell matrix.
- Returns:
True if the cell is already lower-triangular.
- Return type:
bool
- static cell_to_vectors(cell: torch.Tensor) torch.Tensor¶
Convert cell matrix to cell parameters.
- Parameters:
cell (torch.Tensor) – The cell matrix of shape
(N, 3)or(N, 3, 3).- Returns:
The cell parameters of shape
(N, 6).- Return type:
torch.Tensor
- static vector_to_cell(cellpar: torch.Tensor) torch.Tensor¶
Convert cell parameters to cell matrix.
- Parameters:
cellpar (torch.Tensor) – The cell parameters of shape
(N, 6).- Returns:
The cell matrix of shape
(N, 3, 3)where each row is a lattice vector.- Return type:
torch.Tensor