agedi.diffusion.sdes.base

Classes

SDE

SDE base class

Module Contents

class agedi.diffusion.sdes.base.SDE(noise_schedule: agedi.diffusion.sdes.noise_schedules.NoiseSchedule = Linear)

Bases: abc.ABC

SDE base class

noise_schedule_cls
get_hparams() Dict

Return hyperparameters sufficient to reconstruct this SDE.

Returns a dictionary with a _target_ key (the fully-qualified class name). Subclasses should call super().get_hparams() and merge in their own constructor parameters.

Returns:

Hyperparameter dictionary.

Return type:

dict

abstractmethod drift(x: torch.Tensor, t: torch.Tensor) torch.Tensor

Drift term of the SDE.

Must be implemented by subclass.

Defines the drift term of the SDE: .. math:

f(x, t) = ...
Parameters:
  • x (torch.Tensor) – The positions of the atoms.

  • t (torch.Tensor) – The time at which to calculate the drift term.

Returns:

drift – The drift term of the SDE.

Return type:

torch.Tensor

abstractmethod diffusion(t: torch.Tensor) torch.Tensor

Diffusion term of the SDE.

Must be implemented by subclass.

Defines the diffusion term of the SDE: .. math:

g(t) = ...
Parameters:

t (torch.Tensor) – The time at which to calculate the diffusion term.

Returns:

diffusion – The diffusion term of the SDE.

Return type:

torch.Tensor

abstractmethod mean(t: torch.Tensor) torch.Tensor

Mean of the SDE.

Must be implemented by subclass.

Calculates the mean of transition kernel at time t: .. math:

\mu_t = ...
Parameters:

t (torch.Tensor) – The time at which to calculate the mean.

Returns:

mean – The mean of the diffusion process.

Return type:

torch.Tensor

abstractmethod var(t: torch.Tensor) torch.Tensor

Variance of the SDE.

Must be implemented by subclass.

Calculates the variance of transition kernel at time t: .. math:

\sigma_t^2 = ...
Parameters:

t (torch.Tensor) – The time at which to calculate the variance.

Returns:

var – The variance of the diffusion process.

Return type:

torch.Tensor

transition_kernel(x: torch.Tensor, t: torch.Tensor, w: Callable) torch.Tensor

Transition kernel of the SDE.

Calculates the transition kernel of the diffusion process: .. math:

p(mathbf{x}_t | mathbf{x}_0) = mu_t mathbf{x} + sigma_t mathbf{w},

with \(\mathbf{w} \sim N(0,1)\).

Parameters:
  • x (torch.Tensor) – The positions of the atoms.

  • w (torch.Tensor) – The noise term.

  • t (torch.Tensor) – The time at which to calculate the transition kernel.

Returns:

transition_kernel – The transition kernel of the diffusion process.

Return type:

torch.Tensor

noise(x0: torch.Tensor, xt: torch.Tensor, t: torch.Tensor) torch.Tensor

Noise term of the SDE.

Calculates the noise term of the SDE: .. math:: mathbf{w} =

rac{mathbf{x}_t - mu_t mathbf{x}_0}{sigma_t}

x0: torch.Tensor

x at time 0.

xt: torch.Tensor

x at time t.

t: torch.Tensor

The time at which to calculate the noise term.

noise: torch.Tensor

The noise term of the diffusion process.