Browse Source

Added truncated normal tests

Added some tests for TruncatedNormal, mainly for the unbounded case
Corrected mathematical and precision errors in TruncatedNormal
pull/344/head
BenHewins 11 years ago
parent
commit
ad0cc0c7e8
  1. 89
      src/Numerics/Distributions/TruncatedNormal.cs
  2. 2
      src/UnitTests/DistributionTests/CommonDistributionTests.cs
  3. 198
      src/UnitTests/DistributionTests/Continuous/TruncatedNormalTests.cs
  4. 1
      src/UnitTests/UnitTests.csproj

89
src/Numerics/Distributions/TruncatedNormal.cs

@ -40,15 +40,31 @@ namespace MathNet.Numerics.Distributions {
/// For more details about this distribution, see
/// <a href="https://en.wikipedia.org/wiki/Truncated_normal_distribution">Wikipedia - Truncated normal distribution</a>
/// </summary>
public class TruncatedNormal : IContinuousDistribution {
public class TruncatedNormal : IContinuousDistribution
{
System.Random _random;
readonly double _mean;
readonly double _stdDev;
/// <summary>
/// Mean of the untruncated normal distribution.
/// </summary>
readonly double _mu;
/// <summary>
/// Standard deviation of the uncorrected normal distribution.
/// </summary>
readonly double _sigma;
readonly double _lowerBound;
readonly double _upperBound;
readonly Normal _uncorrectedNormal;
readonly Normal _standardNormal = new Normal(0.0, 1.0);
/// <summary>
/// Position in the standard normal distribution of the lower bound.
/// </summary>
readonly double _alpha;
/// <summary>
/// Position in the standard normal distribution of the upper bound.
/// </summary>
readonly double _beta;
/// <summary>
/// The total density of the uncorrected normal distribution which is within the lower and upper bounds.
/// Referred to as "Z" in the wikipedia equations. Z = Φ(UpperBound) - Φ(LowerBound).
@ -61,7 +77,7 @@ namespace MathNet.Numerics.Distributions {
/// normal distribution.
/// </summary>
/// <param name="mean">The mean (μ) of the untruncated distribution.</param>
/// <param name="stddev">The standard deviation (σ) of the untruncated distribution. Range: σ 0.</param>
/// <param name="stddev">The standard deviation (σ) of the untruncated distribution. Range: σ > 0.</param>
/// <param name="lowerBound">The inclusive lower bound of the truncated distribution. Default is double.NegativeInfinity.</param>
/// <param name="upperBound">The inclusive upper bound of the truncated distribution. Must be larger than <paramref name="lowerBound"/>.
/// Default is double.PositiveInfinity.</param>
@ -76,7 +92,7 @@ namespace MathNet.Numerics.Distributions {
/// be initialized with the default <seealso cref="System.Random"/> random number generator.
/// </summary>
/// <param name="mean">The mean (μ) of the normal distribution.</param>
/// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ 0.</param>
/// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ > 0.</param>
/// <param name="randomSource">The random number generator which is used to draw random samples.</param>
public TruncatedNormal(double mean, double stddev, System.Random randomSource, double lowerBound = double.NegativeInfinity, double upperBound = double.PositiveInfinity)
{
@ -86,28 +102,30 @@ namespace MathNet.Numerics.Distributions {
}
_random = randomSource ?? SystemRandomSource.Default;
_mean = mean;
_stdDev = stddev;
_mu = mean;
_sigma = stddev;
_lowerBound = lowerBound;
_upperBound = upperBound;
_uncorrectedNormal = Normal.WithMeanStdDev(_mean, _stdDev);
_cumulativeDensityWithinBounds = _uncorrectedNormal.CumulativeDistribution(_upperBound) - _uncorrectedNormal.CumulativeDistribution(_lowerBound);
_alpha = (_lowerBound - _mu) / _sigma;
_beta = (_upperBound - _mu) / _sigma;
_cumulativeDensityWithinBounds = _standardNormal.CumulativeDistribution(_beta) - _standardNormal.CumulativeDistribution(_alpha);
}
/// <summary>
/// Tests whether the provided values are valid parameters for this distribution.
/// </summary>
/// <param name="mean">The mean (μ) of the normal distribution.</param>
/// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ 0.</param>
/// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ > 0.</param>
public static bool IsValidParameterSet(double mean, double stddev, double lowerBound, double upperBound)
{
bool normalRequirements = Normal.IsValidParameterSet(mean, stddev);
bool normalRequirements = Normal.IsValidParameterSet(mean, stddev) && stddev > 0;
bool boundsAreOrdered = lowerBound < upperBound;
return normalRequirements && boundsAreOrdered;
}
public override string ToString() {
return "TruncatedNormal(μ = " + _mean + ", σ = " + _stdDev +", LowerBound = " + _lowerBound + ", UpperBound = " + _upperBound + ")";
return "TruncatedNormal(μ = " + _mu + ", σ = " + _sigma +", LowerBound = " + _lowerBound + ", UpperBound = " + _upperBound + ")";
}
/// <summary>
@ -117,11 +135,11 @@ namespace MathNet.Numerics.Distributions {
{
get
{
if (_mean < _lowerBound)
if (_mu < _lowerBound)
return _lowerBound;
if (_mean > _upperBound)
if (_mu > _upperBound)
return _upperBound;
return _mean;
return _mu;
}
}
@ -148,9 +166,9 @@ namespace MathNet.Numerics.Distributions {
{
get
{
var pdfDifference = _uncorrectedNormal.Density(_lowerBound) - _uncorrectedNormal.Density(_upperBound);
var diffFromUncorrected = pdfDifference * _stdDev / _cumulativeDensityWithinBounds;
return _mean + diffFromUncorrected;
var pdfDifference = _standardNormal.Density(_alpha) - _standardNormal.Density(_beta);
var diffFromUncorrected = pdfDifference * _sigma / _cumulativeDensityWithinBounds;
return _mu + diffFromUncorrected;
}
}
@ -161,24 +179,31 @@ namespace MathNet.Numerics.Distributions {
{
get
{
//Apparently "Barr and Sherrill (1999)" has a simpler expression for one sided truncations, if anyone has access...
//TODO might need special handling for cases where either or both bounds are infinity
var densityAtLower = double.IsNegativeInfinity(_lowerBound) ? 0.0 : _standardNormal.Density(_alpha);
var densityAtUpper = double.IsPositiveInfinity(_upperBound) ? 0.0 : _standardNormal.Density(_beta);
var standardisedLower = double.IsNegativeInfinity(_lowerBound) ? 0.0 : _alpha;
var standardisedUpper = double.IsPositiveInfinity(_upperBound) ? 0.0 : _beta;
//Second term
var secondNumerator = _lowerBound * _uncorrectedNormal.Density(_lowerBound) - _upperBound * _uncorrectedNormal.Density(_upperBound);
var secondNumerator = standardisedLower * densityAtLower - standardisedUpper * densityAtUpper;
var secordTerm = secondNumerator / _cumulativeDensityWithinBounds;
//Third term
var thirdNumerator = _uncorrectedNormal.Density(_lowerBound) - _uncorrectedNormal.Density(_upperBound);
var thirdNumerator = densityAtLower - densityAtUpper;
var thirdTerm = (thirdNumerator / _cumulativeDensityWithinBounds) * (thirdNumerator / _cumulativeDensityWithinBounds);
var sumOfTerms = 1 + secordTerm + thirdTerm;
return _stdDev * _stdDev * sumOfTerms;
return _sigma * _sigma * sumOfTerms;
}
}
/// <summary>
/// Gets the standard deviation (σ) of the truncated normal distribution. Range: σ 0.
/// Gets the standard deviation (σ) of the truncated normal distribution. Range: σ > 0.
/// </summary>
public double StdDev
{
@ -192,9 +217,9 @@ namespace MathNet.Numerics.Distributions {
{
get
{
var firstTerm = Constants.LogSqrt2PiE + Math.Log(_stdDev + _cumulativeDensityWithinBounds);
var firstTerm = Constants.LogSqrt2PiE + Math.Log(_sigma + _cumulativeDensityWithinBounds);
var secondNumerator = _lowerBound * _uncorrectedNormal.Density(_lowerBound) - _upperBound * _uncorrectedNormal.Density(_upperBound);
var secondNumerator = _lowerBound * _standardNormal.Density(_alpha) - _upperBound * _standardNormal.Density(_beta);
var secondTerm = secondNumerator / (2 * _cumulativeDensityWithinBounds);
return firstTerm + secondTerm;
@ -240,7 +265,7 @@ namespace MathNet.Numerics.Distributions {
if (x < _lowerBound || _upperBound < x)
return 0d;
return _uncorrectedNormal.Density(x) / (_stdDev * _cumulativeDensityWithinBounds);
return _standardNormal.Density((x - _mu) / _sigma) / (_sigma * _cumulativeDensityWithinBounds);
}
/// <summary>
@ -251,11 +276,11 @@ namespace MathNet.Numerics.Distributions {
/// <seealso cref="PDFLn"/>
public double DensityLn(double x)
{
return Math.Log(Density(x));
}
return _standardNormal.DensityLn((x - _mu) / _sigma) - Math.Log(_sigma) - Math.Log(_cumulativeDensityWithinBounds);
}
//TODO: implement sampling, use method described by Mazet here: http://miv.u-strasbg.fr/mazet/rtnorm/
// see implmentations listed on that page for examples.
// see implementations listed on that page for examples.
public double Sample()
{
@ -285,7 +310,7 @@ namespace MathNet.Numerics.Distributions {
if (x > _upperBound)
return 1d;
double cumulative = _uncorrectedNormal.CumulativeDistribution(x) - _uncorrectedNormal.CumulativeDistribution(_lowerBound);
double cumulative = _standardNormal.CumulativeDistribution((x - _mu) / _sigma) - _standardNormal.CumulativeDistribution(_alpha);
return cumulative / _cumulativeDensityWithinBounds;
}
@ -299,9 +324,9 @@ namespace MathNet.Numerics.Distributions {
public double InverseCumulativeDistribution(double p)
{
//TODO check that this is correct with someone.
var pUntruncated = p * _cumulativeDensityWithinBounds + _uncorrectedNormal.CumulativeDistribution(_lowerBound);
var pUntruncated = p * _cumulativeDensityWithinBounds + _standardNormal.CumulativeDistribution(_alpha);
return _uncorrectedNormal.InverseCumulativeDistribution(pUntruncated);
return _standardNormal.InverseCumulativeDistribution(pUntruncated) * _sigma + _mu;
}
}

2
src/UnitTests/DistributionTests/CommonDistributionTests.cs

@ -89,6 +89,8 @@ namespace MathNet.Numerics.UnitTests.DistributionTests
new StudentT(0.0, 1.0, 5.0),
new Triangular(0, 1, 0.7),
new Weibull(1.0, 1.0),
new TruncatedNormal(0, 1.0, -5.0, 5.0), //Finite
new TruncatedNormal(0, 1.0, -5.0), //Semi-finite
};
[Test]

198
src/UnitTests/DistributionTests/Continuous/TruncatedNormalTests.cs

@ -0,0 +1,198 @@
// <copyright file="TruncatedNormalTests.cs" company="Math.NET">
// Math.NET Numerics, part of the Math.NET Project
// http://numerics.mathdotnet.com
// http://github.com/mathnet/mathnet-numerics
// http://mathnetnumerics.codeplex.com
//
// Copyright (c) 2009-2014 Math.NET
//
// Permission is hereby granted, free of charge, to any person
// obtaining a copy of this software and associated documentation
// files (the "Software"), to deal in the Software without
// restriction, including without limitation the rights to use,
// copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following
// conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
// WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
// OTHER DEALINGS IN THE SOFTWARE.
// </copyright>
using MathNet.Numerics.Distributions;
using NUnit.Framework;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous
{
[TestFixture, Category("Distributions")]
public class TruncatedNormalTests
{
/// <summary>
/// Can create a truncated normal without bounds.
/// </summary>
[TestCase(0.0, 0.0)]
[TestCase(10.0, 0.1)]
[TestCase(-5.0, 1.0)]
[TestCase(0.0, 10.0)]
[TestCase(10.0, 100.0)]
[TestCase(-5.0, Double.PositiveInfinity)]
public void CanCreateWithoutBounds(double mean, double stdDev)
{
var truncatedNormal = new TruncatedNormal(mean, stdDev);
Assert.IsTrue(double.IsNegativeInfinity(truncatedNormal.Minimum));
Assert.IsTrue(double.IsPositiveInfinity(truncatedNormal.Maximum));
Assert.AreEqual(mean, truncatedNormal.Mean);
Assert.AreEqual(stdDev, truncatedNormal.StdDev);
}
/// <summary>
/// Constructor fails with negative stdDev or incorrectly ordered bounds.
/// </summary>
/// <param name="mean"></param>
/// <param name="stdDev"></param>
/// <param name="lower"></param>
/// <param name="upper"></param>
[TestCase(0.0, -1.0,-10d, 10d)]
[TestCase(0.0, 1.0, 10d, 9d)]
public void TruncatedNormalCreateFailsWithBadParameters(double mean, double stdDev, double lower, double upper)
{
Assert.That(() => new TruncatedNormal(mean, stdDev, lower, upper), Throws.ArgumentException);
}
[TestCase(0.0, 1.0, double.NegativeInfinity, double.PositiveInfinity)]
[TestCase(0.0, 1.0, -5.0, 5.0)]
[TestCase(-6.0, 1.0, -5.0, 5.0)]
[TestCase(8.0, 1.0, -5.0, 5.0)]
[TestCase(15, 20.0, -20.0, 0.0)]
public void ValidateMode(double mean, double stdDev, double lower, double upper)
{
double mode;
if(mean < lower) {
mode = lower;
} else if(mean <= upper) {
mode = mean;
} else {
mode = upper;
}
var truncatedNormal = new TruncatedNormal(mean, stdDev, lower, upper);
Assert.AreEqual(mode, truncatedNormal.Mode);
}
/// <summary>
/// Validate cumulative distribution. Uses the same test cases as for the normal distribution
/// as they should be equivalent.
/// </summary>
/// <param name="x">Input X value.</param>
/// <param name="p">Expected value.</param>
[TestCase(Double.NegativeInfinity, 0.0)]
[TestCase(-5.0, 0.00000028665157187919391167375233287464535385442301361187883)]
[TestCase(-2.0, 0.0002326290790355250363499258867279847735487493358890356)]
[TestCase(-0.0, 0.0062096653257761351669781045741922211278977469230927036)]
[TestCase(0.0, .0062096653257761351669781045741922211278977469230927036)]
[TestCase(4.0, .30853753872598689636229538939166226011639782444542207)]
[TestCase(5.0, .5)]
[TestCase(6.0, .69146246127401310363770461060833773988360217555457859)]
[TestCase(10.0, 0.9937903346742238648330218954258077788721022530769078)]
[TestCase(Double.PositiveInfinity, 1.0)]
public void ValidateCumulativeNoBounds(double x, double p) {
var truncatedNormal = new TruncatedNormal(5.0, 2.0);
AssertHelpers.AlmostEqualRelative(p, truncatedNormal.CumulativeDistribution(x), 14);
}
/// <summary>
/// Validate inverse cumulative distribution. Uses the same test cases as for the normal distribution
/// as they should be equivalent.
/// </summary>
/// <param name="x">Expected value</param>
/// <param name="p">Input quantile</param>
[TestCase(Double.NegativeInfinity, 0.0)]
[TestCase(-5.0, 0.00000028665157187919391167375233287464535385442301361187883)]
[TestCase(-2.0, 0.0002326290790355250363499258867279847735487493358890356)]
[TestCase(-0.0, 0.0062096653257761351669781045741922211278977469230927036)]
[TestCase(0.0, 0.0062096653257761351669781045741922211278977469230927036)]
[TestCase(4.0, .30853753872598689636229538939166226011639782444542207)]
[TestCase(5.0, .5)]
[TestCase(6.0, .69146246127401310363770461060833773988360217555457859)]
[TestCase(10.0, 0.9937903346742238648330218954258077788721022530769078)]
[TestCase(Double.PositiveInfinity, 1.0)]
public void ValidateInverseCumulativeNoBounds(double x, double p)
{
var truncatedNormal = new TruncatedNormal(5.0, 2.0);
AssertHelpers.AlmostEqualRelative(x, truncatedNormal.InverseCumulativeDistribution(p), 14);
}
/// <summary>
/// Validate density when no bounds are specified. Uses same
/// test cases as the Normal distribution as should be equivalent in this case.
/// </summary>
/// <param name="mean">Mean value.</param>
/// <param name="sdev">Standard deviation value.</param>
[TestCase(10.0, 0.1)]
[TestCase(-5.0, 1.0)]
[TestCase(0.0, 10.0)]
[TestCase(10.0, 100.0)]
[TestCase(-5.0, Double.PositiveInfinity)]
public void ValidateDensityNoBounds(double mean, double sdev) {
var n = new TruncatedNormal(mean, sdev);
for (var i = 0; i < 11; i++) {
var x = i - 5.0;
var d = (mean - x) / sdev;
var pdf = Math.Exp(-0.5 * d * d) / (sdev * Constants.Sqrt2Pi);
AssertHelpers.AlmostEqualRelative(pdf, n.Density(x), 14);
}
}
/// <summary>
/// Validate density log when no bounds are specified. Uses same
/// test cases as the Normal distribution as should be equivalent in this case.
/// </summary>
/// <param name="mean">Mean value.</param>
/// <param name="sdev">Standard deviation value.</param>
[TestCase(10.0, 0.1)]
[TestCase(-5.0, 1.0)]
[TestCase(0.0, 10.0)]
[TestCase(10.0, 100.0)]
[TestCase(-5.0, Double.PositiveInfinity)]
public void ValidateDensityLnNoBounds(double mean, double sdev) {
var n = new TruncatedNormal(mean, sdev);
for (var i = 0; i < 11; i++) {
var x = i - 5.0;
var d = (mean - x) / sdev;
var pdfln = (-0.5 * (d * d)) - Math.Log(sdev) - Constants.LogSqrt2Pi;
AssertHelpers.AlmostEqualRelative(pdfln, n.DensityLn(x), 14);
}
}
[Test]
public void CanSample()
{
var truncatedNormal = new TruncatedNormal(5.0, 2.0, -10, 10.0);
truncatedNormal.Sample();
}
/// <summary>
/// Can sample sequence.
/// </summary>
[Test]
public void CanSampleSequence() {
var truncatedNormal = new TruncatedNormal(5.0, 2.0, -10, 10.0);
var ied = truncatedNormal.Samples();
GC.KeepAlive(ied.Take(5).ToArray());
}
}
}

1
src/UnitTests/UnitTests.csproj

@ -109,6 +109,7 @@
<Compile Include="DistributionTests\Continuous\StableTests.cs" />
<Compile Include="DistributionTests\Continuous\StudentTTests.cs" />
<Compile Include="DistributionTests\Continuous\TriangularTests.cs" />
<Compile Include="DistributionTests\Continuous\TruncatedNormalTests.cs" />
<Compile Include="DistributionTests\Continuous\WeibullTests.cs" />
<Compile Include="DistributionTests\Discrete\BernoulliTests.cs" />
<Compile Include="DistributionTests\Discrete\BinomialTests.cs" />

Loading…
Cancel
Save