agedi.diffusion.noisers.types¶
Attributes¶
Classes¶
Noise schedule for the discrete type diffusion model (Q matrix). |
|
Placeholder class for transition matrix representations. |
|
Type noiser. |
Module Contents¶
- class agedi.diffusion.noisers.types.NoiseSchedule(beta_min: float, beta_max: float)¶
Noise schedule for the discrete type diffusion model (Q matrix).
Implements an exponential noise schedule parameterised by beta_min and beta_max, following the score-entropy discrete diffusion formulation.
- beta_min¶
- beta_max¶
- get_hparams() Dict¶
Return hyperparameters sufficient to reconstruct this noise schedule.
- Returns:
Hyperparameter dictionary with
_target_,beta_min, andbeta_max.- Return type:
dict
- _beta_t(time: torch.Tensor) torch.Tensor¶
Beta function for the type noiser Q
- Parameters:
time (torch.Tensor) – Diffusion time
- Returns:
The beta value for the given time
- Return type:
torch.Tensor
- rate_noise(time: torch.Tensor) torch.Tensor¶
The rate of change of the noise i.e. g(t)
- Parameters:
time (torch.Tensor) – The diffusion time
- Returns:
The rate of change of the noise
- Return type:
torch.Tensor
- total_noise(time: torch.Tensor) torch.Tensor¶
Total noise at time t
Given as the integral of the rate of change of the noise i.e. int_0^t g(t) dt + g(0)
- Parameters:
time (torch.Tensor) – The diffusion time
- Returns:
The total noise at time t
- Return type:
torch.Tensor
- class agedi.diffusion.noisers.types.Transition¶
Placeholder class for transition matrix representations.
- class agedi.diffusion.noisers.types.Types(prior=Constant(0), distribution=Categorical(), noise_schedule: NoiseSchedule = NoiseSchedule(0.01, 3.0), sampling_mask: torch.Tensor | None = None, n_classes: int = 100, type_map: List[int] | None = None, **kwargs)¶
Bases:
agedi.diffusion.noisers.NoiserType noiser.
Based on score entropy and discrete diffusion model. See https://arxiv.org/abs/2310.16834 for more details.
Uses an absorbing state (index 0) as the first state in the transition matrix.
- _key = 'x'¶
- noise_schedule¶
- sampling_mask = None¶
- n_classes = 100¶
- get_hparams() Dict¶
Return hyperparameters for this types noiser.
- _noise(batch: agedi.data.AtomsGraph) agedi.data.AtomsGraph¶
Noises the attribute of the atomistic structure.
Performs the noising of the atomic types.
- Parameters:
batch (AtomsGraph) – The atomistic structure (or batch hereof) to be noised.
- Returns:
The noised atomistic structure (or bach hereof).
- Return type:
- _denoise(batch: agedi.data.AtomsGraph, delta_t: float, last: bool) agedi.data.AtomsGraph¶
Denoises the attribute of the atomistic structure.
Denoisis the atomic types.
- Parameters:
batch (AtomsGraph) – The atomistic structure (or batch hereof) to be denoised.
delta_t (float) – The time step to be used for the denoising.
last (bool) – If the last denoising step is performed.
- Returns:
The denoised atomistic structure (or bach hereof).
- Return type:
- _loss(batch: agedi.data.AtomsGraph) torch.Tensor¶
Computes the training loss.
The score is with score entropy training as thus given as score=log(s) and then for sampling should be used as a concrete score i.e. exp(score)!
- Parameters:
batch (AtomsGraph) – The atomistic structure (or batch hereof) to be noised and denoised.
- Returns:
The loss of the noised and denoised atomistic structure.
- Return type:
float
- sample_transition(x: torch.Tensor, sigma: torch.Tensor) torch.Tensor¶
Sample the transition vector for the types
This corresponds to noising the types in the discrete diffusion model
- Parameters:
x (torch.Tensor) – The current types
sigma (torch.Tensor) – The total noise
- Returns:
The noised types
- Return type:
torch.Tensor
- score_entropy(score: torch.Tensor, sigma: torch.Tensor, x: torch.Tensor, x0: torch.Tensor) torch.Tensor¶
Computes the score entropy loss
- Parameters:
score (torch.Tensor) – The score
sigma (torch.Tensor) – The total noise
x (torch.Tensor) – The noised types
x0 (torch.Tensor) – The original types
- Returns:
The score entropy loss
- Return type:
torch.Tensor
- transp_rate(x: torch.Tensor) torch.Tensor¶
Compute the i’th row of the rate transition matrix Q
Can be used to compute the reverse rate
- Parameters:
x (torch.Tensor) – The types
- Returns:
The i’th row of the rate transition matrix Q
- Return type:
torch.Tensor
- reverse_rate(x: torch.Tensor, score: torch.Tensor) torch.Tensor¶
Constructs the reverse rate.
The reverse rate is given as the score * transp_rate
- Parameters:
x (torch.Tensor) – The types
score (torch.Tensor) – The score
- Returns:
The reverse rate
- Return type:
torch.Tensor
- sample_rate(callable: Callable, x: torch.Tensor, rate: torch.Tensor) torch.Tensor¶
Sample the rate
Explain more…
- Parameters:
callable (Callable) – Callable function defining the categorical distribution
x (torch.Tensor) – The types
rate (torch.Tensor) – The rate
- Returns:
The sampled rate
- Return type:
torch.Tensor
- staggered_score(score: torch.Tensor, dsigma: torch.Tensor) torch.Tensor¶
Computes the staggered score
Computes p_{sigma - dsigma}(z) / p_{sigma}(x), which is approximated with e^{-{dsigma} E} score
- Parameters:
score (torch.Tensor) – The score
dsigma (torch.Tensor) – The rate noise
- Returns:
The staggered score
- Return type:
torch.Tensor
- transp_transition(x: torch.Tensor, sigma: torch.Tensor) torch.Tensor¶
Compute the transition matrix for the types
Explain more..
- Parameters:
x (torch.Tensor) – The types
sigma (torch.Tensor) – The total noise
- Returns:
The transition matrix
- Return type:
torch.Tensor
- agedi.diffusion.noisers.types.TypesNoiser¶