agedi.data.dataset ================== .. py:module:: agedi.data.dataset Classes ------- .. autoapisummary:: agedi.data.dataset.Dataset Module Contents --------------- .. py:class:: Dataset(batch_size: int = 32, n_train: Union[float, int] = 0.9, n_val: Union[float, int] = 0.1, n_test: Union[float, int] = 0.0, shuffle: bool = True, properties: List[str] = ['energy', 'forces'], cutoff: float = 6.0, phase_transforms: Optional[List[List[torch_geometric.transforms.BaseTransform]]] = None, num_workers: int = 0, **kwargs) Bases: :py:obj:`lightning.LightningDataModule` Defines a custom dataset for AtomsGraph data :param batch_size: The batch size for the DataLoader :type batch_size: int :param n_train: The number of training samples. If float, it is interpreted as a fraction of the dataset size :type n_train: Union[float, int] :param n_val: The number of validation samples. If float, it is interpreted as a fraction of the dataset size :type n_val: Union[float, int] :param n_test: The number of test samples. If float, it is interpreted as a fraction of the dataset size :type n_test: Union[float, int] :param shuffle: Whether to shuffle the dataset :type shuffle: bool :param properties: The properties to include in the dataset. Can be "energy", "forces", or both :type properties: List[str] :param cutoff: The cutoff radius for the neighbor list :type cutoff: float :param phase_transforms: The data augmentation transforms to apply to each training phase :type phase_transforms: Optional[List[List[BaseTransform]]] :rtype: Dataset .. py:attribute:: batch_size :value: 32 .. py:attribute:: n_train :value: 0.9 .. py:attribute:: n_val :value: 0.1 .. py:attribute:: n_test :value: 0.0 .. py:attribute:: properties :value: ['energy', 'forces'] .. py:attribute:: cutoff :value: 6.0 .. py:attribute:: dataset :value: None .. py:attribute:: train_idx :value: None .. py:attribute:: val_idx :value: None .. py:attribute:: test_idx :value: None .. py:attribute:: phase_transforms :value: None .. py:attribute:: num_workers :value: 0 .. py:attribute:: regressor_dataset :value: None .. py:attribute:: regressor_train_loader :value: None .. py:method:: add_atoms_data(data: List[ase.Atoms], mask_method: Optional[str] = None, confinement: Optional[Tuple[float, float]] = None, properties: Optional[List[Dict]] = None, canonical_cell: bool = False) -> None Add ASE data to the dataset Converts a list of ASE Atoms objects to AtomsGraph objects and adds them to the dataset :param data: A list of ASE Atoms objects :type data: List[Atoms] :param mask_method: Method for computing the atom mask (e.g. ``"MaskFixed"``). :type mask_method: str, optional :param confinement: Z-axis confinement bounds ``(z_min, z_max)`` applied to every structure. :type confinement: Tuple[float, float], optional :param properties: Per-structure property dictionaries; each entry is mapped to the corresponding graph via :func:`setattr`. :type properties: List[Dict], optional :param canonical_cell: When ``True`` (the default), cells are stored in canonical lower-triangular form. Set to ``False`` to store cells exactly as provided by ASE. :type canonical_cell: bool, optional :rtype: None .. py:method:: add_graph_data(data: List[agedi.data.atoms_graph.AtomsGraph]) -> None Add AtomsGraph data to the dataset Adds a list of AtomsGraph objects to the dataset :param data: A list of AtomsGraph objects :type data: List[AtomsGraph] :rtype: None .. py:method:: add_regressor_data(data: List[ase.Atoms], canonical_cell: bool = False) -> None Add atoms data that will be used exclusively for regressor training. Structures in this dataset are only used to train the regressor model (e.g. force-field heads) and are never passed through the diffusion loss. This allows the regressor to learn from non-equilibrium structures that would be unsuitable as diffusion training targets. Energy and forces are read from the ASE calculator attached to each :class:`~ase.Atoms` object when available. :param data: A list of ASE :class:`~ase.Atoms` objects, each with an attached calculator that provides energy and forces. :type data: List[Atoms] :param canonical_cell: When ``True``, cells are stored in canonical lower-triangular form. Defaults to ``False``. :type canonical_cell: bool, optional :rtype: None .. py:method:: setup(stage: Optional[str] = None) -> None Set up train/validation/test splits and initialise data loaders. Performs a random split of the dataset (if not already split) and calls :meth:`set_phase` to create the initial data loaders. :param stage: Lightning stage identifier (``"fit"``, ``"test"``, etc.). Not used internally; present for API compatibility. :type stage: str, optional .. py:method:: train_dataloader() -> torch_geometric.loader.DataLoader Get the training DataLoader Returns a DataLoader for the training dataset. When a separate regressor dataset has been added via :meth:`add_regressor_data`, a :class:`~lightning.pytorch.utilities.CombinedLoader` is returned so that each training step receives both a regular batch (key ``"main"``) and a regressor-only batch (key ``"regressor"``). :rtype: DataLoader or CombinedLoader .. py:method:: val_dataloader() -> torch_geometric.loader.DataLoader Get the validation DataLoader Returns a DataLoader for the validation dataset :rtype: DataLoader .. py:method:: test_dataloader() -> torch_geometric.loader.DataLoader Get the test DataLoader Returns a DataLoader for the test dataset :rtype: DataLoader .. py:method:: set_phase(phase: int) -> None Switch the dataset to the given training phase. Applies the phase-specific transforms to the dataset splits and re-creates the data loaders with the augmented data. :param phase: Zero-based phase index. Phase 0 uses the original data; subsequent phases append transformed copies according to ``phase_transforms[phase]``. :type phase: int .. py:method:: _check_confinement(dataset: List[agedi.data.atoms_graph.AtomsGraph], confinement: Tuple[float, float]) -> None Check that all unmasked atoms in *dataset* lie within *confinement*. :param dataset: The list of graphs to validate. :type dataset: List[AtomsGraph] :param confinement: The ``(z_min, z_max)`` confinement bounds. :type confinement: Tuple[float, float] :raises ValueError: If any unmasked atom has a Z position outside the confinement. The error message includes a suggested confinement that covers all unmasked atoms. .. py:method:: _has_energy_forces(atoms) Check if the given ASE Atoms object has energy and forces information available. This method checks if a calculator is attached to the Atoms object and if it contains the 'energy' and 'forces' properties in its results. It avoids a calculation if there is a calculator, but it has not yet been used. :param atoms: The ASE Atoms object to check for energy and forces information. :type atoms: Atoms :returns: A tuple indicating whether energy and forces information is available, respectively. :rtype: Tuple[bool, bool]