agedi.diffusion.distributions.uniform

Classes

Uniform

Uniform Distribution

UniformCell

Uniform Prior Distribution for cell parameters

UniformCellConfined

Uniform Prior Distribution for cell parameters with Z-directional confinement

Module Contents

class agedi.diffusion.distributions.uniform.Uniform(low: float = 0.0, high: float = 1.0, key: str = 'x', **kwargs)

Bases: agedi.diffusion.distributions.Distribution

Uniform Distribution

Parameters:
  • low (float) – The lower bound of the distribution

  • high (float) – The upper bound of the distribution

low = 0.0
high = 1.0
get_hparams() Dict

Return hyperparameters for this distribution.

_setup(batch: agedi.data.AtomsGraph) None

Prepare the distribution for sampling from batch.

Sets self.shape to the shape of the target attribute in the batch.

Parameters:

batch (AtomsGraph) – Batch of atomistic data.

_sample(shape: torch.Size | None = None, **kwargs) torch.Tensor

Sample from the uniform distribution

Parameters:

shape (torch.Size) – The shape of the sample

Returns:

Sampled tensor

Return type:

torch.Tensor

class agedi.diffusion.distributions.uniform.UniformCell(low: float = 0.0, high: float = 1.0, key: str = 'x', **kwargs)

Bases: Uniform

Uniform Prior Distribution for cell parameters

_setup(batch: agedi.data.AtomsGraph) None

Prepare the distribution for sampling of the batch

Parameters:

batch (AtomsGraph) – Batch of data

Return type:

None

_sample(**kwargs) torch.Tensor

Sample from the uniform distribution

Parameters:
  • mu (torch.Tensor) – Mean of the distribution

  • sigma (torch.Tensor) – Standard deviation of the distribution

Returns:

Sampled tensor

Return type:

torch.Tensor

class agedi.diffusion.distributions.uniform.UniformCellConfined(low: float = 0.0, high: float = 1.0, key: str = 'x', **kwargs)

Bases: UniformCell

Uniform Prior Distribution for cell parameters with Z-directional confinement

_setup(batch: agedi.data.AtomsGraph) None

Prepare the distribution for sampling of the batch

Parameters:

batch (AtomsGraph) – Batch of data

Return type:

None