Browse Source

Statistics: TruncatedNormal minor formatting tweaks #344

truncatednormal
Christoph Ruegg 11 years ago
parent
commit
7e0b98c625
  1. 87
      src/Numerics/Distributions/TruncatedNormal.cs
  2. 4
      src/UnitTests/DistributionTests/CommonDistributionTests.cs
  3. 51
      src/UnitTests/DistributionTests/Continuous/TruncatedNormalTests.cs

87
src/Numerics/Distributions/TruncatedNormal.cs

@ -4,7 +4,7 @@
// http://github.com/mathnet/mathnet-numerics
// http://mathnetnumerics.codeplex.com
//
// Copyright (c) 2009-2013 Math.NET
// Copyright (c) 2009-2015 Math.NET
//
// Permission is hereby granted, free of charge, to any person
// obtaining a copy of this software and associated documentation
@ -42,24 +42,27 @@ namespace MathNet.Numerics.Distributions {
/// </summary>
public class TruncatedNormal : IContinuousDistribution
{
System.Random _random;
/// <summary>
/// Mean of the untruncated normal distribution.
/// </summary>
readonly double _mu;
readonly double _mu;
/// <summary>
/// Standard deviation of the uncorrected normal distribution.
/// </summary>
readonly double _sigma;
readonly double _lowerBound;
readonly double _upperBound;
readonly Normal _standardNormal = new Normal(0.0, 1.0);
/// <summary>
/// Position in the standard normal distribution of the lower bound.
/// </summary>
readonly double _alpha;
/// <summary>
/// Position in the standard normal distribution of the upper bound.
/// </summary>
@ -73,7 +76,7 @@ namespace MathNet.Numerics.Distributions {
/// <summary>
/// Initializes a new instance of the TruncatedNormal class. The distribution will
/// be initialized with the default <seealso cref="System.Random"/> random number generator. The mean
/// be initialized with the default <seealso cref="System.Random"/> random number generator. The mean
/// and standard deviation are that of the untruncated normal distribution.
/// </summary>
/// <param name="mean">The mean (μ) of the untruncated distribution.</param>
@ -81,10 +84,9 @@ namespace MathNet.Numerics.Distributions {
/// <param name="lowerBound">The inclusive lower bound of the truncated distribution. Default is double.NegativeInfinity.</param>
/// <param name="upperBound">The inclusive upper bound of the truncated distribution. Must be larger than <paramref name="lowerBound"/>.
/// Default is double.PositiveInfinity.</param>
public TruncatedNormal(double mean, double stddev, double lowerBound = double.NegativeInfinity, double upperBound = double.PositiveInfinity)
:this(mean, stddev, SystemRandomSource.Default, lowerBound, upperBound)
public TruncatedNormal(double mean, double stddev, double lowerBound = double.NegativeInfinity, double upperBound = double.PositiveInfinity)
: this(mean, stddev, SystemRandomSource.Default, lowerBound, upperBound)
{
}
/// <summary>
@ -98,9 +100,9 @@ namespace MathNet.Numerics.Distributions {
/// <param name="upperBound">The inclusive upper bound of the truncated distribution. Must be larger than <paramref name="lowerBound"/>.
/// Default is double.PositiveInfinity.</param>
public TruncatedNormal(double untruncatedMean, double untruncatedStdDev, System.Random randomSource, double lowerBound = double.NegativeInfinity, double upperBound = double.PositiveInfinity)
public TruncatedNormal(double untruncatedMean, double untruncatedStdDev, System.Random randomSource, double lowerBound = double.NegativeInfinity, double upperBound = double.PositiveInfinity)
{
if (!IsValidParameterSet(untruncatedMean, untruncatedStdDev, lowerBound, upperBound))
if (!IsValidParameterSet(untruncatedMean, untruncatedStdDev, lowerBound, upperBound))
{
throw new ArgumentException(Resources.InvalidDistributionParameters);
}
@ -121,23 +123,24 @@ namespace MathNet.Numerics.Distributions {
/// </summary>
/// <param name="mean">The mean (μ) of the normal distribution.</param>
/// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ > 0.</param>
public static bool IsValidParameterSet(double mean, double stddev, double lowerBound, double upperBound)
public static bool IsValidParameterSet(double mean, double stddev, double lowerBound, double upperBound)
{
bool normalRequirements = Normal.IsValidParameterSet(mean, stddev) && stddev > 0;
bool boundsAreOrdered = lowerBound < upperBound;
return normalRequirements && boundsAreOrdered;
}
public override string ToString() {
public override string ToString()
{
return "TruncatedNormal(μ = " + _mu + ", σ = " + _sigma +", LowerBound = " + _lowerBound + ", UpperBound = " + _upperBound + ")";
}
/// <summary>
/// Gets the mode of the normal distribution.
/// </summary>
public double Mode
/// <summary>
/// Gets the mode of the truncated normal distribution.
/// </summary>
public double Mode
{
get
get
{
if (_mu < _lowerBound)
return _lowerBound;
@ -150,7 +153,7 @@ namespace MathNet.Numerics.Distributions {
/// <summary>
/// Gets the minimum of the truncated normal distribution.
/// </summary>
public double Minimum
public double Minimum
{
get { return _lowerBound; }
}
@ -158,7 +161,7 @@ namespace MathNet.Numerics.Distributions {
/// <summary>
/// Gets the maximum of the truncated normal distribution.
/// </summary>
public double Maximum
public double Maximum
{
get { return _upperBound; }
}
@ -166,9 +169,9 @@ namespace MathNet.Numerics.Distributions {
/// <summary>
/// Gets the mean (μ) of the truncated normal distribution.
/// </summary>
public double Mean
public double Mean
{
get
get
{
var pdfDifference = _standardNormal.Density(_alpha) - _standardNormal.Density(_beta);
var diffFromUncorrected = pdfDifference * _sigma / _cumulativeDensityWithinBounds;
@ -177,11 +180,11 @@ namespace MathNet.Numerics.Distributions {
}
/// <summary>
/// Gets the variance of the truncated normal distribution.
/// Gets the variance of the truncated normal distribution.
/// </summary>
public double Variance
public double Variance
{
get
get
{
//Apparently "Barr and Sherrill (1999)" has a simpler expression for one sided truncations, if anyone has access...
@ -209,7 +212,7 @@ namespace MathNet.Numerics.Distributions {
/// <summary>
/// Gets the standard deviation (σ) of the truncated normal distribution. Range: σ > 0.
/// </summary>
public double StdDev
public double StdDev
{
get { return Math.Sqrt(Variance); }
}
@ -217,9 +220,9 @@ namespace MathNet.Numerics.Distributions {
/// <summary>
/// Gets the entropy of the truncated normal distribution.
/// </summary>
public double Entropy
public double Entropy
{
get
get
{
var firstTerm = Constants.LogSqrt2PiE + Math.Log(_sigma + _cumulativeDensityWithinBounds);
@ -232,7 +235,7 @@ namespace MathNet.Numerics.Distributions {
public double Skewness
{
get
get
{
throw new NotImplementedException();
}
@ -241,9 +244,9 @@ namespace MathNet.Numerics.Distributions {
/// <summary>
/// Gets the median of the truncated distribution.
/// </summary>
public double Median
public double Median
{
get
get
{
return InverseCumulativeDistribution(0.5);
}
@ -252,7 +255,7 @@ namespace MathNet.Numerics.Distributions {
/// <summary>
/// Gets or sets the random number generator which is used to draw random samples.
/// </summary>
public System.Random RandomSource
public System.Random RandomSource
{
get { return _random; }
set { _random = value ?? SystemRandomSource.Default; }
@ -264,7 +267,7 @@ namespace MathNet.Numerics.Distributions {
/// <param name="x">The location at which to compute the density.</param>
/// <returns>the density at <paramref name="x"/>.</returns>
/// <seealso cref="PDF"/>
public double Density(double x)
public double Density(double x)
{
if (x < _lowerBound || _upperBound < x)
return 0d;
@ -278,29 +281,30 @@ namespace MathNet.Numerics.Distributions {
/// <param name="x">The location at which to compute the log density.</param>
/// <returns>the log density at <paramref name="x"/>.</returns>
/// <seealso cref="PDFLn"/>
public double DensityLn(double x)
public double DensityLn(double x)
{
return _standardNormal.DensityLn((x - _mu) / _sigma) - Math.Log(_sigma) - Math.Log(_cumulativeDensityWithinBounds);
}
}
public double Sample()
public double Sample()
{
//TODO: implement sampling more efficiently/accurately, use method described by Mazet here: http://miv.u-strasbg.fr/mazet/rtnorm/
// see implementations listed on that page for examples.
return InverseCumulativeDistribution(RandomSource.NextDouble());
}
public void Samples(double[] values)
public void Samples(double[] values)
{
for(int i = 0; i < values.Length; i++) {
for(int i = 0; i < values.Length; i++)
{
values[i] = Sample();
}
}
public IEnumerable<double> Samples()
public IEnumerable<double> Samples()
{
while (true) {
while (true)
{
yield return Sample();
}
}
@ -311,7 +315,7 @@ namespace MathNet.Numerics.Distributions {
/// <param name="x">The location at which to compute the cumulative distribution function.</param>
/// <returns>the cumulative distribution at location <paramref name="x"/>.</returns>
/// <seealso cref="CDF"/>
public double CumulativeDistribution(double x)
public double CumulativeDistribution(double x)
{
if (x < _lowerBound)
return 0d;
@ -329,13 +333,12 @@ namespace MathNet.Numerics.Distributions {
/// <param name="p">The location at which to compute the inverse cumulative density.</param>
/// <returns>the inverse cumulative density at <paramref name="p"/>.</returns>
/// <seealso cref="InvCDF"/>
public double InverseCumulativeDistribution(double p)
public double InverseCumulativeDistribution(double p)
{
//TODO check that this is correct with someone.
var pUntruncated = p * _cumulativeDensityWithinBounds + _standardNormal.CumulativeDistribution(_alpha);
return _standardNormal.InverseCumulativeDistribution(pUntruncated) * _sigma + _mu;
}
}
}

4
src/UnitTests/DistributionTests/CommonDistributionTests.cs

@ -89,8 +89,8 @@ namespace MathNet.Numerics.UnitTests.DistributionTests
new StudentT(0.0, 1.0, 5.0),
new Triangular(0, 1, 0.7),
new Weibull(1.0, 1.0),
new TruncatedNormal(0, 1.0, -1.0, 1.5), //Finite
new TruncatedNormal(0, 1.0, -0.5), //Semi-finite
new TruncatedNormal(0, 1.0, -1.0, 1.5), //Finite
new TruncatedNormal(0, 1.0, -0.5), //Semi-finite
};
[Test]

51
src/UnitTests/DistributionTests/Continuous/TruncatedNormalTests.cs

@ -4,7 +4,7 @@
// http://github.com/mathnet/mathnet-numerics
// http://mathnetnumerics.codeplex.com
//
// Copyright (c) 2009-2014 Math.NET
// Copyright (c) 2009-2015 Math.NET
//
// Permission is hereby granted, free of charge, to any person
// obtaining a copy of this software and associated documentation
@ -31,16 +31,13 @@
using MathNet.Numerics.Distributions;
using NUnit.Framework;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous
namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous
{
[TestFixture, Category("Distributions")]
public class TruncatedNormalTests
public class TruncatedNormalTests
{
/// <summary>
/// Can create a truncated normal without bounds.
/// </summary>
@ -50,7 +47,7 @@ namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous
[TestCase(0.0, 10.0)]
[TestCase(10.0, 100.0)]
[TestCase(-5.0, Double.PositiveInfinity)]
public void CanCreateWithoutBounds(double mean, double stdDev)
public void CanCreateWithoutBounds(double mean, double stdDev)
{
var truncatedNormal = new TruncatedNormal(mean, stdDev);
Assert.IsTrue(double.IsNegativeInfinity(truncatedNormal.Minimum));
@ -68,7 +65,7 @@ namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous
/// <param name="upper"></param>
[TestCase(0.0, -1.0,-10d, 10d)]
[TestCase(0.0, 1.0, 10d, 9d)]
public void TruncatedNormalCreateFailsWithBadParameters(double mean, double stdDev, double lower, double upper)
public void TruncatedNormalCreateFailsWithBadParameters(double mean, double stdDev, double lower, double upper)
{
Assert.That(() => new TruncatedNormal(mean, stdDev, lower, upper), Throws.ArgumentException);
}
@ -78,7 +75,7 @@ namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous
[TestCase(-6.0, 1.0, -5.0, 5.0)]
[TestCase(8.0, 1.0, -5.0, 5.0)]
[TestCase(15, 20.0, -20.0, 0.0)]
public void ValidateMode(double mean, double stdDev, double lower, double upper)
public void ValidateMode(double mean, double stdDev, double lower, double upper)
{
double mode;
if(mean < lower) {
@ -108,10 +105,10 @@ namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous
[TestCase(6.0, .69146246127401310363770461060833773988360217555457859)]
[TestCase(10.0, 0.9937903346742238648330218954258077788721022530769078)]
[TestCase(Double.PositiveInfinity, 1.0)]
public void ValidateCumulativeNoBounds(double x, double p)
public void ValidateCumulativeNoBounds(double x, double p)
{
var truncatedNormal = new TruncatedNormal(5.0, 2.0);
AssertHelpers.AlmostEqualRelative(p, truncatedNormal.CumulativeDistribution(x), 14);
AssertHelpers.AlmostEqualRelative(p, truncatedNormal.CumulativeDistribution(x), 14);
}
/// <summary>
@ -130,7 +127,7 @@ namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous
[TestCase(6.0, .69146246127401310363770461060833773988360217555457859)]
[TestCase(10.0, 0.9937903346742238648330218954258077788721022530769078)]
[TestCase(Double.PositiveInfinity, 1.0)]
public void ValidateInverseCumulativeNoBounds(double x, double p)
public void ValidateInverseCumulativeNoBounds(double x, double p)
{
var truncatedNormal = new TruncatedNormal(5.0, 2.0);
AssertHelpers.AlmostEqualRelative(x, truncatedNormal.InverseCumulativeDistribution(p), 14);
@ -147,7 +144,7 @@ namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous
[TestCase(0.0, 10.0)]
[TestCase(10.0, 100.0)]
[TestCase(-5.0, Double.PositiveInfinity)]
public void ValidateDensityNoBounds(double mean, double sdev)
public void ValidateDensityNoBounds(double mean, double sdev)
{
var n = new TruncatedNormal(mean, sdev);
for (var i = 0; i < 11; i++) {
@ -159,7 +156,7 @@ namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous
}
/// <summary>
/// Validate density when only one bound is are specified.
/// Validate density when only one bound is are specified.
/// </summary>
/// <param name="mean">Mean value.</param>
/// <param name="sdev">Standard deviation value.</param>
@ -168,11 +165,11 @@ namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous
[TestCase(0.0, 10.0, -10.0)]
[TestCase(10.0, 100.0, 15.0)]
[TestCase(-5.0, Double.PositiveInfinity, -5.0)]
public void ValidateDensitySemiFinite(double mean, double sdev, double lowerBound)
public void ValidateDensitySemiFinite(double mean, double sdev, double lowerBound)
{
var truncatedNormal = new TruncatedNormal(mean, sdev, lowerBound);
var normal = new Normal(mean, sdev);
for (var i = 0; i < 11; i++)
for (var i = 0; i < 11; i++)
{
var x = i - 5.0;
double density;
@ -180,7 +177,7 @@ namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous
{
density = 0d;
}
else
else
{
var d = (mean - x) / sdev;
var pdf = Math.Exp(-0.5 * d * d) / (sdev * Constants.Sqrt2Pi);
@ -200,19 +197,19 @@ namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous
[TestCase(0.0, 10.0, -10.0, 15.0)]
[TestCase(10.0, 100.0, 15.0, 100.0)]
[TestCase(-5.0, Double.PositiveInfinity, -5.0, 0.0)]
public void ValidateDensityFinite(double mean, double sdev, double lowerBound, double upperBound)
public void ValidateDensityFinite(double mean, double sdev, double lowerBound, double upperBound)
{
var truncatedNormal = new TruncatedNormal(mean, sdev, lowerBound, upperBound);
var normal = new Normal(mean, sdev);
for (var i = 0; i < 11; i++)
for (var i = 0; i < 11; i++)
{
var x = i - 5.0;
double density;
if (x < lowerBound || upperBound < x)
if (x < lowerBound || upperBound < x)
{
density = 0d;
}
else
}
else
{
var d = (mean - x) / sdev;
var pdf = Math.Exp(-0.5 * d * d) / (sdev * Constants.Sqrt2Pi);
@ -224,7 +221,7 @@ namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous
/// <summary>
/// Validate density log when no bounds are specified. Uses same
/// Validate density log when no bounds are specified. Uses same
/// test cases as the Normal distribution as should be equivalent in this case.
/// </summary>
/// <param name="mean">Mean value.</param>
@ -234,10 +231,10 @@ namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous
[TestCase(0.0, 10.0)]
[TestCase(10.0, 100.0)]
[TestCase(-5.0, Double.PositiveInfinity)]
public void ValidateDensityLnNoBounds(double mean, double sdev)
public void ValidateDensityLnNoBounds(double mean, double sdev)
{
var n = new TruncatedNormal(mean, sdev);
for (var i = 0; i < 11; i++)
for (var i = 0; i < 11; i++)
{
var x = i - 5.0;
var d = (mean - x) / sdev;
@ -247,7 +244,7 @@ namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous
}
[Test]
public void CanSample()
public void CanSample()
{
var truncatedNormal = new TruncatedNormal(5.0, 2.0, -10, 10.0);
truncatedNormal.Sample();
@ -258,7 +255,7 @@ namespace MathNet.Numerics.UnitTests.DistributionTests.Continuous
/// Can sample sequence.
/// </summary>
[Test]
public void CanSampleSequence()
public void CanSampleSequence()
{
var truncatedNormal = new TruncatedNormal(5.0, 2.0, -10, 10.0);
var ied = truncatedNormal.Samples();

Loading…
Cancel
Save