agedi.diffusion.sdes

Submodules

Classes

SDE

SDE base class

VE

Implements variance-exploding (VE) SDE.

VP

Implements variance-preserving (VP) SDE.

Package Contents

class agedi.diffusion.sdes.SDE(noise_schedule: agedi.diffusion.sdes.noise_schedules.NoiseSchedule = Linear)

Bases: abc.ABC

SDE base class

noise_schedule_cls
get_hparams() Dict

Return hyperparameters sufficient to reconstruct this SDE.

Returns a dictionary with a _target_ key (the fully-qualified class name). Subclasses should call super().get_hparams() and merge in their own constructor parameters.

Returns:

Hyperparameter dictionary.

Return type:

dict

abstractmethod drift(x: torch.Tensor, t: torch.Tensor) torch.Tensor

Drift term of the SDE.

Must be implemented by subclass.

Defines the drift term of the SDE: .. math:

f(x, t) = ...
Parameters:
  • x (torch.Tensor) – The positions of the atoms.

  • t (torch.Tensor) – The time at which to calculate the drift term.

Returns:

drift – The drift term of the SDE.

Return type:

torch.Tensor

abstractmethod diffusion(t: torch.Tensor) torch.Tensor

Diffusion term of the SDE.

Must be implemented by subclass.

Defines the diffusion term of the SDE: .. math:

g(t) = ...
Parameters:

t (torch.Tensor) – The time at which to calculate the diffusion term.

Returns:

diffusion – The diffusion term of the SDE.

Return type:

torch.Tensor

abstractmethod mean(t: torch.Tensor) torch.Tensor

Mean of the SDE.

Must be implemented by subclass.

Calculates the mean of transition kernel at time t: .. math:

\mu_t = ...
Parameters:

t (torch.Tensor) – The time at which to calculate the mean.

Returns:

mean – The mean of the diffusion process.

Return type:

torch.Tensor

abstractmethod var(t: torch.Tensor) torch.Tensor

Variance of the SDE.

Must be implemented by subclass.

Calculates the variance of transition kernel at time t: .. math:

\sigma_t^2 = ...
Parameters:

t (torch.Tensor) – The time at which to calculate the variance.

Returns:

var – The variance of the diffusion process.

Return type:

torch.Tensor

transition_kernel(x: torch.Tensor, t: torch.Tensor, w: Callable) torch.Tensor

Transition kernel of the SDE.

Calculates the transition kernel of the diffusion process: .. math:

p(mathbf{x}_t | mathbf{x}_0) = mu_t mathbf{x} + sigma_t mathbf{w},

with \(\mathbf{w} \sim N(0,1)\).

Parameters:
  • x (torch.Tensor) – The positions of the atoms.

  • w (torch.Tensor) – The noise term.

  • t (torch.Tensor) – The time at which to calculate the transition kernel.

Returns:

transition_kernel – The transition kernel of the diffusion process.

Return type:

torch.Tensor

noise(x0: torch.Tensor, xt: torch.Tensor, t: torch.Tensor) torch.Tensor

Noise term of the SDE.

Calculates the noise term of the SDE: .. math:: mathbf{w} =

rac{mathbf{x}_t - mu_t mathbf{x}_0}{sigma_t}

x0: torch.Tensor

x at time 0.

xt: torch.Tensor

x at time t.

t: torch.Tensor

The time at which to calculate the noise term.

noise: torch.Tensor

The noise term of the diffusion process.

class agedi.diffusion.sdes.VE(sigma_min: float = 0.01, sigma_max: float = 1.0, **kwargs)

Bases: agedi.diffusion.sdes.SDE

Implements variance-exploding (VE) SDE.

Parameters:
  • sigma_min (float) – The minimum value of the sigma parameter.

  • sigma_max (float) – The maximum value of the sigma parameter.

Return type:

VP

sigma_min = 0.01
sigma_max = 1.0
noise_schedule
get_hparams() Dict

Return hyperparameters for this VE SDE.

