agedi.utils.truncated_normal¶
BSD 3-Clause License
Copyright (c) 2020, Anton Obukhov All rights reserved.
https://github.com/toshas/torch_truncnorm
Attributes¶
Classes¶
Truncated Standard Normal distribution |
|
Truncated Normal distribution |
Module Contents¶
- agedi.utils.truncated_normal.CONST_SQRT_2¶
- agedi.utils.truncated_normal.CONST_INV_SQRT_2PI¶
- agedi.utils.truncated_normal.CONST_INV_SQRT_2¶
- agedi.utils.truncated_normal.CONST_LOG_INV_SQRT_2PI¶
- agedi.utils.truncated_normal.CONST_LOG_SQRT_2PI_E¶
- class agedi.utils.truncated_normal.TruncatedStandardNormal(a: numbers.Number | torch.Tensor, b: numbers.Number | torch.Tensor, validate_args: bool | None = None)¶
Bases:
torch.distributions.DistributionTruncated Standard Normal distribution https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
- arg_constraints¶
- has_rsample = True¶
- _dtype_min_gt_0¶
- _dtype_max_lt_1¶
- _little_phi_a¶
- _little_phi_b¶
- _big_phi_a¶
- _big_phi_b¶
- _Z¶
- _log_Z¶
- _lpbb_m_lpaa_d_Z¶
- _mean¶
- _variance¶
- _entropy¶
- support()¶
Return the support interval
[a, b].
- property mean: torch.Tensor¶
Mean of the truncated standard normal distribution.
- property variance: torch.Tensor¶
Variance of the truncated standard normal distribution.
- property entropy: torch.Tensor¶
Differential entropy of the truncated standard normal distribution.
- property auc: torch.Tensor¶
Normalisation constant Z = Φ(b) − Φ(a).
- static _little_phi(x: torch.Tensor) torch.Tensor¶
Standard normal probability density function φ(x).
- static _big_phi(x: torch.Tensor) torch.Tensor¶
Standard normal cumulative distribution function Φ(x).
- static _inv_big_phi(x: torch.Tensor) torch.Tensor¶
Inverse of the standard normal CDF Φ⁻¹(x).
- cdf(value: torch.Tensor) torch.Tensor¶
Cumulative distribution function evaluated at value.
- icdf(value: torch.Tensor) torch.Tensor¶
Inverse CDF (quantile function) evaluated at value.
- log_prob(value: torch.Tensor) torch.Tensor¶
Log probability density evaluated at value.
- rsample(sample_shape: torch.Size = torch.Size()) torch.Tensor¶
Draw a re-parameterised sample of the given shape.
- class agedi.utils.truncated_normal.TruncatedNormal(loc: numbers.Number | torch.Tensor, scale: numbers.Number | torch.Tensor, a: numbers.Number | torch.Tensor, b: numbers.Number | torch.Tensor, validate_args: bool | None = None)¶
Bases:
TruncatedStandardNormalTruncated Normal distribution https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
- has_rsample = True¶
- _log_scale¶
- _mean¶
- _variance¶
- _to_std_rv(value: torch.Tensor) torch.Tensor¶
Standardise value to the standard (zero-mean, unit-variance) domain.
- _from_std_rv(value: torch.Tensor) torch.Tensor¶
Map value from the standard domain back to the original (loc/scale) domain.
- cdf(value: torch.Tensor) torch.Tensor¶
Cumulative distribution function evaluated at value.
- icdf(value: torch.Tensor) torch.Tensor¶
Inverse CDF (quantile function) evaluated at value.
- log_prob(value: torch.Tensor) torch.Tensor¶
Log probability density evaluated at value.