agedi.data.dataset

Classes

Dataset

Defines a custom dataset for AtomsGraph data

Module Contents

class agedi.data.dataset.Dataset(batch_size: int = 32, n_train: float | int = 0.9, n_val: float | int = 0.1, n_test: float | int = 0.0, shuffle: bool = True, properties: List[str] = ['energy', 'forces'], cutoff: float = 6.0, phase_transforms: List[List[torch_geometric.transforms.BaseTransform]] | None = None, num_workers: int = 0, **kwargs)

Bases: lightning.LightningDataModule

Defines a custom dataset for AtomsGraph data

Parameters:
  • batch_size (int) – The batch size for the DataLoader

  • n_train (Union[float, int]) – The number of training samples. If float, it is interpreted as a fraction of the dataset size

  • n_val (Union[float, int]) – The number of validation samples. If float, it is interpreted as a fraction of the dataset size

  • n_test (Union[float, int]) – The number of test samples. If float, it is interpreted as a fraction of the dataset size

  • shuffle (bool) – Whether to shuffle the dataset

  • properties (List[str]) – The properties to include in the dataset. Can be “energy”, “forces”, or both

  • cutoff (float) – The cutoff radius for the neighbor list

  • phase_transforms (Optional[List[List[BaseTransform]]]) – The data augmentation transforms to apply to each training phase

Return type:

Dataset

batch_size = 32
n_train = 0.9
n_val = 0.1
n_test = 0.0
properties = ['energy', 'forces']
cutoff = 6.0
dataset = None
train_idx = None
val_idx = None
test_idx = None
phase_transforms = None
num_workers = 0
regressor_dataset = None
regressor_train_loader = None
add_atoms_data(data: List[ase.Atoms], mask_method: str | None = None, confinement: Tuple[float, float] | None = None, properties: List[Dict] | None = None, canonical_cell: bool = False) None

Add ASE data to the dataset

Converts a list of ASE Atoms objects to AtomsGraph objects and adds them to the dataset

Parameters:
  • data (List[Atoms]) – A list of ASE Atoms objects

  • mask_method (str, optional) – Method for computing the atom mask (e.g. "MaskFixed").

  • confinement (Tuple[float, float], optional) – Z-axis confinement bounds (z_min, z_max) applied to every structure.

  • properties (List[Dict], optional) – Per-structure property dictionaries; each entry is mapped to the corresponding graph via setattr().

  • canonical_cell (bool, optional) – When True (the default), cells are stored in canonical lower-triangular form. Set to False to store cells exactly as provided by ASE.

Return type:

None

add_graph_data(data: List[agedi.data.atoms_graph.AtomsGraph]) None

Add AtomsGraph data to the dataset

Adds a list of AtomsGraph objects to the dataset

Parameters:

data (List[AtomsGraph]) – A list of AtomsGraph objects

Return type:

None

add_regressor_data(data: List[ase.Atoms], canonical_cell: bool = False) None

Add atoms data that will be used exclusively for regressor training.

Structures in this dataset are only used to train the regressor model (e.g. force-field heads) and are never passed through the diffusion loss. This allows the regressor to learn from non-equilibrium structures that would be unsuitable as diffusion training targets.

Energy and forces are read from the ASE calculator attached to each Atoms object when available.

Parameters:
  • data (List[Atoms]) – A list of ASE Atoms objects, each with an attached calculator that provides energy and forces.

  • canonical_cell (bool, optional) – When True, cells are stored in canonical lower-triangular form. Defaults to False.

Return type:

None

setup(stage: str | None = None) None

Set up train/validation/test splits and initialise data loaders.

Performs a random split of the dataset (if not already split) and calls set_phase() to create the initial data loaders.

Parameters:

stage (str, optional) – Lightning stage identifier ("fit", "test", etc.). Not used internally; present for API compatibility.

train_dataloader() torch_geometric.loader.DataLoader

Get the training DataLoader

Returns a DataLoader for the training dataset. When a separate regressor dataset has been added via add_regressor_data(), a CombinedLoader is returned so that each training step receives both a regular batch (key "main") and a regressor-only batch (key "regressor").

Return type:

DataLoader or CombinedLoader

val_dataloader() torch_geometric.loader.DataLoader

Get the validation DataLoader

Returns a DataLoader for the validation dataset

Return type:

DataLoader

test_dataloader() torch_geometric.loader.DataLoader

Get the test DataLoader

Returns a DataLoader for the test dataset

Return type:

DataLoader

set_phase(phase: int) None

Switch the dataset to the given training phase.

Applies the phase-specific transforms to the dataset splits and re-creates the data loaders with the augmented data.

Parameters:

phase (int) – Zero-based phase index. Phase 0 uses the original data; subsequent phases append transformed copies according to phase_transforms[phase].

_check_confinement(dataset: List[agedi.data.atoms_graph.AtomsGraph], confinement: Tuple[float, float]) None

Check that all unmasked atoms in dataset lie within confinement.

Parameters:
  • dataset (List[AtomsGraph]) – The list of graphs to validate.

  • confinement (Tuple[float, float]) – The (z_min, z_max) confinement bounds.

Raises:

ValueError – If any unmasked atom has a Z position outside the confinement. The error message includes a suggested confinement that covers all unmasked atoms.

_has_energy_forces(atoms)

Check if the given ASE Atoms object has energy and forces information available. This method checks if a calculator is attached to the Atoms object and if it contains the ‘energy’ and ‘forces’ properties in its results. It avoids a calculation if there is a calculator, but it has not yet been used.

Parameters:

atoms (Atoms) – The ASE Atoms object to check for energy and forces information.

Returns:

A tuple indicating whether energy and forces information is available, respectively.

Return type:

Tuple[bool, bool]