agedi.data.dataset¶
Classes¶
Defines a custom dataset for AtomsGraph data |
Module Contents¶
- class agedi.data.dataset.Dataset(batch_size: int = 32, n_train: float | int = 0.9, n_val: float | int = 0.1, n_test: float | int = 0.0, shuffle: bool = True, properties: List[str] = ['energy', 'forces'], cutoff: float = 6.0, phase_transforms: List[List[torch_geometric.transforms.BaseTransform]] | None = None, num_workers: int = 0, **kwargs)¶
Bases:
lightning.LightningDataModuleDefines a custom dataset for AtomsGraph data
- Parameters:
batch_size (int) – The batch size for the DataLoader
n_train (Union[float, int]) – The number of training samples. If float, it is interpreted as a fraction of the dataset size
n_val (Union[float, int]) – The number of validation samples. If float, it is interpreted as a fraction of the dataset size
n_test (Union[float, int]) – The number of test samples. If float, it is interpreted as a fraction of the dataset size
shuffle (bool) – Whether to shuffle the dataset
properties (List[str]) – The properties to include in the dataset. Can be “energy”, “forces”, or both
cutoff (float) – The cutoff radius for the neighbor list
phase_transforms (Optional[List[List[BaseTransform]]]) – The data augmentation transforms to apply to each training phase
- Return type:
- batch_size = 32¶
- n_train = 0.9¶
- n_val = 0.1¶
- n_test = 0.0¶
- properties = ['energy', 'forces']¶
- cutoff = 6.0¶
- dataset = None¶
- train_idx = None¶
- val_idx = None¶
- test_idx = None¶
- phase_transforms = None¶
- num_workers = 0¶
- regressor_dataset = None¶
- regressor_train_loader = None¶
- add_atoms_data(data: List[ase.Atoms], mask_method: str | None = None, confinement: Tuple[float, float] | None = None, properties: List[Dict] | None = None, canonical_cell: bool = False) None¶
Add ASE data to the dataset
Converts a list of ASE Atoms objects to AtomsGraph objects and adds them to the dataset
- Parameters:
data (List[Atoms]) – A list of ASE Atoms objects
mask_method (str, optional) – Method for computing the atom mask (e.g.
"MaskFixed").confinement (Tuple[float, float], optional) – Z-axis confinement bounds
(z_min, z_max)applied to every structure.properties (List[Dict], optional) – Per-structure property dictionaries; each entry is mapped to the corresponding graph via
setattr().canonical_cell (bool, optional) – When
True(the default), cells are stored in canonical lower-triangular form. Set toFalseto store cells exactly as provided by ASE.
- Return type:
None
- add_graph_data(data: List[agedi.data.atoms_graph.AtomsGraph]) None¶
Add AtomsGraph data to the dataset
Adds a list of AtomsGraph objects to the dataset
- Parameters:
data (List[AtomsGraph]) – A list of AtomsGraph objects
- Return type:
None
- add_regressor_data(data: List[ase.Atoms], canonical_cell: bool = False) None¶
Add atoms data that will be used exclusively for regressor training.
Structures in this dataset are only used to train the regressor model (e.g. force-field heads) and are never passed through the diffusion loss. This allows the regressor to learn from non-equilibrium structures that would be unsuitable as diffusion training targets.
Energy and forces are read from the ASE calculator attached to each
Atomsobject when available.- Parameters:
data (List[Atoms]) – A list of ASE
Atomsobjects, each with an attached calculator that provides energy and forces.canonical_cell (bool, optional) – When
True, cells are stored in canonical lower-triangular form. Defaults toFalse.
- Return type:
None
- setup(stage: str | None = None) None¶
Set up train/validation/test splits and initialise data loaders.
Performs a random split of the dataset (if not already split) and calls
set_phase()to create the initial data loaders.- Parameters:
stage (str, optional) – Lightning stage identifier (
"fit","test", etc.). Not used internally; present for API compatibility.
- train_dataloader() torch_geometric.loader.DataLoader¶
Get the training DataLoader
Returns a DataLoader for the training dataset. When a separate regressor dataset has been added via
add_regressor_data(), aCombinedLoaderis returned so that each training step receives both a regular batch (key"main") and a regressor-only batch (key"regressor").- Return type:
DataLoader or CombinedLoader
- val_dataloader() torch_geometric.loader.DataLoader¶
Get the validation DataLoader
Returns a DataLoader for the validation dataset
- Return type:
DataLoader
- test_dataloader() torch_geometric.loader.DataLoader¶
Get the test DataLoader
Returns a DataLoader for the test dataset
- Return type:
DataLoader
- set_phase(phase: int) None¶
Switch the dataset to the given training phase.
Applies the phase-specific transforms to the dataset splits and re-creates the data loaders with the augmented data.
- Parameters:
phase (int) – Zero-based phase index. Phase 0 uses the original data; subsequent phases append transformed copies according to
phase_transforms[phase].
- _check_confinement(dataset: List[agedi.data.atoms_graph.AtomsGraph], confinement: Tuple[float, float]) None¶
Check that all unmasked atoms in dataset lie within confinement.
- Parameters:
dataset (List[AtomsGraph]) – The list of graphs to validate.
confinement (Tuple[float, float]) – The
(z_min, z_max)confinement bounds.
- Raises:
ValueError – If any unmasked atom has a Z position outside the confinement. The error message includes a suggested confinement that covers all unmasked atoms.
- _has_energy_forces(atoms)¶
Check if the given ASE Atoms object has energy and forces information available. This method checks if a calculator is attached to the Atoms object and if it contains the ‘energy’ and ‘forces’ properties in its results. It avoids a calculation if there is a calculator, but it has not yet been used.
- Parameters:
atoms (Atoms) – The ASE Atoms object to check for energy and forces information.
- Returns:
A tuple indicating whether energy and forces information is available, respectively.
- Return type:
Tuple[bool, bool]