Browse Source

RootFinding: align algorithms, no need for objects

pull/121/head
Christoph Ruegg 13 years ago
parent
commit
fd18c32f25
  1. 93
      src/Numerics/RootFinding/Algorithms/Bisection.cs
  2. 54
      src/Numerics/RootFinding/Algorithms/Brent.cs
  3. 4
      src/Numerics/RootFinding/FloatingPointRoots.cs
  4. 5
      src/UnitTests/RootFindingTests/BisectionTest.cs

93
src/Numerics/RootFinding/Algorithms/Bisection.cs

@ -32,76 +32,62 @@ using System;
namespace MathNet.Numerics.RootFinding.Algorithms
{
public class Bisection
public static class Bisection
{
public Bisection(double objective_tolerance = 1e-5, double x_tolerance = 1e-5, double lower_expansion_factor = -1.0, double upper_expansion_factor = -1.0, int max_expansion_steps = 10)
/// <summary>Find a solution of the equation f(x)=0.</summary>
public static double FindRoot(Func<double, double> f, double lowerBound, double upperBound, double fTolerance = 1e-5, double xTolerance = 1e-5, double lowerExpansionFactor = -1.0, double upperExpansionFactor = -1.0, int maxExpansionSteps = 10)
{
ObjectiveTolerance = objective_tolerance;
XTolerance = x_tolerance;
LowerExpansionFactor = lower_expansion_factor;
UpperExpansionFactor = upper_expansion_factor;
MaxExpansionSteps = max_expansion_steps;
}
public double ObjectiveTolerance { get; set; }
public double XTolerance { get; set; }
public double LowerExpansionFactor { get; set; }
public double UpperExpansionFactor { get; set; }
public int MaxExpansionSteps { get; set; }
double fmin = f(lowerBound);
double fmax = f(upperBound);
public double FindRoot(Func<double, double> objective_function, double lower_bound, double upper_bound)
{
double lower_val = objective_function(lower_bound);
double upper_val = objective_function(upper_bound);
if (fmin == 0.0)
return lowerBound;
if (fmax == 0.0)
return upperBound;
if (lower_val == 0.0)
return lower_bound;
if (upper_val == 0.0)
return upper_bound;
ValidateEvaluation(fmin, lowerBound);
ValidateEvaluation(fmax, upperBound);
ValidateEvaluation(lower_val, lower_bound);
ValidateEvaluation(upper_val, upper_bound);
if (Math.Sign(lower_val) == Math.Sign(upper_val) && LowerExpansionFactor <= 1.0 && UpperExpansionFactor <= 1.0)
if (Math.Sign(fmin) == Math.Sign(fmax) && lowerExpansionFactor <= 1.0 && upperExpansionFactor <= 1.0)
throw new Exception("Bounds do not necessarily span a root, and StepExpansionFactor is not set to expand the interval in this case.");
int expansion_steps = 0;
while (Math.Sign(lower_val) == Math.Sign(upper_val) && expansion_steps < MaxExpansionSteps)
int expansionSteps = 0;
while (Math.Sign(fmin) == Math.Sign(fmax) && expansionSteps < maxExpansionSteps)
{
double range = upper_bound - lower_bound;
if (UpperExpansionFactor <= 0.0 || (LowerExpansionFactor > 0.0 && Math.Abs(lower_val) < Math.Abs(upper_val)))
double range = upperBound - lowerBound;
if (upperExpansionFactor <= 0.0 || (lowerExpansionFactor > 0.0 && Math.Abs(fmin) < Math.Abs(fmax)))
{
lower_bound = upper_bound - LowerExpansionFactor*range;
lower_val = objective_function(lower_bound);
ValidateEvaluation(lower_val, lower_bound);
lowerBound = upperBound - lowerExpansionFactor * range;
fmin = f(lowerBound);
ValidateEvaluation(fmin, lowerBound);
}
else
{
upper_bound = lower_bound + UpperExpansionFactor*range;
upper_val = objective_function(upper_bound);
ValidateEvaluation(upper_val, upper_bound);
upperBound = lowerBound + upperExpansionFactor * range;
fmax = f(upperBound);
ValidateEvaluation(fmax, upperBound);
}
expansion_steps += 1;
expansionSteps += 1;
}
if (expansion_steps == MaxExpansionSteps)
if (expansionSteps == maxExpansionSteps)
throw new NonConvergenceException();
while (Math.Abs(upper_val - lower_val) > 0.5*ObjectiveTolerance || Math.Abs(upper_bound - lower_bound) > 0.5*XTolerance)
while (Math.Abs(fmax - fmin) > 0.5 * fTolerance || Math.Abs(upperBound - lowerBound) > 0.5 * xTolerance)
{
double midpoint = 0.5*(upper_bound + lower_bound);
double midval = objective_function(midpoint);
double midpoint = 0.5*(upperBound + lowerBound);
double midval = f(midpoint);
ValidateEvaluation(midval, midpoint);
if (Math.Sign(midval) == Math.Sign(lower_val))
if (Math.Sign(midval) == Math.Sign(fmin))
{
lower_bound = midpoint;
lower_val = midval;
lowerBound = midpoint;
fmin = midval;
}
else if (Math.Sign(midval) == Math.Sign(upper_val))
else if (Math.Sign(midval) == Math.Sign(fmax))
{
upper_bound = midpoint;
upper_val = midval;
upperBound = midpoint;
fmax = midval;
}
else
{
@ -109,18 +95,15 @@ namespace MathNet.Numerics.RootFinding.Algorithms
}
}
return 0.5*(lower_bound + upper_bound);
return 0.5*(lowerBound + upperBound);
}
void ValidateEvaluation(double output, double input)
static void ValidateEvaluation(double output, double input)
{
if (!IsFinite(output))
if (Double.IsInfinity(output) || Double.IsInfinity(output))
{
throw new Exception(String.Format("Objective function returned non-finite result: f({0}) = {1}", input, output));
}
static bool IsFinite(double x)
{
return !(Double.IsInfinity(x) || Double.IsNaN(x));
}
}
}
}

