agedi.diffusion

Submodules

Classes

Agedi

Full diffusion model: training + sampling.

ForcefieldGuidanceConfig

Configuration for force-field guided sampling.

Diffusion

Pure-Python sampling core for diffusion models.

Package Contents

class agedi.diffusion.Agedi(score_model: agedi.models.ScoreModel, noisers: List[agedi.diffusion.noisers.Noiser], regressor_model: torch.nn.Module | None = None, regressor_heads: List | None = None, regressor_loss_weight: float = 1.0, optim_config: Dict | None = None, scheduler_config: Dict | None = None, eps: float = 1e-05)

Bases: lightning.LightningModule, agedi.diffusion.diffusion.Diffusion

Full diffusion model: training + sampling.

Combines the Diffusion sampling pipeline with LightningModule training hooks.

Parameters:
  • score_model (ScoreModel) – The score model.

  • noisers (List[Noiser]) – A list of noisers.

  • regressor_model (torch.nn.Module, optional) – An optional regressor model used for force-field guidance during sampling. When present, its loss is added to the diffusion loss during training.

  • regressor_heads (List, optional) – When provided, a RegressorModel is built internally using these heads while sharing the translator and representation from score_model. Use this parameter (instead of regressor_model) when the backbone should be shared.

  • regressor_loss_weight (float, optional) – Weight applied to the regressor loss. Defaults to 1.0.

  • optim_config (dict, optional) – Keyword arguments forwarded to torch.optim.AdamW.

  • scheduler_config (dict, optional) – Keyword arguments forwarded to torch.optim.lr_scheduler.ReduceLROnPlateau.

  • eps (float, optional) – Minimum diffusion time value.

regressor_loss_weight = 1.0
optim_config = None
scheduler_config = None
_regressor_training = False
on_fit_start() None

Write hparams.yaml to the trainer log directory at training start.

get_hparams() Dict

Return hyperparameters sufficient to reconstruct this diffusion model.

Returns:

Hyperparameter dictionary with _target_, score_model, noisers, optim_config, scheduler_config, eps, and optionally regressor_heads or regressor_model.

Return type:

dict

setup(stage: str = None) None

Set up the model (put score model in training mode).

forward(batch: agedi.data.AtomsGraph) agedi.data.AtomsGraph

Forward pass through the score model.

Parameters:

batch (AtomsGraph) – A batch of AtomsGraph data.

Returns:

The output of the score model forward pass.

Return type:

AtomsGraph

loss(batch: agedi.data.AtomsGraph, batch_idx: torch.Tensor) Dict

Compute the combined diffusion + regressor loss.

Always computes the diffusion (denoising) loss on a noised copy of the batch. When a regressor model is present and the batch contains force labels, the regressor loss is added with weight regressor_loss_weight.

Parameters:
  • batch (AtomsGraph) – A batch of AtomsGraph data.

  • batch_idx (torch.Tensor) – The index of the batch.

Returns:

A dictionary of losses.

Return type:

dict

diffusion_loss(batch: agedi.data.AtomsGraph, batch_idx: torch.Tensor) Dict

Compute the diffusion (denoising score-matching) loss.

Parameters:
  • batch (AtomsGraph) – A batch of AtomsGraph data.

  • batch_idx (torch.Tensor) – The index of the batch.

Returns:

A dictionary of losses.

Return type:

dict

regressor_loss(batch: agedi.data.AtomsGraph, batch_idx: torch.Tensor) Dict

Compute the regressor loss on the un-noised batch.

Parameters:
  • batch (AtomsGraph) – A batch of AtomsGraph data.

  • batch_idx (torch.Tensor) – The index of the batch.

Returns:

A dictionary of losses.

Return type:

dict

Raises:

ValueError – If no regressor model is attached.

training_step(batch, batch_idx: torch.Tensor) torch.Tensor

Perform a training step.

Computes the combined diffusion + regressor loss (see loss()).

When the Dataset was set up with a dedicated regressor dataset (via add_regressor_data()), batch is a dict with two keys:

  • "main" – a regular training batch used for both the diffusion and regressor loss.

  • "regressor" – a regressor-only batch whose structures are only forwarded through the regressor loss (not the diffusion loss).

When no regressor dataset is present batch is a plain AtomsGraph batch and the behaviour is identical to the pre-existing implementation.

Parameters:
  • batch (AtomsGraph or dict) – A batch of AtomsGraph data, or a dict with "main" and "regressor" keys when a dedicated regressor dataset is used.

  • batch_idx (torch.Tensor) – The index of the batch.

