agedi.diffusion.distributions.categorical¶
Classes¶
Categorical Distribution |
Module Contents¶
- class agedi.diffusion.distributions.categorical.Categorical¶
Bases:
agedi.diffusion.distributions.DistributionCategorical Distribution
Implements hard sampling using the Gumbel-Max trick.
- _sample(probs: torch.Tensor) torch.Tensor¶
Sample from the categorical distribution where probabilites define the likelihood of mu value to be set to the masked, 0, value
- Parameters:
probs (torch.Tensor) – The probabilities of each category
- Returns:
Sampled tensor
- Return type:
torch.Tensor