agedi.data.atoms_graph ====================== .. py:module:: agedi.data.atoms_graph Attributes ---------- .. autoapisummary:: agedi.data.atoms_graph.NVIDIA_NEIGHBOR_IMPORT_ERROR agedi.data.atoms_graph.NVIDIA_CELL_LIST_IMPORT_ERROR agedi.data.atoms_graph.NEIGHBOR_CACHE_KEYS Classes ------- .. autoapisummary:: agedi.data.atoms_graph.Representation agedi.data.atoms_graph.AtomsGraph Functions --------- .. autoapisummary:: agedi.data.atoms_graph.batched Module Contents --------------- .. py:data:: NVIDIA_NEIGHBOR_IMPORT_ERROR :value: None .. py:data:: NVIDIA_CELL_LIST_IMPORT_ERROR :value: None .. py:data:: NEIGHBOR_CACHE_KEYS :value: ('edge_index', 'shift_vectors') .. py:function:: batched(update_keys: Optional[Sequence[str]] = None, return_batch: bool = False) -> Callable Batched decorator Decorator for functions that return Data objects, but can with this operator be called with batched inputs. The function will be called for each element in the batch, and the results will be concatenated into a single Data object. If called with a Data-object as input, the function will be called with as if it not decorated. :param update_keys: The keys in the Batch object that should be updated. If None, no keys will be updated. :type update_keys: Optional[Sequence[str]] :param return_batch: If True, the function will return a Batch object instead of None. :type return_batch: bool :rtype: Callable .. py:class:: Representation Representation class A simple container holding the scalar (l=0) and vector (l=1) equivariant representations produced by the backbone network. Both fields are optional so that the class can also be used for partial representations. Registered as a ``torch.utils._pytree`` node so that ``torch.compile`` can traverse instances transparently without introducing graph breaks. :param scalar: Per-node scalar features of shape ``(n_nodes, n_features, 1)``. Default is ``None``. :type scalar: Optional[torch.Tensor] :param vector: Per-node vector features of shape ``(n_nodes, n_features, 3)``. Default is ``None``. :type vector: Optional[torch.Tensor] .. py:attribute:: scalar :type: Optional[torch.Tensor] :value: None .. py:attribute:: vector :type: Optional[torch.Tensor] :value: None .. py:method:: to_tensor(n_graphs: int) -> Tuple[torch.Tensor, torch.Tensor] Serialise scalar and vector tensors into a single flat representation. Concatenates ``scalar`` and ``vector`` (when present) along the feature dimension. Returns the concatenated tensor together with per-graph slice boundaries and degree values so that :meth:`from_tensor` can reconstruct the original fields. :param n_graphs: The number of graphs in the batch. The slice and degree tensors are repeated once per graph so they can be stored as graph-level attributes. :type n_graphs: int :returns: * **tensor** (*torch.Tensor*) -- Concatenated representation of shape ``(n_nodes, total_features)``. * **slices** (*torch.Tensor*) -- Cumulative slice boundaries of shape ``(n_graphs, n_parts + 1)``. * **ls** (*torch.Tensor*) -- Degree values of shape ``(n_graphs, n_parts)``. .. py:method:: from_tensor(tensor: torch.Tensor, slices: torch.Tensor, ls: torch.Tensor) -> Representation :classmethod: Reconstruct a :class:`Representation` from a flat serialised form. :param tensor: Flat representation of shape ``(n_nodes, total_features)``. :type tensor: torch.Tensor :param slices: Cumulative slice boundaries of shape ``(n_graphs, n_parts + 1)``. :type slices: torch.Tensor :param ls: Degree values of shape ``(n_graphs, n_parts)``. :type ls: torch.Tensor :rtype: Representation .. py:class:: AtomsGraph Bases: :py:obj:`torch_geometric.data.Data` Atomistic Graph Class Class defining a graph with atoms as nodes and edges formed between all atoms within a finite cutoff radius. :param pos: The positions of the atoms with shape (n_atoms, 3). :type pos: torch.Tensor :param x: The node features i.e atomic types of the graph with shape (n_nodes, 1). :type x: torch.Tensor :param edge_index: The edge index tensor of the graph with shape (2, n_edges). :type edge_index: torch.Tensor :param edge_attr: The edge attributes of the graph with shape (n_edges, n_edge_features). :type edge_attr: torch.Tensor :param y: The target tensor of the graph with shape (n_targets,). :type y: Optional[torch.Tensor] :param representation: The representation of the atoms in the graph. :type representation: Optional[Representation] :param confinement: z-directional confinement of the atoms with shape (1,2). :type confinement: Optional[torch.Tensor] :param kwargs: :type kwargs: Dict[str, torch.Tensor] .. py:method:: from_atoms(atoms: ase.Atoms, cutoff: float = 6.0, dtype: torch.dtype = torch.float, initialize_mask: Optional[bool] = None, confinement: Optional[Tuple[float, float]] = None, canonical_cell: bool = False) -> AtomsGraph :classmethod: Create a graph from an ASE Atoms object. :param atoms: The ASE Atoms object. :type atoms: Atoms :param cutoff: The cutoff radius for the edges. :type cutoff: float :param dtype: The data type of the tensors. :type dtype: torch.dtype :param initialize_mask: Whether to initialize the mask tensor. When ``None`` (the default), the mask is initialised only when ``confinement`` is not provided (i.e. ``initialize_mask`` defaults to ``False`` for template / confinement graphs). :type initialize_mask: Optional[bool] :param confinement: Optional z-directional confinement bounds ``(z_min, z_max)`` to attach to the graph. When provided, a ``confinement`` tensor of shape ``(1, 2)`` is stored on the graph. When ``None`` (the default), no confinement attribute is added. :type confinement: Optional[Tuple[float, float]] :param canonical_cell: When ``True``, the cell is stored in canonical lower-triangular form. If the input cell is not already canonical, Cartesian positions are recomputed to preserve fractional coordinates and a warning is raised. Set to ``False`` (the default) to store the cell exactly as provided by ASE (no rotation or recomputation is performed). :type canonical_cell: bool :returns: **graph** -- The graph object. :rtype: AtomsGraph .. py:method:: empty(cutoff: float = 6.0) -> AtomsGraph :classmethod: Create an empty graph. :param cutoff: The cutoff radius for the edges. :type cutoff: float :returns: **graph** -- The graph object. :rtype: AtomsGraph .. py:method:: add_batch_attr(key: str, value: torch.Tensor, type: str = 'node') -> None Add a batch attribute to the graph. :param key: The key of the attribute. :type key: str :param value: The value of the attribute. :type value: torch.Tensor :param type: The type of the attribute. Can be either "node" or "graph" :type type: str :rtype: None .. py:method:: to_atoms() -> ase.Atoms Convert the graph to an ASE Atoms object. Only works on unbatched graphs. :returns: **atoms** -- The atoms object. :rtype: ase.Atoms .. py:method:: _get_scalar_attr(key: str) -> Optional[float] .. py:method:: prepare_for_compile(cutoff: float) -> None Pre-allocate neighbor-list buffers for ``torch.compile`` compatibility. Estimates the maximum number of neighbors per atom using :func:`~nvalchemiops.torch.neighbors.neighbor_utils.estimate_max_neighbors` and the cell-list dimensions using :func:`~nvalchemiops.torch.neighbors.cell_list.estimate_cell_list_sizes`, then allocates the cell list and all output buffers with fixed shapes. Fixed shapes are required for ``torch.compile`` to trace the reverse diffusion step once without retracing on subsequent iterations. Must be called on a :class:`~torch_geometric.data.Batch` **before** the first :meth:`update_graph` call. Requires the ``nvalchemiops`` package. :param cutoff: Neighbor-list cutoff radius (Å). :type cutoff: float :raises RuntimeError: When ``nvalchemiops`` is not installed. :raises TypeError: When called on an unbatched :class:`AtomsGraph` instead of a :class:`~torch_geometric.data.Batch`. .. py:method:: _cell_list_to_graph(neighbor_matrix: torch.Tensor, neighbor_shifts: torch.Tensor, cell: torch.Tensor, dtype: torch.dtype, batch_idx: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor] :staticmethod: Convert cell-list query output to ``(edge_index, shift_vectors)``. .. py:method:: update_graph() -> bool Update the graph with new edges This should be called after changing any of the positions or cell. :returns: **rebuilt** -- ``True`` when the neighbor list was fully recomputed. :rtype: bool .. py:method:: _make_graph_matscipy(positions: torch.Tensor, cell: torch.Tensor, cutoff: float, pbc: torch.Tensor, dtype: Optional[torch.dtype] = None, batch_idx: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor] :staticmethod: .. py:method:: make_graph(positions: torch.Tensor, cell: torch.Tensor, cutoff: float, pbc: torch.Tensor, dtype: torch.dtype = None, batch_idx: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor] :staticmethod: Create the graph-edges from the positions and cell. :param positions: The positions of the atoms. :type positions: torch.Tensor :param cell: The cell of the system. :type cell: torch.Tensor :param cutoff: The cutoff radius for the edges. :type cutoff: float :param pbc: The periodic boundary conditions. :type pbc: torch.Tensor :param dtype: The data type of the output. :type dtype: torch.dtype :returns: * **edge_index** (*torch.Tensor*) -- The edge index tensor. * **shift_vectors** (*torch.Tensor*) -- The shift vectors tensor. .. py:method:: clear_graph() -> None Clear the graph removing all edges :rtype: None .. py:method:: __len__() -> int Return the number of atoms in the graph. :returns: **n_atoms** -- The number of atoms in the graph. :rtype: int .. py:property:: cell :type: torch.Tensor Return the canonical cell matrix of the graph. :returns: **cell** -- The cell matrix of shape ``(3, 3)``. :rtype: torch.Tensor .. py:property:: frac :type: torch.Tensor Return the fractional coordinates of the positions :returns: **frac** -- The fractional coordinates of the atoms. :rtype: torch.Tensor .. py:method:: frac_to_pos(f: torch.Tensor) -> torch.Tensor Fraction -> Cartesian coordinates. Convert fractional coordinates to cartesian coordinates. :param f: The fractional coordinates. :type f: torch.Tensor :returns: **r** -- The cartesian coordinates. :rtype: torch.Tensor .. py:method:: pos_to_frac(r: torch.Tensor) -> torch.Tensor Cartesian -> Fractional coordinates. Convert cartesian coordinates to fractional coordinates. :param r: The cartesian coordinates. :type r: torch.Tensor :returns: **f** -- The fractional coordinates. :rtype: torch.Tensor .. py:property:: positions_mask :type: torch.Tensor Return the mask of the positions that are fixed. True for fixed atom-positions and else false. :returns: **mask** -- The mask of the positions that are fixed. :rtype: torch.Tensor .. py:property:: time :type: torch.Tensor Return the time of the graph. :returns: **time** -- The time of the graph. :rtype: torch.Tensor .. py:property:: representation :type: Optional[Representation] Return the representation of the graph. :returns: **representation** -- The representation of the graph, or ``None`` if not set. :rtype: Optional[Representation] .. py:method:: wrap_positions() -> None Wrap the positions of the atoms to the unit cell. :rtype: None .. py:method:: apply_mask(x: torch.Tensor, val: float = 0.0) -> torch.Tensor Apply the mask to the tensor x. :param x: The tensor to apply the mask to. :type x: torch.Tensor :param val: The value to set the masked values to. :type val: float :returns: **x** -- The tensor with the mask applied. :rtype: torch.Tensor .. py:property:: confinement :type: torch.Tensor Return the confinement of the graph. :returns: **confinement** -- The confinement of the graph. :rtype: torch.Tensor .. py:property:: cellpar :type: torch.Tensor Return the cell parameters of the graph. .. py:method:: _is_lower_triangular(cell: torch.Tensor) -> bool :staticmethod: Return True if *cell* is in canonical lower-triangular form. A cell matrix is considered canonical when the three strictly upper-triangular entries (cell[0,1], cell[0,2], cell[1,2]) are all zero (within a tight floating-point tolerance of 1e-10). :param cell: The cell matrix. :type cell: torch.Tensor :returns: True if the cell is already lower-triangular. :rtype: bool .. py:method:: cell_to_vectors(cell: torch.Tensor) -> torch.Tensor :staticmethod: Convert cell matrix to cell parameters. :param cell: The cell matrix of shape ``(N, 3)`` or ``(N, 3, 3)``. :type cell: torch.Tensor :returns: The cell parameters of shape ``(N, 6)``. :rtype: torch.Tensor .. py:method:: vector_to_cell(cellpar: torch.Tensor) -> torch.Tensor :staticmethod: Convert cell parameters to cell matrix. :param cellpar: The cell parameters of shape ``(N, 6)``. :type cellpar: torch.Tensor :returns: The cell matrix of shape ``(N, 3, 3)`` where each row is a lattice vector. :rtype: torch.Tensor