Returns:

The combined loss.

Return type:

torch.Tensor

validation_step(batch: agedi.data.AtomsGraph, batch_idx: torch.Tensor) torch.Tensor

Perform a validation step.

Parameters:
  • batch (AtomsGraph) – A batch of AtomsGraph data.

  • batch_idx (torch.Tensor) – The index of the batch.

Returns:

The combined loss.

Return type:

torch.Tensor

configure_optimizers() Dict

Configure optimizers and learning-rate schedulers.

When a regressor model is present a single optimizer is built over the deduplicated union of score_model and regressor_model parameters (shared parameters appear only once).

Returns:

A dictionary with "optimizer", "lr_scheduler", and "monitor" keys.

Return type:

dict

_scheduler_monitor() str

Return the metric used by ReduceLROnPlateau.

property regressor_training: bool

Whether the regressor model is in training mode.

class agedi.diffusion.ForcefieldGuidanceConfig

Configuration for force-field guided sampling.

Parameters:
  • guidance (float) – Scale of the force-field guidance applied at each reverse step. Set to 0.0 (the default) to disable guidance entirely.

  • zeta (float) – Exponent for the time-dependent weight factor (1 - t)**zeta. Higher values concentrate guidance near the end of the trajectory.

  • force_threshold (float) – Convergence criterion for the optional post-diffusion relaxation: the maximum per-atom force magnitude (eV/Å) below which relaxation stops.

  • max_extra_steps (int) – Maximum number of additional relaxation steps performed after the main diffusion trajectory when guidance > 0.

guidance: float = 0.0
zeta: float = 3.0
force_threshold: float = 0.05
max_extra_steps: int = 0
class agedi.diffusion.Diffusion(score_model: ScoreModel, noisers: List[agedi.diffusion.noisers.Noiser], regressor_model: torch.nn.Module | None = None, eps: float = 1e-05)

Pure-Python sampling core for diffusion models.

Holds the score model, noisers, and an optional regressor and provides the full forward / reverse / sampling pipeline. This class does not inherit from torch.nn.Module or lightning.LightningModule and therefore has no training hooks.

When used through Agedi (which inherits from both this class and lightning.LightningModule), the Lightning infrastructure manages device placement and module registration. When used standalone, device information is derived from the score model’s parameters via the device property.

Parameters:
  • score_model (ScoreModel) – The score model.

  • noisers (List[Noiser]) – A list of noisers.

  • regressor_model (torch.nn.Module, optional) – An optional regressor model used for force-field guidance during sampling.

  • eps (float, optional) – Minimum value for the diffusion time step (used in sample_time()).

score_model
noisers
regressor_model = None
eps = 1e-05
lbfgs_step_sizer: agedi.diffusion.guidance.BatchedLBFGSStepSizer | None = None
zeta: float = 3.0
noiser_keys
score_keys
_compiled_reverse_step = None
property device: torch.device

Infer the computation device from the score model’s parameters.

When used through Agedi (which also inherits lightning.LightningModule), Lightning’s own device property takes precedence.

sample_time(batch: agedi.data.AtomsGraph) None

Sample a random diffusion time for each graph in batch.

Draws times uniformly from [eps, 1] and assigns them to batch.time at atom resolution.

Parameters:

batch (AtomsGraph) – A batch of AtomsGraph data; modified in-place.

forward_step(batch: agedi.data.AtomsGraph) agedi.data.AtomsGraph

Forward diffusion step (corruption).

Applies each noiser in order to corrupt the batch.

Parameters:

batch (AtomsGraph) – A batch of AtomsGraph data.

Returns:

The corrupted batch.

Return type:

AtomsGraph

reverse_step(batch: agedi.data.AtomsGraph, delta_t: float, force_field_guidance: float, last: bool = False, timings: SamplingTimings | None = None) agedi.data.AtomsGraph

Reverse diffusion step (denoising).

Evaluates the score model and applies one reverse-SDE step through all noisers. Optionally applies force-field guidance afterwards.

Parameters:
  • batch (AtomsGraph) – A batch of AtomsGraph data.

  • delta_t (float) – The time step.

  • force_field_guidance (float) – Scale of the force-field guidance (0.0 disables it).

  • last (bool, optional) – Whether this is the final denoising step.

  • timings (SamplingTimings, optional) – If provided, timing measurements are accumulated here.

