agedi.utils.truncated_normal ============================ .. py:module:: agedi.utils.truncated_normal .. autoapi-nested-parse:: BSD 3-Clause License Copyright (c) 2020, Anton Obukhov All rights reserved. https://github.com/toshas/torch_truncnorm Attributes ---------- .. autoapisummary:: agedi.utils.truncated_normal.CONST_SQRT_2 agedi.utils.truncated_normal.CONST_INV_SQRT_2PI agedi.utils.truncated_normal.CONST_INV_SQRT_2 agedi.utils.truncated_normal.CONST_LOG_INV_SQRT_2PI agedi.utils.truncated_normal.CONST_LOG_SQRT_2PI_E Classes ------- .. autoapisummary:: agedi.utils.truncated_normal.TruncatedStandardNormal agedi.utils.truncated_normal.TruncatedNormal Module Contents --------------- .. py:data:: CONST_SQRT_2 .. py:data:: CONST_INV_SQRT_2PI .. py:data:: CONST_INV_SQRT_2 .. py:data:: CONST_LOG_INV_SQRT_2PI .. py:data:: CONST_LOG_SQRT_2PI_E .. py:class:: TruncatedStandardNormal(a: Union[numbers.Number, torch.Tensor], b: Union[numbers.Number, torch.Tensor], validate_args: Optional[bool] = None) Bases: :py:obj:`torch.distributions.Distribution` Truncated Standard Normal distribution https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf .. py:attribute:: arg_constraints .. py:attribute:: has_rsample :value: True .. py:attribute:: _dtype_min_gt_0 .. py:attribute:: _dtype_max_lt_1 .. py:attribute:: _little_phi_a .. py:attribute:: _little_phi_b .. py:attribute:: _big_phi_a .. py:attribute:: _big_phi_b .. py:attribute:: _Z .. py:attribute:: _log_Z .. py:attribute:: _lpbb_m_lpaa_d_Z .. py:attribute:: _mean .. py:attribute:: _variance .. py:attribute:: _entropy .. py:method:: support() Return the support interval ``[a, b]``. .. py:property:: mean :type: torch.Tensor Mean of the truncated standard normal distribution. .. py:property:: variance :type: torch.Tensor Variance of the truncated standard normal distribution. .. py:property:: entropy :type: torch.Tensor Differential entropy of the truncated standard normal distribution. .. py:property:: auc :type: torch.Tensor Normalisation constant Z = Φ(b) − Φ(a). .. py:method:: _little_phi(x: torch.Tensor) -> torch.Tensor :staticmethod: Standard normal probability density function φ(x). .. py:method:: _big_phi(x: torch.Tensor) -> torch.Tensor :staticmethod: Standard normal cumulative distribution function Φ(x). .. py:method:: _inv_big_phi(x: torch.Tensor) -> torch.Tensor :staticmethod: Inverse of the standard normal CDF Φ⁻¹(x). .. py:method:: cdf(value: torch.Tensor) -> torch.Tensor Cumulative distribution function evaluated at *value*. .. py:method:: icdf(value: torch.Tensor) -> torch.Tensor Inverse CDF (quantile function) evaluated at *value*. .. py:method:: log_prob(value: torch.Tensor) -> torch.Tensor Log probability density evaluated at *value*. .. py:method:: rsample(sample_shape: torch.Size = torch.Size()) -> torch.Tensor Draw a re-parameterised sample of the given shape. .. py:class:: TruncatedNormal(loc: Union[numbers.Number, torch.Tensor], scale: Union[numbers.Number, torch.Tensor], a: Union[numbers.Number, torch.Tensor], b: Union[numbers.Number, torch.Tensor], validate_args: Optional[bool] = None) Bases: :py:obj:`TruncatedStandardNormal` Truncated Normal distribution https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf .. py:attribute:: has_rsample :value: True .. py:attribute:: _log_scale .. py:attribute:: _mean .. py:attribute:: _variance .. py:method:: _to_std_rv(value: torch.Tensor) -> torch.Tensor Standardise *value* to the standard (zero-mean, unit-variance) domain. .. py:method:: _from_std_rv(value: torch.Tensor) -> torch.Tensor Map *value* from the standard domain back to the original (loc/scale) domain. .. py:method:: cdf(value: torch.Tensor) -> torch.Tensor Cumulative distribution function evaluated at *value*. .. py:method:: icdf(value: torch.Tensor) -> torch.Tensor Inverse CDF (quantile function) evaluated at *value*. .. py:method:: log_prob(value: torch.Tensor) -> torch.Tensor Log probability density evaluated at *value*.