agedi.diffusion¶
Submodules¶
Classes¶
Full diffusion model: training + sampling. |
|
Configuration for force-field guided sampling. |
|
Pure-Python sampling core for diffusion models. |
Package Contents¶
- class agedi.diffusion.Agedi(score_model: agedi.models.ScoreModel, noisers: List[agedi.diffusion.noisers.Noiser], regressor_model: torch.nn.Module | None = None, regressor_heads: List | None = None, regressor_loss_weight: float = 1.0, optim_config: Dict | None = None, scheduler_config: Dict | None = None, eps: float = 1e-05)¶
Bases:
lightning.LightningModule,agedi.diffusion.diffusion.DiffusionFull diffusion model: training + sampling.
Combines the
Diffusionsampling pipeline withLightningModuletraining hooks.- Parameters:
score_model (ScoreModel) – The score model.
noisers (List[Noiser]) – A list of noisers.
regressor_model (torch.nn.Module, optional) – An optional regressor model used for force-field guidance during sampling. When present, its loss is added to the diffusion loss during training.
regressor_heads (List, optional) – When provided, a
RegressorModelis built internally using these heads while sharing the translator and representation fromscore_model. Use this parameter (instead ofregressor_model) when the backbone should be shared.regressor_loss_weight (float, optional) – Weight applied to the regressor loss. Defaults to
1.0.optim_config (dict, optional) – Keyword arguments forwarded to
torch.optim.AdamW.scheduler_config (dict, optional) – Keyword arguments forwarded to
torch.optim.lr_scheduler.ReduceLROnPlateau.eps (float, optional) – Minimum diffusion time value.
- regressor_loss_weight = 1.0¶
- optim_config = None¶
- scheduler_config = None¶
- _regressor_training = False¶
- on_fit_start() None¶
Write
hparams.yamlto the trainer log directory at training start.
- get_hparams() Dict¶
Return hyperparameters sufficient to reconstruct this diffusion model.
- Returns:
Hyperparameter dictionary with
_target_,score_model,noisers,optim_config,scheduler_config,eps, and optionallyregressor_headsorregressor_model.- Return type:
dict
- setup(stage: str = None) None¶
Set up the model (put score model in training mode).
- forward(batch: agedi.data.AtomsGraph) agedi.data.AtomsGraph¶
Forward pass through the score model.
- Parameters:
batch (AtomsGraph) – A batch of AtomsGraph data.
- Returns:
The output of the score model forward pass.
- Return type:
- loss(batch: agedi.data.AtomsGraph, batch_idx: torch.Tensor) Dict¶
Compute the combined diffusion + regressor loss.
Always computes the diffusion (denoising) loss on a noised copy of the batch. When a regressor model is present and the batch contains force labels, the regressor loss is added with weight
regressor_loss_weight.- Parameters:
batch (AtomsGraph) – A batch of AtomsGraph data.
batch_idx (torch.Tensor) – The index of the batch.
- Returns:
A dictionary of losses.
- Return type:
dict
- diffusion_loss(batch: agedi.data.AtomsGraph, batch_idx: torch.Tensor) Dict¶
Compute the diffusion (denoising score-matching) loss.
- Parameters:
batch (AtomsGraph) – A batch of AtomsGraph data.
batch_idx (torch.Tensor) – The index of the batch.
- Returns:
A dictionary of losses.
- Return type:
dict
- regressor_loss(batch: agedi.data.AtomsGraph, batch_idx: torch.Tensor) Dict¶
Compute the regressor loss on the un-noised batch.
- Parameters:
batch (AtomsGraph) – A batch of AtomsGraph data.
batch_idx (torch.Tensor) – The index of the batch.
- Returns:
A dictionary of losses.
- Return type:
dict
- Raises:
ValueError – If no regressor model is attached.
- training_step(batch, batch_idx: torch.Tensor) torch.Tensor¶
Perform a training step.
Computes the combined diffusion + regressor loss (see
loss()).When the
Datasetwas set up with a dedicated regressor dataset (viaadd_regressor_data()),batchis a dict with two keys:"main"– a regular training batch used for both the diffusion and regressor loss."regressor"– a regressor-only batch whose structures are only forwarded through the regressor loss (not the diffusion loss).
When no regressor dataset is present
batchis a plainAtomsGraphbatch and the behaviour is identical to the pre-existing implementation.- Parameters:
batch (AtomsGraph or dict) – A batch of AtomsGraph data, or a dict with
"main"and"regressor"keys when a dedicated regressor dataset is used.batch_idx (torch.Tensor) – The index of the batch.
- Returns:
The combined loss.
- Return type:
torch.Tensor
- validation_step(batch: agedi.data.AtomsGraph, batch_idx: torch.Tensor) torch.Tensor¶
Perform a validation step.
- Parameters:
batch (AtomsGraph) – A batch of AtomsGraph data.
batch_idx (torch.Tensor) – The index of the batch.
- Returns:
The combined loss.
- Return type:
torch.Tensor
- configure_optimizers() Dict¶
Configure optimizers and learning-rate schedulers.
When a regressor model is present a single optimizer is built over the deduplicated union of
score_modelandregressor_modelparameters (shared parameters appear only once).- Returns:
A dictionary with
"optimizer","lr_scheduler", and"monitor"keys.- Return type:
dict
- _scheduler_monitor() str¶
Return the metric used by ReduceLROnPlateau.
- property regressor_training: bool¶
Whether the regressor model is in training mode.
- class agedi.diffusion.ForcefieldGuidanceConfig¶
Configuration for force-field guided sampling.
- Parameters:
guidance (float) – Scale of the force-field guidance applied at each reverse step. Set to
0.0(the default) to disable guidance entirely.zeta (float) – Exponent for the time-dependent weight factor
(1 - t)**zeta. Higher values concentrate guidance near the end of the trajectory.force_threshold (float) – Convergence criterion for the optional post-diffusion relaxation: the maximum per-atom force magnitude (eV/Å) below which relaxation stops.
max_extra_steps (int) – Maximum number of additional relaxation steps performed after the main diffusion trajectory when
guidance > 0.
- guidance: float = 0.0¶
- zeta: float = 3.0¶
- force_threshold: float = 0.05¶
- max_extra_steps: int = 0¶
- class agedi.diffusion.Diffusion(score_model: ScoreModel, noisers: List[agedi.diffusion.noisers.Noiser], regressor_model: torch.nn.Module | None = None, eps: float = 1e-05)¶
Pure-Python sampling core for diffusion models.
Holds the score model, noisers, and an optional regressor and provides the full forward / reverse / sampling pipeline. This class does not inherit from
torch.nn.Moduleorlightning.LightningModuleand therefore has no training hooks.When used through
Agedi(which inherits from both this class andlightning.LightningModule), the Lightning infrastructure manages device placement and module registration. When used standalone, device information is derived from the score model’s parameters via thedeviceproperty.- Parameters:
score_model (ScoreModel) – The score model.
noisers (List[Noiser]) – A list of noisers.
regressor_model (torch.nn.Module, optional) – An optional regressor model used for force-field guidance during sampling.
eps (float, optional) – Minimum value for the diffusion time step (used in
sample_time()).
- score_model¶
- noisers¶
- regressor_model = None¶
- eps = 1e-05¶
- lbfgs_step_sizer: agedi.diffusion.guidance.BatchedLBFGSStepSizer | None = None¶
- zeta: float = 3.0¶
- noiser_keys¶
- score_keys¶
- _compiled_reverse_step = None¶
- property device: torch.device¶
Infer the computation device from the score model’s parameters.
When used through
Agedi(which also inheritslightning.LightningModule), Lightning’s owndeviceproperty takes precedence.
- sample_time(batch: agedi.data.AtomsGraph) None¶
Sample a random diffusion time for each graph in batch.
Draws times uniformly from
[eps, 1]and assigns them tobatch.timeat atom resolution.- Parameters:
batch (AtomsGraph) – A batch of AtomsGraph data; modified in-place.
- forward_step(batch: agedi.data.AtomsGraph) agedi.data.AtomsGraph¶
Forward diffusion step (corruption).
Applies each noiser in order to corrupt the batch.
- Parameters:
batch (AtomsGraph) – A batch of AtomsGraph data.
- Returns:
The corrupted batch.
- Return type:
- reverse_step(batch: agedi.data.AtomsGraph, delta_t: float, force_field_guidance: float, last: bool = False, timings: SamplingTimings | None = None) agedi.data.AtomsGraph¶
Reverse diffusion step (denoising).
Evaluates the score model and applies one reverse-SDE step through all noisers. Optionally applies force-field guidance afterwards.
- Parameters:
batch (AtomsGraph) – A batch of AtomsGraph data.
delta_t (float) – The time step.
force_field_guidance (float) – Scale of the force-field guidance (
0.0disables it).last (bool, optional) – Whether this is the final denoising step.
timings (SamplingTimings, optional) – If provided, timing measurements are accumulated here.
- Returns:
The denoised batch.
- Return type:
- corrector_step(batch: agedi.data.AtomsGraph, corrector_dt: float) agedi.data.AtomsGraph¶
Langevin corrector step at constant time.
Evaluates the score model and applies one Langevin corrector step through all noisers (in reverse order).
- Parameters:
batch (AtomsGraph) – A batch of AtomsGraph data.
corrector_dt (float) – Step size for the Langevin corrector.
- Returns:
The corrected batch.
- Return type:
- force_field_guidance_step(batch: agedi.data.AtomsGraph, scale: float, max_step_size: float = 0.1) agedi.data.AtomsGraph¶
Apply one force-field guidance step.
- Parameters:
batch (AtomsGraph) – A batch of AtomsGraph data.
scale (float) – Base scale of the force field guidance.
max_step_size (float, optional) – Maximum allowed step size magnitude.
- Returns:
Updated batch.
- Return type:
- post_diffusion_relaxation_step(batch: agedi.data.AtomsGraph, scale: float = 0.1) agedi.data.AtomsGraph¶
Perform a pure force-based relaxation step.
- Parameters:
batch (AtomsGraph) – A batch of AtomsGraph data.
scale (float, optional) – Step size scaling factor.
- Returns:
Updated batch.
- Return type:
- _initialize_graph(cutoff: float, **kwargs) agedi.data.AtomsGraph¶
Initialise a single graph from noiser priors.
- Parameters:
cutoff (float) – Cutoff radius for the neighbour list.
**kwargs – Additional keyword arguments passed to the graph (e.g.
cell,template,pbc).
- Returns:
The initialised graph.
- Return type:
- static _sync_for_timing(device: torch.device | None) None¶
- _time_sampling_call(device: torch.device | None, timings: SamplingTimings, key: str, fn, *args, **kwargs)¶
- static _format_timing_line(label: str, value: float, count: int | None = None) str¶
- _print_sampling_timings(timings: SamplingTimings) None¶
- property compiled_reverse_step¶
Lazily compile
reverse_step()withtorch.compile.The compiled kernel is cached as
self._compiled_reverse_stepso that compilation happens at most once per model instance. Using a per-instance cache (rather than a class-level@torch.compiledecorator) means that twoDiffusionobjects with different architectures will each compile their own kernel and never interfere.Note
timingsmust not be passed to the compiled function —time.perf_counteris not traceable by Dynamo. Time the compiled call from outside in_sample_batch()using theis_compiledflag.
- _sample_batch(batch: torch_geometric.data.Batch, steps: int, eps: float, force_field_guidance: float, save_trajectory: bool, progress_bar: bool, force_threshold: float, max_extra_steps: int, corrector_steps: int = 0, corrector_step_size: float = 0.001, timings: SamplingTimings | None = None, reverse_step_fn=None, is_compiled: bool = False) List[agedi.data.AtomsGraph]¶
Run the reverse-diffusion loop for a pre-built batch.
- Parameters:
batch (Batch) – A batch of
AtomsGraphdata att=1.steps (int) – Number of reverse-diffusion steps.
eps (float) – Minimum time value (end of trajectory).
force_field_guidance (float) – Scale of the force-field guidance (
0.0disables it).save_trajectory (bool) – Whether to collect and return all intermediate states.
progress_bar (bool) – Whether to display a tqdm progress bar.
force_threshold (float) – Maximum per-atom force for terminating post-diffusion relaxation.
max_extra_steps (int) – Maximum extra relaxation steps after the main trajectory.
corrector_steps (int, optional) – Number of Langevin corrector passes after each predictor step.
0(default) disables the corrector (standard DDPM/EM sampling).corrector_step_size (float, optional) – Step size used for each Langevin corrector step. Defaults to
1e-3.timings (SamplingTimings, optional) – If provided, timing measurements are accumulated here.
reverse_step_fn (callable, optional) – The reverse step function to use. Defaults to
self.reverse_step. Pass atorch.compile-wrapped version to enable compiled sampling.is_compiled (bool, optional) – Whether
reverse_step_fnis a compiled function.
- Returns:
Final structures, or (when save_trajectory is
True) a list of trajectories (one per graph).- Return type:
List[AtomsGraph]
- _sample(N: int, steps: int, cutoff: float, eps: float, force_field_guidance: float, force_threshold: float, max_extra_steps: int, progress_bar: bool, save_trajectory: bool, corrector_steps: int = 0, corrector_step_size: float = 0.001, print_timings: bool = False, compile: bool = False, **kwargs) List[agedi.data.AtomsGraph]¶
Build N graphs from priors and run the sampling loop.
- Parameters:
N (int) – Number of structures to generate.
steps (int) – Number of reverse-diffusion steps.
cutoff (float) – Cutoff radius for the neighbour list.
eps (float) – Minimum time value (end of trajectory).
force_field_guidance (float) – Scale of the force-field guidance.
force_threshold (float) – Maximum per-atom force for post-diffusion relaxation.
max_extra_steps (int) – Maximum extra relaxation steps.
progress_bar (bool) – Show tqdm progress bar.
save_trajectory (bool) – Collect all intermediate states.
corrector_steps (int, optional) – Langevin corrector passes per predictor step.
corrector_step_size (float, optional) – Step size for each corrector pass.
print_timings (bool, optional) – Print a timing breakdown after sampling completes.
compile (bool, optional) – Use
torch.compileon the reverse diffusion step.**kwargs – Keyword arguments forwarded to
_initialize_graph().
- Returns:
Sampled structures (or trajectories when save_trajectory is
True).- Return type:
List[AtomsGraph]
- sample(N: int, template=None, batch_size: int | None = 64, steps: int | None = 500, cutoff: float | None = 6.0, eps: float | None = 0.001, n_atoms: int | None = None, atomic_numbers: List[int] | None = None, formula: str | None = None, positions: numpy.ndarray | None = None, cell: numpy.ndarray | None = None, pbc: numpy.ndarray | None = None, confinement: Tuple[float, float] | None = None, compile: bool = False, ff_guidance: agedi.diffusion.guidance.ForcefieldGuidanceConfig | None = None, property: Dict | None = None, progress_bar: bool | None = False, save_trajectory: bool | None = False, print_timings: bool | None = False, corrector_steps: int = 0, corrector_step_size: float = 0.001) List[agedi.data.AtomsGraph]¶
Sample structures from the diffusion model.
The minimum required arguments depend on the configured noisers and whether a template is provided:
n_atoms– always required unless derivable fromatomic_numbersorformula.atomic_numbers– required unless a types-noiser is configured (key"x"), or derivable fromformula.positions– required when no positions-noiser is configured (type-only diffusion).cell– required when notemplateis given.pbc– optional; defaults to[True, True, True].
- Parameters:
N (int) – Number of structures to generate.
template (AtomsGraph or ase.Atoms, optional) – Template structure.
cellandpbcare taken from the template when not explicitly provided.batch_size (int, optional) – Internal batch size for splitting large N.
steps (int, optional) – Number of reverse-diffusion steps.
cutoff (float, optional) – Cutoff radius for the neighbour list.
eps (float, optional) – Minimum time value at the end of the trajectory.
n_atoms (int, optional) – Number of atoms per structure.
atomic_numbers (List[int], optional) – Atomic numbers of the atoms to generate.
formula (str, optional) – Chemical formula (e.g.
"H2O").positions (np.ndarray, optional) – Fixed atom positions (shape
(n_atoms, 3)).cell (np.ndarray, optional) – Unit-cell matrix (3x3).
pbc (np.ndarray, optional) – Periodic boundary conditions.
confinement (Tuple[float, float], optional) – Z-directional confinement
(z_min, z_max).compile (bool, optional) – When
True, usetorch.compileon the reverse diffusion step for improved throughput on CUDA hardware.ff_guidance (ForcefieldGuidanceConfig, optional) – Force-field guidance configuration.
property (dict, optional) – Conditioning properties (key -> scalar tensor).
progress_bar (bool, optional) – Show a tqdm progress bar.
save_trajectory (bool, optional) – Return full trajectories instead of final structures.
print_timings (bool, optional) – Print a timing breakdown after sampling completes.
corrector_steps (int, optional) – Number of Langevin corrector passes after each predictor step.
0(default) gives standard (predictor-only) sampling.corrector_step_size (float, optional) – Step size for each corrector pass. Defaults to
1e-3.
- Returns:
Sampled structures, or trajectories when save_trajectory is
True.- Return type:
List[AtomsGraph]