Returns:

The denoised batch.

Return type:

AtomsGraph

corrector_step(batch: agedi.data.AtomsGraph, corrector_dt: float) agedi.data.AtomsGraph

Langevin corrector step at constant time.

Evaluates the score model and applies one Langevin corrector step through all noisers (in reverse order).

Parameters:
  • batch (AtomsGraph) – A batch of AtomsGraph data.

  • corrector_dt (float) – Step size for the Langevin corrector.

Returns:

The corrected batch.

Return type:

AtomsGraph

force_field_guidance_step(batch: agedi.data.AtomsGraph, scale: float, max_step_size: float = 0.1) agedi.data.AtomsGraph

Apply one force-field guidance step.

Parameters:
  • batch (AtomsGraph) – A batch of AtomsGraph data.

  • scale (float) – Base scale of the force field guidance.

  • max_step_size (float, optional) – Maximum allowed step size magnitude.

Returns:

Updated batch.

Return type:

AtomsGraph

post_diffusion_relaxation_step(batch: agedi.data.AtomsGraph, scale: float = 0.1) agedi.data.AtomsGraph

Perform a pure force-based relaxation step.

Parameters:
  • batch (AtomsGraph) – A batch of AtomsGraph data.

  • scale (float, optional) – Step size scaling factor.

Returns:

Updated batch.

Return type:

AtomsGraph

_initialize_graph(cutoff: float, **kwargs) agedi.data.AtomsGraph

Initialise a single graph from noiser priors.

Parameters:
  • cutoff (float) – Cutoff radius for the neighbour list.

  • **kwargs – Additional keyword arguments passed to the graph (e.g. cell, template, pbc).

Returns:

The initialised graph.

Return type:

AtomsGraph

static _sync_for_timing(device: torch.device | None) None
_time_sampling_call(device: torch.device | None, timings: SamplingTimings, key: str, fn, *args, **kwargs)
static _format_timing_line(label: str, value: float, count: int | None = None) str
_print_sampling_timings(timings: SamplingTimings) None
property compiled_reverse_step

Lazily compile reverse_step() with torch.compile.

The compiled kernel is cached as self._compiled_reverse_step so that compilation happens at most once per model instance. Using a per-instance cache (rather than a class-level @torch.compile decorator) means that two Diffusion objects with different architectures will each compile their own kernel and never interfere.

Note

timings must not be passed to the compiled function — time.perf_counter is not traceable by Dynamo. Time the compiled call from outside in _sample_batch() using the is_compiled flag.

_sample_batch(batch: torch_geometric.data.Batch, steps: int, eps: float, force_field_guidance: float, save_trajectory: bool, progress_bar: bool, force_threshold: float, max_extra_steps: int, corrector_steps: int = 0, corrector_step_size: float = 0.001, timings: SamplingTimings | None = None, reverse_step_fn=None, is_compiled: bool = False) List[agedi.data.AtomsGraph]

Run the reverse-diffusion loop for a pre-built batch.

Parameters:
  • batch (Batch) – A batch of AtomsGraph data at t=1.

  • steps (int) – Number of reverse-diffusion steps.

  • eps (float) – Minimum time value (end of trajectory).

  • force_field_guidance (float) – Scale of the force-field guidance (0.0 disables it).

  • save_trajectory (bool) – Whether to collect and return all intermediate states.

  • progress_bar (bool) – Whether to display a tqdm progress bar.

  • force_threshold (float) – Maximum per-atom force for terminating post-diffusion relaxation.

  • max_extra_steps (int) – Maximum extra relaxation steps after the main trajectory.

  • corrector_steps (int, optional) – Number of Langevin corrector passes after each predictor step. 0 (default) disables the corrector (standard DDPM/EM sampling).

  • corrector_step_size (float, optional) – Step size used for each Langevin corrector step. Defaults to 1e-3.

  • timings (SamplingTimings, optional) – If provided, timing measurements are accumulated here.

  • reverse_step_fn (callable, optional) – The reverse step function to use. Defaults to self.reverse_step. Pass a torch.compile-wrapped version to enable compiled sampling.

  • is_compiled (bool, optional) – Whether reverse_step_fn is a compiled function.

Returns:

Final structures, or (when save_trajectory is True) a list of trajectories (one per graph).

Return type:

List[AtomsGraph]