drift(x: torch.Tensor, t: torch.Tensor) torch.Tensor

Implement VP drift term.

Defines the drift term of the SDE: f(x, t).

Parameters:
  • x (torch.Tensor) – The positions of the atoms.

  • t (torch.Tensor) – The time at which to calculate the drift term.

Returns:

drift – The drift term of the SDE.

Return type:

torch.Tensor

diffusion(t: torch.Tensor) torch.Tensor

Implement VP diffusion term.

Defines the diffusion term of the SDE: g(t).

Parameters:

t (torch.Tensor) – The time at which to calculate the diffusion term.

Returns:

diffusion – The diffusion term of the SDE.

Return type:

torch.Tensor

mean(t: torch.Tensor) torch.Tensor

Implement VP mean term.

Calculates the mean of transition kernel at time t: mu(t).

Parameters:

t (torch.Tensor) – The time at which to calculate the mean.

Returns:

mean – The mean of the diffusion process.

Return type:

torch.Tensor

var(t: torch.Tensor) torch.Tensor

Implement VP variance term.

Calculates the variance of transition kernel at time t: sigma^2(t).

Parameters:

t (torch.Tensor) – The time at which to calculate the variance.

Returns:

var – The variance of the diffusion process.

Return type:

torch.Tensor

sigma(t: torch.Tensor) torch.Tensor

VE sigma function

Calculates the value of sigma at time t.

Parameters:

t (torch.Tensor) – The time at which to calculate sigma.

Returns:

sigma – The value of sigma at time t.

Return type:

torch.Tensor

class agedi.diffusion.sdes.VP(beta_min: float = 0.01, beta_max: float = 3, **kwargs)

Bases: agedi.diffusion.sdes.SDE

Implements variance-preserving (VP) SDE.

Parameters:
  • beta_min (float) – The minimum value of the beta parameter.

  • beta_max (float) – The maximum value of the beta parameter.

Return type:

VP

beta_min = 0.01
beta_max = 3
noise_schedule
get_hparams() Dict

Return hyperparameters for this VP SDE.

drift(x: torch.Tensor, t: torch.Tensor) torch.Tensor

Implement VP drift term.

Defines the drift term of the SDE: f(x, t).

Parameters:
  • x (torch.Tensor) – The positions of the atoms.

  • t (torch.Tensor) – The time at which to calculate the drift term.

Returns:

drift – The drift term of the SDE.

Return type:

torch.Tensor

diffusion(t: torch.Tensor) torch.Tensor

Implement VP diffusion term.

Defines the diffusion term of the SDE: g(t).

Parameters:

t (torch.Tensor) – The time at which to calculate the diffusion term.

Returns:

diffusion – The diffusion term of the SDE.

Return type:

torch.Tensor

mean(t: torch.Tensor) torch.Tensor

Implement VP mean term.

Calculates the mean of transition kernel at time t: mu(t).

Parameters:

t (torch.Tensor) – The time at which to calculate the mean.

Returns:

mean – The mean of the diffusion process.

Return type:

torch.Tensor

var(t: torch.Tensor) torch.Tensor

Implement VP variance term.

Calculates the variance of transition kernel at time t: sigma^2(t).

Parameters:

t (torch.Tensor) – The time at which to calculate the variance.

Returns:

var – The variance of the diffusion process.

Return type:

torch.Tensor

beta(t: torch.Tensor) torch.Tensor

VP Beta function

Calculates the value of beta at time t.

Parameters:

t (torch.Tensor) – The time at which to calculate beta.

Returns:

beta – The value of beta at time t.

Return type:

torch.Tensor

alpha(t: torch.Tensor) torch.Tensor

VP Alpha function

Calculates the value of alpha at time t with .. math:: lpha(t) = int_{0}^{t} beta(s) ds.

Parameters:

t (torch.Tensor) – The time at which to calculate alpha.

Returns:

alpha – The value of alpha at time t.

Return type:

torch.Tensor