agedi.utils.truncated_normal

BSD 3-Clause License

Copyright (c) 2020, Anton Obukhov All rights reserved.

https://github.com/toshas/torch_truncnorm

Attributes

Classes

TruncatedStandardNormal

Truncated Standard Normal distribution

TruncatedNormal

Truncated Normal distribution

Module Contents

agedi.utils.truncated_normal.CONST_SQRT_2
agedi.utils.truncated_normal.CONST_INV_SQRT_2PI
agedi.utils.truncated_normal.CONST_INV_SQRT_2
agedi.utils.truncated_normal.CONST_LOG_INV_SQRT_2PI
agedi.utils.truncated_normal.CONST_LOG_SQRT_2PI_E
class agedi.utils.truncated_normal.TruncatedStandardNormal(a: numbers.Number | torch.Tensor, b: numbers.Number | torch.Tensor, validate_args: bool | None = None)

Bases: torch.distributions.Distribution

Truncated Standard Normal distribution https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf

arg_constraints
has_rsample = True
_dtype_min_gt_0
_dtype_max_lt_1
_little_phi_a
_little_phi_b
_big_phi_a
_big_phi_b
_Z
_log_Z
_lpbb_m_lpaa_d_Z
_mean
_variance
_entropy
support()

Return the support interval [a, b].

property mean: torch.Tensor

Mean of the truncated standard normal distribution.

property variance: torch.Tensor

Variance of the truncated standard normal distribution.

property entropy: torch.Tensor

Differential entropy of the truncated standard normal distribution.

property auc: torch.Tensor

Normalisation constant Z = Φ(b) − Φ(a).

static _little_phi(x: torch.Tensor) torch.Tensor

Standard normal probability density function φ(x).

static _big_phi(x: torch.Tensor) torch.Tensor

Standard normal cumulative distribution function Φ(x).

static _inv_big_phi(x: torch.Tensor) torch.Tensor

Inverse of the standard normal CDF Φ⁻¹(x).

cdf(value: torch.Tensor) torch.Tensor

Cumulative distribution function evaluated at value.

icdf(value: torch.Tensor) torch.Tensor

Inverse CDF (quantile function) evaluated at value.

log_prob(value: torch.Tensor) torch.Tensor

Log probability density evaluated at value.

rsample(sample_shape: torch.Size = torch.Size()) torch.Tensor

Draw a re-parameterised sample of the given shape.

class agedi.utils.truncated_normal.TruncatedNormal(loc: numbers.Number | torch.Tensor, scale: numbers.Number | torch.Tensor, a: numbers.Number | torch.Tensor, b: numbers.Number | torch.Tensor, validate_args: bool | None = None)

Bases: TruncatedStandardNormal

Truncated Normal distribution https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf

has_rsample = True
_log_scale
_mean
_variance
_to_std_rv(value: torch.Tensor) torch.Tensor

Standardise value to the standard (zero-mean, unit-variance) domain.

_from_std_rv(value: torch.Tensor) torch.Tensor

Map value from the standard domain back to the original (loc/scale) domain.

cdf(value: torch.Tensor) torch.Tensor

Cumulative distribution function evaluated at value.

icdf(value: torch.Tensor) torch.Tensor

Inverse CDF (quantile function) evaluated at value.

log_prob(value: torch.Tensor) torch.Tensor

Log probability density evaluated at value.