diff --git a/src/Numerics/Distributions/TruncatedNormal.cs b/src/Numerics/Distributions/TruncatedNormal.cs
index 72432c23..b8419593 100644
--- a/src/Numerics/Distributions/TruncatedNormal.cs
+++ b/src/Numerics/Distributions/TruncatedNormal.cs
@@ -40,15 +40,31 @@ namespace MathNet.Numerics.Distributions {
/// For more details about this distribution, see
/// Wikipedia - Truncated normal distribution
///
- public class TruncatedNormal : IContinuousDistribution {
+ public class TruncatedNormal : IContinuousDistribution
+ {
System.Random _random;
- readonly double _mean;
- readonly double _stdDev;
+ ///
+ /// Mean of the untruncated normal distribution.
+ ///
+ readonly double _mu;
+ ///
+ /// Standard deviation of the uncorrected normal distribution.
+ ///
+ readonly double _sigma;
readonly double _lowerBound;
readonly double _upperBound;
- readonly Normal _uncorrectedNormal;
+ readonly Normal _standardNormal = new Normal(0.0, 1.0);
+ ///
+ /// Position in the standard normal distribution of the lower bound.
+ ///
+ readonly double _alpha;
+ ///
+ /// Position in the standard normal distribution of the upper bound.
+ ///
+ readonly double _beta;
+
///
/// The total density of the uncorrected normal distribution which is within the lower and upper bounds.
/// Referred to as "Z" in the wikipedia equations. Z = Φ(UpperBound) - Φ(LowerBound).
@@ -61,7 +77,7 @@ namespace MathNet.Numerics.Distributions {
/// normal distribution.
///
/// The mean (μ) of the untruncated distribution.
- /// The standard deviation (σ) of the untruncated distribution. Range: σ ≥ 0.
+ /// The standard deviation (σ) of the untruncated distribution. Range: σ > 0.
/// The inclusive lower bound of the truncated distribution. Default is double.NegativeInfinity.
/// The inclusive upper bound of the truncated distribution. Must be larger than .
/// Default is double.PositiveInfinity.
@@ -76,7 +92,7 @@ namespace MathNet.Numerics.Distributions {
/// be initialized with the default random number generator.
///
/// The mean (μ) of the normal distribution.
- /// The standard deviation (σ) of the normal distribution. Range: σ ≥ 0.
+ /// The standard deviation (σ) of the normal distribution. Range: σ > 0.
/// The random number generator which is used to draw random samples.
public TruncatedNormal(double mean, double stddev, System.Random randomSource, double lowerBound = double.NegativeInfinity, double upperBound = double.PositiveInfinity)
{
@@ -86,28 +102,30 @@ namespace MathNet.Numerics.Distributions {
}
_random = randomSource ?? SystemRandomSource.Default;
- _mean = mean;
- _stdDev = stddev;
+ _mu = mean;
+ _sigma = stddev;
_lowerBound = lowerBound;
_upperBound = upperBound;
- _uncorrectedNormal = Normal.WithMeanStdDev(_mean, _stdDev);
- _cumulativeDensityWithinBounds = _uncorrectedNormal.CumulativeDistribution(_upperBound) - _uncorrectedNormal.CumulativeDistribution(_lowerBound);
+ _alpha = (_lowerBound - _mu) / _sigma;
+ _beta = (_upperBound - _mu) / _sigma;
+
+ _cumulativeDensityWithinBounds = _standardNormal.CumulativeDistribution(_beta) - _standardNormal.CumulativeDistribution(_alpha);
}
///
/// Tests whether the provided values are valid parameters for this distribution.
///
/// The mean (μ) of the normal distribution.
- /// The standard deviation (σ) of the normal distribution. Range: σ ≥ 0.
+ /// The standard deviation (σ) of the normal distribution. Range: σ > 0.
public static bool IsValidParameterSet(double mean, double stddev, double lowerBound, double upperBound)
{
- bool normalRequirements = Normal.IsValidParameterSet(mean, stddev);
+ bool normalRequirements = Normal.IsValidParameterSet(mean, stddev) && stddev > 0;
bool boundsAreOrdered = lowerBound < upperBound;
return normalRequirements && boundsAreOrdered;
}
public override string ToString() {
- return "TruncatedNormal(μ = " + _mean + ", σ = " + _stdDev +", LowerBound = " + _lowerBound + ", UpperBound = " + _upperBound + ")";
+ return "TruncatedNormal(μ = " + _mu + ", σ = " + _sigma +", LowerBound = " + _lowerBound + ", UpperBound = " + _upperBound + ")";
}
///
@@ -117,11 +135,11 @@ namespace MathNet.Numerics.Distributions {
{
get
{
- if (_mean < _lowerBound)
+ if (_mu < _lowerBound)
return _lowerBound;
- if (_mean > _upperBound)
+ if (_mu > _upperBound)
return _upperBound;
- return _mean;
+ return _mu;
}
}
@@ -148,9 +166,9 @@ namespace MathNet.Numerics.Distributions {
{
get
{
- var pdfDifference = _uncorrectedNormal.Density(_lowerBound) - _uncorrectedNormal.Density(_upperBound);
- var diffFromUncorrected = pdfDifference * _stdDev / _cumulativeDensityWithinBounds;
- return _mean + diffFromUncorrected;
+ var pdfDifference = _standardNormal.Density(_alpha) - _standardNormal.Density(_beta);
+ var diffFromUncorrected = pdfDifference * _sigma / _cumulativeDensityWithinBounds;
+ return _mu + diffFromUncorrected;
}
}
@@ -161,24 +179,31 @@ namespace MathNet.Numerics.Distributions {
{
get
{
+ //Apparently "Barr and Sherrill (1999)" has a simpler expression for one sided truncations, if anyone has access...
+
//TODO might need special handling for cases where either or both bounds are infinity
+ var densityAtLower = double.IsNegativeInfinity(_lowerBound) ? 0.0 : _standardNormal.Density(_alpha);
+ var densityAtUpper = double.IsPositiveInfinity(_upperBound) ? 0.0 : _standardNormal.Density(_beta);
+
+ var standardisedLower = double.IsNegativeInfinity(_lowerBound) ? 0.0 : _alpha;
+ var standardisedUpper = double.IsPositiveInfinity(_upperBound) ? 0.0 : _beta;
//Second term
- var secondNumerator = _lowerBound * _uncorrectedNormal.Density(_lowerBound) - _upperBound * _uncorrectedNormal.Density(_upperBound);
+ var secondNumerator = standardisedLower * densityAtLower - standardisedUpper * densityAtUpper;
var secordTerm = secondNumerator / _cumulativeDensityWithinBounds;
//Third term
- var thirdNumerator = _uncorrectedNormal.Density(_lowerBound) - _uncorrectedNormal.Density(_upperBound);
+ var thirdNumerator = densityAtLower - densityAtUpper;
var thirdTerm = (thirdNumerator / _cumulativeDensityWithinBounds) * (thirdNumerator / _cumulativeDensityWithinBounds);
var sumOfTerms = 1 + secordTerm + thirdTerm;
- return _stdDev * _stdDev * sumOfTerms;
+ return _sigma * _sigma * sumOfTerms;
}
}
///
- /// Gets the standard deviation (σ) of the truncated normal distribution. Range: σ ≥ 0.
+ /// Gets the standard deviation (σ) of the truncated normal distribution. Range: σ > 0.
///
public double StdDev
{
@@ -192,9 +217,9 @@ namespace MathNet.Numerics.Distributions {
{
get
{
- var firstTerm = Constants.LogSqrt2PiE + Math.Log(_stdDev + _cumulativeDensityWithinBounds);
+ var firstTerm = Constants.LogSqrt2PiE + Math.Log(_sigma + _cumulativeDensityWithinBounds);
- var secondNumerator = _lowerBound * _uncorrectedNormal.Density(_lowerBound) - _upperBound * _uncorrectedNormal.Density(_upperBound);
+ var secondNumerator = _lowerBound * _standardNormal.Density(_alpha) - _upperBound * _standardNormal.Density(_beta);
var secondTerm = secondNumerator / (2 * _cumulativeDensityWithinBounds);
return firstTerm + secondTerm;
@@ -240,7 +265,7 @@ namespace MathNet.Numerics.Distributions {
if (x < _lowerBound || _upperBound < x)
return 0d;
- return _uncorrectedNormal.Density(x) / (_stdDev * _cumulativeDensityWithinBounds);
+ return _standardNormal.Density((x - _mu) / _sigma) / (_sigma * _cumulativeDensityWithinBounds);
}
///
@@ -251,11 +276,11 @@ namespace MathNet.Numerics.Distributions {
///
public double DensityLn(double x)
{
- return Math.Log(Density(x));
- }
+ return _standardNormal.DensityLn((x - _mu) / _sigma) - Math.Log(_sigma) - Math.Log(_cumulativeDensityWithinBounds);
+ }
//TODO: implement sampling, use method described by Mazet here: http://miv.u-strasbg.fr/mazet/rtnorm/
- // see implmentations listed on that page for examples.
+ // see implementations listed on that page for examples.
public double Sample()
{
@@ -285,7 +310,7 @@ namespace MathNet.Numerics.Distributions {
if (x > _upperBound)
return 1d;
- double cumulative = _uncorrectedNormal.CumulativeDistribution(x) - _uncorrectedNormal.CumulativeDistribution(_lowerBound);
+ double cumulative = _standardNormal.CumulativeDistribution((x - _mu) / _sigma) - _standardNormal.CumulativeDistribution(_alpha);
return cumulative / _cumulativeDensityWithinBounds;
}
@@ -299,9 +324,9 @@ namespace MathNet.Numerics.Distributions {
public double InverseCumulativeDistribution(double p)
{
//TODO check that this is correct with someone.
- var pUntruncated = p * _cumulativeDensityWithinBounds + _uncorrectedNormal.CumulativeDistribution(_lowerBound);
+ var pUntruncated = p * _cumulativeDensityWithinBounds + _standardNormal.CumulativeDistribution(_alpha);
- return _uncorrectedNormal.InverseCumulativeDistribution(pUntruncated);
+ return _standardNormal.InverseCumulativeDistribution(pUntruncated) * _sigma + _mu;
}
}
diff --git a/src/UnitTests/DistributionTests/CommonDistributionTests.cs b/src/UnitTests/DistributionTests/CommonDistributionTests.cs
index 601c8b60..e0991dd6 100644
--- a/src/UnitTests/DistributionTests/CommonDistributionTests.cs
+++ b/src/UnitTests/DistributionTests/CommonDistributionTests.cs
@@ -89,6 +89,8 @@ namespace MathNet.Numerics.UnitTests.DistributionTests
new StudentT(0.0, 1.0, 5.0),
new Triangular(0, 1, 0.7),
new Weibull(1.0, 1.0),
+ new TruncatedNormal(0, 1.0, -5.0, 5.0), //Finite
+ new TruncatedNormal(0, 1.0, -5.0), //Semi-finite
};
[Test]
diff --git a/src/UnitTests/DistributionTests/Continuous/TruncatedNormalTests.cs b/src/UnitTests/DistributionTests/Continuous/TruncatedNormalTests.cs
new file mode 100644
index 00000000..7605c165
--- /dev/null
+++ b/src/UnitTests/DistributionTests/Continuous/TruncatedNormalTests.cs
@@ -0,0 +1,198 @@
+//
+// Math.NET Numerics, part of the Math.NET Project
+// http://numerics.mathdotnet.com
+// http://github.com/mathnet/mathnet-numerics
+// http://mathnetnumerics.codeplex.com
+//
+// Copyright (c) 2009-2014 Math.NET
+//
+// Permission is hereby granted, free of charge, to any person
+// obtaining a copy of this software and associated documentation
+// files (the "Software"), to deal in the Software without
+// restriction, including without limitation the rights to use,
+// copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the
+// Software is furnished to do so, subject to the following
+// conditions:
+//
+// The above copyright notice and this permission notice shall be
+// included in all copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
+// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+// WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+// OTHER DEALINGS IN THE SOFTWARE.
+//
+
+using MathNet.Numerics.Distributions;
+using NUnit.Framework;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+
+namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous
+{
+ [TestFixture, Category("Distributions")]
+ public class TruncatedNormalTests
+ {
+
+ ///
+ /// Can create a truncated normal without bounds.
+ ///
+ [TestCase(0.0, 0.0)]
+ [TestCase(10.0, 0.1)]
+ [TestCase(-5.0, 1.0)]
+ [TestCase(0.0, 10.0)]
+ [TestCase(10.0, 100.0)]
+ [TestCase(-5.0, Double.PositiveInfinity)]
+ public void CanCreateWithoutBounds(double mean, double stdDev)
+ {
+ var truncatedNormal = new TruncatedNormal(mean, stdDev);
+ Assert.IsTrue(double.IsNegativeInfinity(truncatedNormal.Minimum));
+ Assert.IsTrue(double.IsPositiveInfinity(truncatedNormal.Maximum));
+ Assert.AreEqual(mean, truncatedNormal.Mean);
+ Assert.AreEqual(stdDev, truncatedNormal.StdDev);
+ }
+
+ ///
+ /// Constructor fails with negative stdDev or incorrectly ordered bounds.
+ ///
+ ///
+ ///
+ ///
+ ///
+ [TestCase(0.0, -1.0,-10d, 10d)]
+ [TestCase(0.0, 1.0, 10d, 9d)]
+ public void TruncatedNormalCreateFailsWithBadParameters(double mean, double stdDev, double lower, double upper)
+ {
+ Assert.That(() => new TruncatedNormal(mean, stdDev, lower, upper), Throws.ArgumentException);
+ }
+
+ [TestCase(0.0, 1.0, double.NegativeInfinity, double.PositiveInfinity)]
+ [TestCase(0.0, 1.0, -5.0, 5.0)]
+ [TestCase(-6.0, 1.0, -5.0, 5.0)]
+ [TestCase(8.0, 1.0, -5.0, 5.0)]
+ [TestCase(15, 20.0, -20.0, 0.0)]
+ public void ValidateMode(double mean, double stdDev, double lower, double upper)
+ {
+ double mode;
+ if(mean < lower) {
+ mode = lower;
+ } else if(mean <= upper) {
+ mode = mean;
+ } else {
+ mode = upper;
+ }
+ var truncatedNormal = new TruncatedNormal(mean, stdDev, lower, upper);
+ Assert.AreEqual(mode, truncatedNormal.Mode);
+ }
+
+ ///
+ /// Validate cumulative distribution. Uses the same test cases as for the normal distribution
+ /// as they should be equivalent.
+ ///
+ /// Input X value.
+ /// Expected value.
+ [TestCase(Double.NegativeInfinity, 0.0)]
+ [TestCase(-5.0, 0.00000028665157187919391167375233287464535385442301361187883)]
+ [TestCase(-2.0, 0.0002326290790355250363499258867279847735487493358890356)]
+ [TestCase(-0.0, 0.0062096653257761351669781045741922211278977469230927036)]
+ [TestCase(0.0, .0062096653257761351669781045741922211278977469230927036)]
+ [TestCase(4.0, .30853753872598689636229538939166226011639782444542207)]
+ [TestCase(5.0, .5)]
+ [TestCase(6.0, .69146246127401310363770461060833773988360217555457859)]
+ [TestCase(10.0, 0.9937903346742238648330218954258077788721022530769078)]
+ [TestCase(Double.PositiveInfinity, 1.0)]
+ public void ValidateCumulativeNoBounds(double x, double p) {
+ var truncatedNormal = new TruncatedNormal(5.0, 2.0);
+ AssertHelpers.AlmostEqualRelative(p, truncatedNormal.CumulativeDistribution(x), 14);
+ }
+
+ ///
+ /// Validate inverse cumulative distribution. Uses the same test cases as for the normal distribution
+ /// as they should be equivalent.
+ ///
+ /// Expected value
+ /// Input quantile
+ [TestCase(Double.NegativeInfinity, 0.0)]
+ [TestCase(-5.0, 0.00000028665157187919391167375233287464535385442301361187883)]
+ [TestCase(-2.0, 0.0002326290790355250363499258867279847735487493358890356)]
+ [TestCase(-0.0, 0.0062096653257761351669781045741922211278977469230927036)]
+ [TestCase(0.0, 0.0062096653257761351669781045741922211278977469230927036)]
+ [TestCase(4.0, .30853753872598689636229538939166226011639782444542207)]
+ [TestCase(5.0, .5)]
+ [TestCase(6.0, .69146246127401310363770461060833773988360217555457859)]
+ [TestCase(10.0, 0.9937903346742238648330218954258077788721022530769078)]
+ [TestCase(Double.PositiveInfinity, 1.0)]
+ public void ValidateInverseCumulativeNoBounds(double x, double p)
+ {
+ var truncatedNormal = new TruncatedNormal(5.0, 2.0);
+ AssertHelpers.AlmostEqualRelative(x, truncatedNormal.InverseCumulativeDistribution(p), 14);
+ }
+
+ ///
+ /// Validate density when no bounds are specified. Uses same
+ /// test cases as the Normal distribution as should be equivalent in this case.
+ ///
+ /// Mean value.
+ /// Standard deviation value.
+ [TestCase(10.0, 0.1)]
+ [TestCase(-5.0, 1.0)]
+ [TestCase(0.0, 10.0)]
+ [TestCase(10.0, 100.0)]
+ [TestCase(-5.0, Double.PositiveInfinity)]
+ public void ValidateDensityNoBounds(double mean, double sdev) {
+ var n = new TruncatedNormal(mean, sdev);
+ for (var i = 0; i < 11; i++) {
+ var x = i - 5.0;
+ var d = (mean - x) / sdev;
+ var pdf = Math.Exp(-0.5 * d * d) / (sdev * Constants.Sqrt2Pi);
+ AssertHelpers.AlmostEqualRelative(pdf, n.Density(x), 14);
+ }
+ }
+
+ ///
+ /// Validate density log when no bounds are specified. Uses same
+ /// test cases as the Normal distribution as should be equivalent in this case.
+ ///
+ /// Mean value.
+ /// Standard deviation value.
+ [TestCase(10.0, 0.1)]
+ [TestCase(-5.0, 1.0)]
+ [TestCase(0.0, 10.0)]
+ [TestCase(10.0, 100.0)]
+ [TestCase(-5.0, Double.PositiveInfinity)]
+ public void ValidateDensityLnNoBounds(double mean, double sdev) {
+ var n = new TruncatedNormal(mean, sdev);
+ for (var i = 0; i < 11; i++) {
+ var x = i - 5.0;
+ var d = (mean - x) / sdev;
+ var pdfln = (-0.5 * (d * d)) - Math.Log(sdev) - Constants.LogSqrt2Pi;
+ AssertHelpers.AlmostEqualRelative(pdfln, n.DensityLn(x), 14);
+ }
+ }
+
+ [Test]
+ public void CanSample()
+ {
+ var truncatedNormal = new TruncatedNormal(5.0, 2.0, -10, 10.0);
+ truncatedNormal.Sample();
+ }
+
+
+ ///
+ /// Can sample sequence.
+ ///
+ [Test]
+ public void CanSampleSequence() {
+ var truncatedNormal = new TruncatedNormal(5.0, 2.0, -10, 10.0);
+ var ied = truncatedNormal.Samples();
+ GC.KeepAlive(ied.Take(5).ToArray());
+ }
+ }
+}
diff --git a/src/UnitTests/UnitTests.csproj b/src/UnitTests/UnitTests.csproj
index 91b8405d..d1c1dddd 100644
--- a/src/UnitTests/UnitTests.csproj
+++ b/src/UnitTests/UnitTests.csproj
@@ -109,6 +109,7 @@
+