_sample(N: int, steps: int, cutoff: float, eps: float, force_field_guidance: float, force_threshold: float, max_extra_steps: int, progress_bar: bool, save_trajectory: bool, corrector_steps: int = 0, corrector_step_size: float = 0.001, print_timings: bool = False, compile: bool = False, **kwargs) List[agedi.data.AtomsGraph]

Build N graphs from priors and run the sampling loop.

Parameters:
  • N (int) – Number of structures to generate.

  • steps (int) – Number of reverse-diffusion steps.

  • cutoff (float) – Cutoff radius for the neighbour list.

  • eps (float) – Minimum time value (end of trajectory).

  • force_field_guidance (float) – Scale of the force-field guidance.

  • force_threshold (float) – Maximum per-atom force for post-diffusion relaxation.

  • max_extra_steps (int) – Maximum extra relaxation steps.

  • progress_bar (bool) – Show tqdm progress bar.

  • save_trajectory (bool) – Collect all intermediate states.

  • corrector_steps (int, optional) – Langevin corrector passes per predictor step.

  • corrector_step_size (float, optional) – Step size for each corrector pass.

  • print_timings (bool, optional) – Print a timing breakdown after sampling completes.

  • compile (bool, optional) – Use torch.compile on the reverse diffusion step.

  • **kwargs – Keyword arguments forwarded to _initialize_graph().

Returns:

Sampled structures (or trajectories when save_trajectory is True).

Return type:

List[AtomsGraph]

sample(N: int, template=None, batch_size: int | None = 64, steps: int | None = 500, cutoff: float | None = 6.0, eps: float | None = 0.001, n_atoms: int | None = None, atomic_numbers: List[int] | None = None, formula: str | None = None, positions: numpy.ndarray | None = None, cell: numpy.ndarray | None = None, pbc: numpy.ndarray | None = None, confinement: Tuple[float, float] | None = None, compile: bool = False, ff_guidance: agedi.diffusion.guidance.ForcefieldGuidanceConfig | None = None, property: Dict | None = None, progress_bar: bool | None = False, save_trajectory: bool | None = False, print_timings: bool | None = False, corrector_steps: int = 0, corrector_step_size: float = 0.001) List[agedi.data.AtomsGraph]

Sample structures from the diffusion model.

The minimum required arguments depend on the configured noisers and whether a template is provided:

  • n_atoms – always required unless derivable from atomic_numbers or formula.

  • atomic_numbers – required unless a types-noiser is configured (key "x"), or derivable from formula.

  • positions – required when no positions-noiser is configured (type-only diffusion).

  • cell – required when no template is given.

  • pbc – optional; defaults to [True, True, True].

Parameters:
  • N (int) – Number of structures to generate.

  • template (AtomsGraph or ase.Atoms, optional) – Template structure. cell and pbc are taken from the template when not explicitly provided.

  • batch_size (int, optional) – Internal batch size for splitting large N.

  • steps (int, optional) – Number of reverse-diffusion steps.

  • cutoff (float, optional) – Cutoff radius for the neighbour list.

  • eps (float, optional) – Minimum time value at the end of the trajectory.

  • n_atoms (int, optional) – Number of atoms per structure.

  • atomic_numbers (List[int], optional) – Atomic numbers of the atoms to generate.

  • formula (str, optional) – Chemical formula (e.g. "H2O").

  • positions (np.ndarray, optional) – Fixed atom positions (shape (n_atoms, 3)).

  • cell (np.ndarray, optional) – Unit-cell matrix (3x3).

  • pbc (np.ndarray, optional) – Periodic boundary conditions.

  • confinement (Tuple[float, float], optional) – Z-directional confinement (z_min, z_max).

  • compile (bool, optional) – When True, use torch.compile on the reverse diffusion step for improved throughput on CUDA hardware.

  • ff_guidance (ForcefieldGuidanceConfig, optional) – Force-field guidance configuration.

  • property (dict, optional) – Conditioning properties (key -> scalar tensor).

  • progress_bar (bool, optional) – Show a tqdm progress bar.

  • save_trajectory (bool, optional) – Return full trajectories instead of final structures.

  • print_timings (bool, optional) – Print a timing breakdown after sampling completes.

  • corrector_steps (int, optional) – Number of Langevin corrector passes after each predictor step. 0 (default) gives standard (predictor-only) sampling.

  • corrector_step_size (float, optional) – Step size for each corrector pass. Defaults to 1e-3.

Returns:

Sampled structures, or trajectories when save_trajectory is True.

Return type:

List[AtomsGraph]