54
src/Numerics/RootFinding/Algorithms/Brent.cs

@ -36,8 +36,8 @@ namespace MathNet.Numerics.RootFinding.Algorithms
{
/// <summary>Find a solution of the equation f(x)=0.</summary>
/// <param name="f">The function to find roots from.</param>
/// <param name="xmin">The low value of the range where the root is supposed to be.</param>
/// <param name="xmax">The high value of the range where the root is supposed to be.</param>
/// <param name="lowerBound">The low value of the range where the root is supposed to be.</param>
/// <param name="upperBound">The high value of the range where the root is supposed to be.</param>
/// <param name="accuracy">Desired accuracy. The root will be refined until the accuracy or the maximum number of iterations is reached.</param>
/// <param name="maxIterations">Maximum number of iterations. Usually 100.</param>
/// <returns>Returns the root with the specified accuracy.</returns>
@ -46,58 +46,58 @@ namespace MathNet.Numerics.RootFinding.Algorithms
/// Implementation inspired by Press, Teukolsky, Vetterling, and Flannery, "Numerical Recipes in C", 2nd edition, Cambridge University Press
/// </remarks>
/// <exception cref="NonConvergenceException"></exception>
public static double FindRoot(Func<double, double> f, double xmin, double xmax, double accuracy = 1e-8, int maxIterations = 100)
public static double FindRoot(Func<double, double> f, double lowerBound, double upperBound, double accuracy = 1e-8, int maxIterations = 100)
{
double fxmin = f(xmin);
double fxmax = f(xmax);
double root = xmax;
double froot = fxmax;
double fmin = f(lowerBound);
double fmax = f(upperBound);
double root = upperBound;
double froot = fmax;
double d = 0.0, e = 0.0;
for (int i = 0; i <= maxIterations; i++)
{
// adjust bounds
if (Math.Sign(froot) == Math.Sign(fxmax))
if (Math.Sign(froot) == Math.Sign(fmax))
{
xmax = xmin;
fxmax = fxmin;
e = d = root - xmin;
upperBound = lowerBound;
fmax = fmin;
e = d = root - lowerBound;
}
if (Math.Abs(fxmax) < Math.Abs(froot))
if (Math.Abs(fmax) < Math.Abs(froot))
{
xmin = root;
root = xmax;
xmax = xmin;
fxmin = froot;
froot = fxmax;
fxmax = fxmin;
lowerBound = root;
root = upperBound;
upperBound = lowerBound;
fmin = froot;
froot = fmax;
fmax = fmin;
}
// convergence check
double xAcc1 = 2.0*Precision.DoubleMachinePrecision*Math.Abs(root) + 0.5*accuracy;
double xMid = (xmax - root)/2.0;
double xMid = (upperBound - root)/2.0;
if (Math.Abs(xMid) <= xAcc1 || froot.AlmostEqualWithAbsoluteError(0, froot, accuracy))
{
return root;
}
if (Math.Abs(e) >= xAcc1 && Math.Abs(fxmin) > Math.Abs(froot))
if (Math.Abs(e) >= xAcc1 && Math.Abs(fmin) > Math.Abs(froot))
{
// Attempt inverse quadratic interpolation
double s = froot/fxmin;
double s = froot/fmin;
double p;
double q;
if (xmin.AlmostEqual(xmax))
if (lowerBound.AlmostEqual(upperBound))
{
p = 2.0*xMid*s;
q = 1.0 - s;
}
else
{
q = fxmin/fxmax;
double r = froot/fxmax;
p = s*(2.0*xMid*q*(q - r) - (root - xmin)*(r - 1.0));
q = fmin/fmax;
double r = froot/fmax;
p = s*(2.0*xMid*q*(q - r) - (root - lowerBound)*(r - 1.0));
q = (q - 1.0)*(r - 1.0)*(s - 1.0);
}
@ -126,8 +126,8 @@ namespace MathNet.Numerics.RootFinding.Algorithms
d = xMid;
e = d;
}
xmin = root;
fxmin = froot;
lowerBound = root;
fmin = froot;
if (Math.Abs(d) > xAcc1)
{
root += d;

4
src/Numerics/RootFinding/FloatingPointRoots.cs

@ -35,9 +35,9 @@ namespace MathNet.Numerics.RootFinding
{
public static class FloatingPointRoots
{
public static double Find(Func<double, double> f, double xmin, double xmax)
public static double OfFunction(Func<double, double> f, double lowerBound, double upperBound)
{
return Brent.FindRoot(f, xmin, xmax, 1e-8, 100);
return Brent.FindRoot(f, lowerBound, upperBound, 1e-8, 100);
}
}
}

5
src/UnitTests/RootFindingTests/BisectionTest.cs

@ -40,14 +40,13 @@ namespace MathNet.Numerics.UnitTests.RootFindingTests
[Test]
public void MultipleRoots()
{
var algorithm = new Bisection(0.001, 0.001);
var f1 = new Func<double, double>(x => (x - 3)*(x - 4));
double r1 = algorithm.FindRoot(f1, 2.1, 3.9);
double r1 = Bisection.FindRoot(f1, 2.1, 3.9, 0.001, 0.001);
Assert.That(Math.Abs(f1(r1)), Is.LessThan(0.001));
Assert.That(Math.Abs(r1 - 3.0), Is.LessThan(0.001));
var f2 = new Func<double, double>(x => (x - 3)*(x - 4));
double r2 = algorithm.FindRoot(f1, 2.1, 3.4);
double r2 = Bisection.FindRoot(f1, 2.1, 3.4, 0.001, 0.001);
Assert.That(Math.Abs(f2(r2)), Is.LessThan(0.001));
Assert.That(Math.Abs(r2 - 3.0), Is.LessThan(0.001));
}

Loading…
Cancel
Save