agedi.diffusion.noisers.types

Attributes

Classes

NoiseSchedule

Noise schedule for the discrete type diffusion model (Q matrix).

Transition

Placeholder class for transition matrix representations.

Types

Type noiser.

Module Contents

class agedi.diffusion.noisers.types.NoiseSchedule(beta_min: float, beta_max: float)

Noise schedule for the discrete type diffusion model (Q matrix).

Implements an exponential noise schedule parameterised by beta_min and beta_max, following the score-entropy discrete diffusion formulation.

beta_min
beta_max
get_hparams() Dict

Return hyperparameters sufficient to reconstruct this noise schedule.

Returns:

Hyperparameter dictionary with _target_, beta_min, and beta_max.

Return type:

dict

_beta_t(time: torch.Tensor) torch.Tensor

Beta function for the type noiser Q

Parameters:

time (torch.Tensor) – Diffusion time

Returns:

The beta value for the given time

Return type:

torch.Tensor

rate_noise(time: torch.Tensor) torch.Tensor

The rate of change of the noise i.e. g(t)

Parameters:

time (torch.Tensor) – The diffusion time

Returns:

The rate of change of the noise

Return type:

torch.Tensor

total_noise(time: torch.Tensor) torch.Tensor

Total noise at time t

Given as the integral of the rate of change of the noise i.e. int_0^t g(t) dt + g(0)

Parameters:

time (torch.Tensor) – The diffusion time

Returns:

The total noise at time t

Return type:

torch.Tensor

class agedi.diffusion.noisers.types.Transition

Placeholder class for transition matrix representations.

class agedi.diffusion.noisers.types.Types(prior=Constant(0), distribution=Categorical(), noise_schedule: NoiseSchedule = NoiseSchedule(0.01, 3.0), sampling_mask: torch.Tensor | None = None, n_classes: int = 100, type_map: List[int] | None = None, **kwargs)

Bases: agedi.diffusion.noisers.Noiser

Type noiser.

Based on score entropy and discrete diffusion model. See https://arxiv.org/abs/2310.16834 for more details.

Uses an absorbing state (index 0) as the first state in the transition matrix.

_key = 'x'
noise_schedule
sampling_mask = None
n_classes = 100
get_hparams() Dict

Return hyperparameters for this types noiser.

_noise(batch: agedi.data.AtomsGraph) agedi.data.AtomsGraph

Noises the attribute of the atomistic structure.

Performs the noising of the atomic types.

Parameters:

batch (AtomsGraph) – The atomistic structure (or batch hereof) to be noised.

Returns:

The noised atomistic structure (or bach hereof).

Return type:

AtomsGraph

_denoise(batch: agedi.data.AtomsGraph, delta_t: float, last: bool) agedi.data.AtomsGraph

Denoises the attribute of the atomistic structure.

Denoisis the atomic types.

Parameters:
  • batch (AtomsGraph) – The atomistic structure (or batch hereof) to be denoised.

  • delta_t (float) – The time step to be used for the denoising.

  • last (bool) – If the last denoising step is performed.

Returns:

The denoised atomistic structure (or bach hereof).

Return type:

AtomsGraph

_loss(batch: agedi.data.AtomsGraph) torch.Tensor

Computes the training loss.

The score is with score entropy training as thus given as score=log(s) and then for sampling should be used as a concrete score i.e. exp(score)!

Parameters:

batch (AtomsGraph) – The atomistic structure (or batch hereof) to be noised and denoised.

Returns:

The loss of the noised and denoised atomistic structure.

Return type:

float

sample_transition(x: torch.Tensor, sigma: torch.Tensor) torch.Tensor

Sample the transition vector for the types

This corresponds to noising the types in the discrete diffusion model

Parameters:
  • x (torch.Tensor) – The current types

  • sigma (torch.Tensor) – The total noise

Returns:

The noised types

Return type:

torch.Tensor

score_entropy(score: torch.Tensor, sigma: torch.Tensor, x: torch.Tensor, x0: torch.Tensor) torch.Tensor

Computes the score entropy loss

Parameters:
  • score (torch.Tensor) – The score

  • sigma (torch.Tensor) – The total noise

  • x (torch.Tensor) – The noised types

  • x0 (torch.Tensor) – The original types

Returns:

The score entropy loss

Return type:

torch.Tensor

transp_rate(x: torch.Tensor) torch.Tensor

Compute the i’th row of the rate transition matrix Q

Can be used to compute the reverse rate

Parameters:

x (torch.Tensor) – The types

Returns:

The i’th row of the rate transition matrix Q

Return type:

torch.Tensor

reverse_rate(x: torch.Tensor, score: torch.Tensor) torch.Tensor

Constructs the reverse rate.

The reverse rate is given as the score * transp_rate

Parameters:
  • x (torch.Tensor) – The types

  • score (torch.Tensor) – The score

Returns:

The reverse rate

Return type:

torch.Tensor

sample_rate(callable: Callable, x: torch.Tensor, rate: torch.Tensor) torch.Tensor

Sample the rate

Explain more…

Parameters:
  • callable (Callable) – Callable function defining the categorical distribution

  • x (torch.Tensor) – The types

  • rate (torch.Tensor) – The rate

Returns:

The sampled rate

Return type:

torch.Tensor

staggered_score(score: torch.Tensor, dsigma: torch.Tensor) torch.Tensor

Computes the staggered score

Computes p_{sigma - dsigma}(z) / p_{sigma}(x), which is approximated with e^{-{dsigma} E} score

Parameters:
  • score (torch.Tensor) – The score

  • dsigma (torch.Tensor) – The rate noise

Returns:

The staggered score

Return type:

torch.Tensor

transp_transition(x: torch.Tensor, sigma: torch.Tensor) torch.Tensor

Compute the transition matrix for the types

Explain more..

Parameters:
  • x (torch.Tensor) – The types

  • sigma (torch.Tensor) – The total noise

Returns:

The transition matrix

Return type:

torch.Tensor

agedi.diffusion.noisers.types.TypesNoiser