diff --git a/src/Numerics/Distributions/TruncatedNormal.cs b/src/Numerics/Distributions/TruncatedNormal.cs index 72432c23..b8419593 100644 --- a/src/Numerics/Distributions/TruncatedNormal.cs +++ b/src/Numerics/Distributions/TruncatedNormal.cs @@ -40,15 +40,31 @@ namespace MathNet.Numerics.Distributions { /// For more details about this distribution, see /// Wikipedia - Truncated normal distribution /// - public class TruncatedNormal : IContinuousDistribution { + public class TruncatedNormal : IContinuousDistribution + { System.Random _random; - readonly double _mean; - readonly double _stdDev; + /// + /// Mean of the untruncated normal distribution. + /// + readonly double _mu; + /// + /// Standard deviation of the uncorrected normal distribution. + /// + readonly double _sigma; readonly double _lowerBound; readonly double _upperBound; - readonly Normal _uncorrectedNormal; + readonly Normal _standardNormal = new Normal(0.0, 1.0); + /// + /// Position in the standard normal distribution of the lower bound. + /// + readonly double _alpha; + /// + /// Position in the standard normal distribution of the upper bound. + /// + readonly double _beta; + /// /// The total density of the uncorrected normal distribution which is within the lower and upper bounds. /// Referred to as "Z" in the wikipedia equations. Z = Φ(UpperBound) - Φ(LowerBound). @@ -61,7 +77,7 @@ namespace MathNet.Numerics.Distributions { /// normal distribution. /// /// The mean (μ) of the untruncated distribution. - /// The standard deviation (σ) of the untruncated distribution. Range: σ ≥ 0. + /// The standard deviation (σ) of the untruncated distribution. Range: σ > 0. /// The inclusive lower bound of the truncated distribution. Default is double.NegativeInfinity. /// The inclusive upper bound of the truncated distribution. Must be larger than . /// Default is double.PositiveInfinity. @@ -76,7 +92,7 @@ namespace MathNet.Numerics.Distributions { /// be initialized with the default random number generator. /// /// The mean (μ) of the normal distribution. - /// The standard deviation (σ) of the normal distribution. Range: σ ≥ 0. + /// The standard deviation (σ) of the normal distribution. Range: σ > 0. /// The random number generator which is used to draw random samples. public TruncatedNormal(double mean, double stddev, System.Random randomSource, double lowerBound = double.NegativeInfinity, double upperBound = double.PositiveInfinity) { @@ -86,28 +102,30 @@ namespace MathNet.Numerics.Distributions { } _random = randomSource ?? SystemRandomSource.Default; - _mean = mean; - _stdDev = stddev; + _mu = mean; + _sigma = stddev; _lowerBound = lowerBound; _upperBound = upperBound; - _uncorrectedNormal = Normal.WithMeanStdDev(_mean, _stdDev); - _cumulativeDensityWithinBounds = _uncorrectedNormal.CumulativeDistribution(_upperBound) - _uncorrectedNormal.CumulativeDistribution(_lowerBound); + _alpha = (_lowerBound - _mu) / _sigma; + _beta = (_upperBound - _mu) / _sigma; + + _cumulativeDensityWithinBounds = _standardNormal.CumulativeDistribution(_beta) - _standardNormal.CumulativeDistribution(_alpha); } /// /// Tests whether the provided values are valid parameters for this distribution. /// /// The mean (μ) of the normal distribution. - /// The standard deviation (σ) of the normal distribution. Range: σ ≥ 0. + /// The standard deviation (σ) of the normal distribution. Range: σ > 0. public static bool IsValidParameterSet(double mean, double stddev, double lowerBound, double upperBound) { - bool normalRequirements = Normal.IsValidParameterSet(mean, stddev); + bool normalRequirements = Normal.IsValidParameterSet(mean, stddev) && stddev > 0; bool boundsAreOrdered = lowerBound < upperBound; return normalRequirements && boundsAreOrdered; } public override string ToString() { - return "TruncatedNormal(μ = " + _mean + ", σ = " + _stdDev +", LowerBound = " + _lowerBound + ", UpperBound = " + _upperBound + ")"; + return "TruncatedNormal(μ = " + _mu + ", σ = " + _sigma +", LowerBound = " + _lowerBound + ", UpperBound = " + _upperBound + ")"; } /// @@ -117,11 +135,11 @@ namespace MathNet.Numerics.Distributions { { get { - if (_mean < _lowerBound) + if (_mu < _lowerBound) return _lowerBound; - if (_mean > _upperBound) + if (_mu > _upperBound) return _upperBound; - return _mean; + return _mu; } } @@ -148,9 +166,9 @@ namespace MathNet.Numerics.Distributions { { get { - var pdfDifference = _uncorrectedNormal.Density(_lowerBound) - _uncorrectedNormal.Density(_upperBound); - var diffFromUncorrected = pdfDifference * _stdDev / _cumulativeDensityWithinBounds; - return _mean + diffFromUncorrected; + var pdfDifference = _standardNormal.Density(_alpha) - _standardNormal.Density(_beta); + var diffFromUncorrected = pdfDifference * _sigma / _cumulativeDensityWithinBounds; + return _mu + diffFromUncorrected; } } @@ -161,24 +179,31 @@ namespace MathNet.Numerics.Distributions { { get { + //Apparently "Barr and Sherrill (1999)" has a simpler expression for one sided truncations, if anyone has access... + //TODO might need special handling for cases where either or both bounds are infinity + var densityAtLower = double.IsNegativeInfinity(_lowerBound) ? 0.0 : _standardNormal.Density(_alpha); + var densityAtUpper = double.IsPositiveInfinity(_upperBound) ? 0.0 : _standardNormal.Density(_beta); + + var standardisedLower = double.IsNegativeInfinity(_lowerBound) ? 0.0 : _alpha; + var standardisedUpper = double.IsPositiveInfinity(_upperBound) ? 0.0 : _beta; //Second term - var secondNumerator = _lowerBound * _uncorrectedNormal.Density(_lowerBound) - _upperBound * _uncorrectedNormal.Density(_upperBound); + var secondNumerator = standardisedLower * densityAtLower - standardisedUpper * densityAtUpper; var secordTerm = secondNumerator / _cumulativeDensityWithinBounds; //Third term - var thirdNumerator = _uncorrectedNormal.Density(_lowerBound) - _uncorrectedNormal.Density(_upperBound); + var thirdNumerator = densityAtLower - densityAtUpper; var thirdTerm = (thirdNumerator / _cumulativeDensityWithinBounds) * (thirdNumerator / _cumulativeDensityWithinBounds); var sumOfTerms = 1 + secordTerm + thirdTerm; - return _stdDev * _stdDev * sumOfTerms; + return _sigma * _sigma * sumOfTerms; } } /// - /// Gets the standard deviation (σ) of the truncated normal distribution. Range: σ ≥ 0. + /// Gets the standard deviation (σ) of the truncated normal distribution. Range: σ > 0. /// public double StdDev { @@ -192,9 +217,9 @@ namespace MathNet.Numerics.Distributions { { get { - var firstTerm = Constants.LogSqrt2PiE + Math.Log(_stdDev + _cumulativeDensityWithinBounds); + var firstTerm = Constants.LogSqrt2PiE + Math.Log(_sigma + _cumulativeDensityWithinBounds); - var secondNumerator = _lowerBound * _uncorrectedNormal.Density(_lowerBound) - _upperBound * _uncorrectedNormal.Density(_upperBound); + var secondNumerator = _lowerBound * _standardNormal.Density(_alpha) - _upperBound * _standardNormal.Density(_beta); var secondTerm = secondNumerator / (2 * _cumulativeDensityWithinBounds); return firstTerm + secondTerm; @@ -240,7 +265,7 @@ namespace MathNet.Numerics.Distributions { if (x < _lowerBound || _upperBound < x) return 0d; - return _uncorrectedNormal.Density(x) / (_stdDev * _cumulativeDensityWithinBounds); + return _standardNormal.Density((x - _mu) / _sigma) / (_sigma * _cumulativeDensityWithinBounds); } /// @@ -251,11 +276,11 @@ namespace MathNet.Numerics.Distributions { /// public double DensityLn(double x) { - return Math.Log(Density(x)); - } + return _standardNormal.DensityLn((x - _mu) / _sigma) - Math.Log(_sigma) - Math.Log(_cumulativeDensityWithinBounds); + } //TODO: implement sampling, use method described by Mazet here: http://miv.u-strasbg.fr/mazet/rtnorm/ - // see implmentations listed on that page for examples. + // see implementations listed on that page for examples. public double Sample() { @@ -285,7 +310,7 @@ namespace MathNet.Numerics.Distributions { if (x > _upperBound) return 1d; - double cumulative = _uncorrectedNormal.CumulativeDistribution(x) - _uncorrectedNormal.CumulativeDistribution(_lowerBound); + double cumulative = _standardNormal.CumulativeDistribution((x - _mu) / _sigma) - _standardNormal.CumulativeDistribution(_alpha); return cumulative / _cumulativeDensityWithinBounds; } @@ -299,9 +324,9 @@ namespace MathNet.Numerics.Distributions { public double InverseCumulativeDistribution(double p) { //TODO check that this is correct with someone. - var pUntruncated = p * _cumulativeDensityWithinBounds + _uncorrectedNormal.CumulativeDistribution(_lowerBound); + var pUntruncated = p * _cumulativeDensityWithinBounds + _standardNormal.CumulativeDistribution(_alpha); - return _uncorrectedNormal.InverseCumulativeDistribution(pUntruncated); + return _standardNormal.InverseCumulativeDistribution(pUntruncated) * _sigma + _mu; } } diff --git a/src/UnitTests/DistributionTests/CommonDistributionTests.cs b/src/UnitTests/DistributionTests/CommonDistributionTests.cs index 601c8b60..e0991dd6 100644 --- a/src/UnitTests/DistributionTests/CommonDistributionTests.cs +++ b/src/UnitTests/DistributionTests/CommonDistributionTests.cs @@ -89,6 +89,8 @@ namespace MathNet.Numerics.UnitTests.DistributionTests new StudentT(0.0, 1.0, 5.0), new Triangular(0, 1, 0.7), new Weibull(1.0, 1.0), + new TruncatedNormal(0, 1.0, -5.0, 5.0), //Finite + new TruncatedNormal(0, 1.0, -5.0), //Semi-finite }; [Test] diff --git a/src/UnitTests/DistributionTests/Continuous/TruncatedNormalTests.cs b/src/UnitTests/DistributionTests/Continuous/TruncatedNormalTests.cs new file mode 100644 index 00000000..7605c165 --- /dev/null +++ b/src/UnitTests/DistributionTests/Continuous/TruncatedNormalTests.cs @@ -0,0 +1,198 @@ +// +// Math.NET Numerics, part of the Math.NET Project +// http://numerics.mathdotnet.com +// http://github.com/mathnet/mathnet-numerics +// http://mathnetnumerics.codeplex.com +// +// Copyright (c) 2009-2014 Math.NET +// +// Permission is hereby granted, free of charge, to any person +// obtaining a copy of this software and associated documentation +// files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, +// copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +// WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +// OTHER DEALINGS IN THE SOFTWARE. +// + +using MathNet.Numerics.Distributions; +using NUnit.Framework; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous +{ + [TestFixture, Category("Distributions")] + public class TruncatedNormalTests + { + + /// + /// Can create a truncated normal without bounds. + /// + [TestCase(0.0, 0.0)] + [TestCase(10.0, 0.1)] + [TestCase(-5.0, 1.0)] + [TestCase(0.0, 10.0)] + [TestCase(10.0, 100.0)] + [TestCase(-5.0, Double.PositiveInfinity)] + public void CanCreateWithoutBounds(double mean, double stdDev) + { + var truncatedNormal = new TruncatedNormal(mean, stdDev); + Assert.IsTrue(double.IsNegativeInfinity(truncatedNormal.Minimum)); + Assert.IsTrue(double.IsPositiveInfinity(truncatedNormal.Maximum)); + Assert.AreEqual(mean, truncatedNormal.Mean); + Assert.AreEqual(stdDev, truncatedNormal.StdDev); + } + + /// + /// Constructor fails with negative stdDev or incorrectly ordered bounds. + /// + /// + /// + /// + /// + [TestCase(0.0, -1.0,-10d, 10d)] + [TestCase(0.0, 1.0, 10d, 9d)] + public void TruncatedNormalCreateFailsWithBadParameters(double mean, double stdDev, double lower, double upper) + { + Assert.That(() => new TruncatedNormal(mean, stdDev, lower, upper), Throws.ArgumentException); + } + + [TestCase(0.0, 1.0, double.NegativeInfinity, double.PositiveInfinity)] + [TestCase(0.0, 1.0, -5.0, 5.0)] + [TestCase(-6.0, 1.0, -5.0, 5.0)] + [TestCase(8.0, 1.0, -5.0, 5.0)] + [TestCase(15, 20.0, -20.0, 0.0)] + public void ValidateMode(double mean, double stdDev, double lower, double upper) + { + double mode; + if(mean < lower) { + mode = lower; + } else if(mean <= upper) { + mode = mean; + } else { + mode = upper; + } + var truncatedNormal = new TruncatedNormal(mean, stdDev, lower, upper); + Assert.AreEqual(mode, truncatedNormal.Mode); + } + + /// + /// Validate cumulative distribution. Uses the same test cases as for the normal distribution + /// as they should be equivalent. + /// + /// Input X value. + /// Expected value. + [TestCase(Double.NegativeInfinity, 0.0)] + [TestCase(-5.0, 0.00000028665157187919391167375233287464535385442301361187883)] + [TestCase(-2.0, 0.0002326290790355250363499258867279847735487493358890356)] + [TestCase(-0.0, 0.0062096653257761351669781045741922211278977469230927036)] + [TestCase(0.0, .0062096653257761351669781045741922211278977469230927036)] + [TestCase(4.0, .30853753872598689636229538939166226011639782444542207)] + [TestCase(5.0, .5)] + [TestCase(6.0, .69146246127401310363770461060833773988360217555457859)] + [TestCase(10.0, 0.9937903346742238648330218954258077788721022530769078)] + [TestCase(Double.PositiveInfinity, 1.0)] + public void ValidateCumulativeNoBounds(double x, double p) { + var truncatedNormal = new TruncatedNormal(5.0, 2.0); + AssertHelpers.AlmostEqualRelative(p, truncatedNormal.CumulativeDistribution(x), 14); + } + + /// + /// Validate inverse cumulative distribution. Uses the same test cases as for the normal distribution + /// as they should be equivalent. + /// + /// Expected value + /// Input quantile + [TestCase(Double.NegativeInfinity, 0.0)] + [TestCase(-5.0, 0.00000028665157187919391167375233287464535385442301361187883)] + [TestCase(-2.0, 0.0002326290790355250363499258867279847735487493358890356)] + [TestCase(-0.0, 0.0062096653257761351669781045741922211278977469230927036)] + [TestCase(0.0, 0.0062096653257761351669781045741922211278977469230927036)] + [TestCase(4.0, .30853753872598689636229538939166226011639782444542207)] + [TestCase(5.0, .5)] + [TestCase(6.0, .69146246127401310363770461060833773988360217555457859)] + [TestCase(10.0, 0.9937903346742238648330218954258077788721022530769078)] + [TestCase(Double.PositiveInfinity, 1.0)] + public void ValidateInverseCumulativeNoBounds(double x, double p) + { + var truncatedNormal = new TruncatedNormal(5.0, 2.0); + AssertHelpers.AlmostEqualRelative(x, truncatedNormal.InverseCumulativeDistribution(p), 14); + } + + /// + /// Validate density when no bounds are specified. Uses same + /// test cases as the Normal distribution as should be equivalent in this case. + /// + /// Mean value. + /// Standard deviation value. + [TestCase(10.0, 0.1)] + [TestCase(-5.0, 1.0)] + [TestCase(0.0, 10.0)] + [TestCase(10.0, 100.0)] + [TestCase(-5.0, Double.PositiveInfinity)] + public void ValidateDensityNoBounds(double mean, double sdev) { + var n = new TruncatedNormal(mean, sdev); + for (var i = 0; i < 11; i++) { + var x = i - 5.0; + var d = (mean - x) / sdev; + var pdf = Math.Exp(-0.5 * d * d) / (sdev * Constants.Sqrt2Pi); + AssertHelpers.AlmostEqualRelative(pdf, n.Density(x), 14); + } + } + + /// + /// Validate density log when no bounds are specified. Uses same + /// test cases as the Normal distribution as should be equivalent in this case. + /// + /// Mean value. + /// Standard deviation value. + [TestCase(10.0, 0.1)] + [TestCase(-5.0, 1.0)] + [TestCase(0.0, 10.0)] + [TestCase(10.0, 100.0)] + [TestCase(-5.0, Double.PositiveInfinity)] + public void ValidateDensityLnNoBounds(double mean, double sdev) { + var n = new TruncatedNormal(mean, sdev); + for (var i = 0; i < 11; i++) { + var x = i - 5.0; + var d = (mean - x) / sdev; + var pdfln = (-0.5 * (d * d)) - Math.Log(sdev) - Constants.LogSqrt2Pi; + AssertHelpers.AlmostEqualRelative(pdfln, n.DensityLn(x), 14); + } + } + + [Test] + public void CanSample() + { + var truncatedNormal = new TruncatedNormal(5.0, 2.0, -10, 10.0); + truncatedNormal.Sample(); + } + + + /// + /// Can sample sequence. + /// + [Test] + public void CanSampleSequence() { + var truncatedNormal = new TruncatedNormal(5.0, 2.0, -10, 10.0); + var ied = truncatedNormal.Samples(); + GC.KeepAlive(ied.Take(5).ToArray()); + } + } +} diff --git a/src/UnitTests/UnitTests.csproj b/src/UnitTests/UnitTests.csproj index 91b8405d..d1c1dddd 100644 --- a/src/UnitTests/UnitTests.csproj +++ b/src/UnitTests/UnitTests.csproj @@ -109,6 +109,7 @@ +