agedi.diffusion.distributions.categorical

Classes

Categorical

Categorical Distribution

Module Contents

class agedi.diffusion.distributions.categorical.Categorical

Bases: agedi.diffusion.distributions.Distribution

Categorical Distribution

Implements hard sampling using the Gumbel-Max trick.

_sample(probs: torch.Tensor) torch.Tensor

Sample from the categorical distribution where probabilites define the likelihood of mu value to be set to the masked, 0, value

Parameters:

probs (torch.Tensor) – The probabilities of each category

Returns:

Sampled tensor

Return type:

torch.Tensor