From fd18c32f25a8f47e4f9daa7b785ce533e6737715 Mon Sep 17 00:00:00 2001 From: Christoph Ruegg Date: Fri, 3 May 2013 13:55:18 +0200 Subject: [PATCH] RootFinding: align algorithms, no need for objects --- .../RootFinding/Algorithms/Bisection.cs | 93 ++++++++----------- src/Numerics/RootFinding/Algorithms/Brent.cs | 54 +++++------ .../RootFinding/FloatingPointRoots.cs | 4 +- .../RootFindingTests/BisectionTest.cs | 5 +- 4 files changed, 69 insertions(+), 87 deletions(-) diff --git a/src/Numerics/RootFinding/Algorithms/Bisection.cs b/src/Numerics/RootFinding/Algorithms/Bisection.cs index ccc491e0..6941f1b2 100644 --- a/src/Numerics/RootFinding/Algorithms/Bisection.cs +++ b/src/Numerics/RootFinding/Algorithms/Bisection.cs @@ -32,76 +32,62 @@ using System; namespace MathNet.Numerics.RootFinding.Algorithms { - public class Bisection + public static class Bisection { - public Bisection(double objective_tolerance = 1e-5, double x_tolerance = 1e-5, double lower_expansion_factor = -1.0, double upper_expansion_factor = -1.0, int max_expansion_steps = 10) + /// Find a solution of the equation f(x)=0. + public static double FindRoot(Func f, double lowerBound, double upperBound, double fTolerance = 1e-5, double xTolerance = 1e-5, double lowerExpansionFactor = -1.0, double upperExpansionFactor = -1.0, int maxExpansionSteps = 10) { - ObjectiveTolerance = objective_tolerance; - XTolerance = x_tolerance; - LowerExpansionFactor = lower_expansion_factor; - UpperExpansionFactor = upper_expansion_factor; - MaxExpansionSteps = max_expansion_steps; - } - - public double ObjectiveTolerance { get; set; } - public double XTolerance { get; set; } - public double LowerExpansionFactor { get; set; } - public double UpperExpansionFactor { get; set; } - public int MaxExpansionSteps { get; set; } + double fmin = f(lowerBound); + double fmax = f(upperBound); - public double FindRoot(Func objective_function, double lower_bound, double upper_bound) - { - double lower_val = objective_function(lower_bound); - double upper_val = objective_function(upper_bound); + if (fmin == 0.0) + return lowerBound; + if (fmax == 0.0) + return upperBound; - if (lower_val == 0.0) - return lower_bound; - if (upper_val == 0.0) - return upper_bound; + ValidateEvaluation(fmin, lowerBound); + ValidateEvaluation(fmax, upperBound); - ValidateEvaluation(lower_val, lower_bound); - ValidateEvaluation(upper_val, upper_bound); - - if (Math.Sign(lower_val) == Math.Sign(upper_val) && LowerExpansionFactor <= 1.0 && UpperExpansionFactor <= 1.0) + if (Math.Sign(fmin) == Math.Sign(fmax) && lowerExpansionFactor <= 1.0 && upperExpansionFactor <= 1.0) throw new Exception("Bounds do not necessarily span a root, and StepExpansionFactor is not set to expand the interval in this case."); - int expansion_steps = 0; - while (Math.Sign(lower_val) == Math.Sign(upper_val) && expansion_steps < MaxExpansionSteps) + int expansionSteps = 0; + while (Math.Sign(fmin) == Math.Sign(fmax) && expansionSteps < maxExpansionSteps) { - double range = upper_bound - lower_bound; - if (UpperExpansionFactor <= 0.0 || (LowerExpansionFactor > 0.0 && Math.Abs(lower_val) < Math.Abs(upper_val))) + double range = upperBound - lowerBound; + if (upperExpansionFactor <= 0.0 || (lowerExpansionFactor > 0.0 && Math.Abs(fmin) < Math.Abs(fmax))) { - lower_bound = upper_bound - LowerExpansionFactor*range; - lower_val = objective_function(lower_bound); - ValidateEvaluation(lower_val, lower_bound); + lowerBound = upperBound - lowerExpansionFactor * range; + fmin = f(lowerBound); + ValidateEvaluation(fmin, lowerBound); } else { - upper_bound = lower_bound + UpperExpansionFactor*range; - upper_val = objective_function(upper_bound); - ValidateEvaluation(upper_val, upper_bound); + upperBound = lowerBound + upperExpansionFactor * range; + fmax = f(upperBound); + ValidateEvaluation(fmax, upperBound); } - expansion_steps += 1; + expansionSteps += 1; } - if (expansion_steps == MaxExpansionSteps) + if (expansionSteps == maxExpansionSteps) throw new NonConvergenceException(); - while (Math.Abs(upper_val - lower_val) > 0.5*ObjectiveTolerance || Math.Abs(upper_bound - lower_bound) > 0.5*XTolerance) + while (Math.Abs(fmax - fmin) > 0.5 * fTolerance || Math.Abs(upperBound - lowerBound) > 0.5 * xTolerance) { - double midpoint = 0.5*(upper_bound + lower_bound); - double midval = objective_function(midpoint); + double midpoint = 0.5*(upperBound + lowerBound); + double midval = f(midpoint); ValidateEvaluation(midval, midpoint); - if (Math.Sign(midval) == Math.Sign(lower_val)) + if (Math.Sign(midval) == Math.Sign(fmin)) { - lower_bound = midpoint; - lower_val = midval; + lowerBound = midpoint; + fmin = midval; } - else if (Math.Sign(midval) == Math.Sign(upper_val)) + else if (Math.Sign(midval) == Math.Sign(fmax)) { - upper_bound = midpoint; - upper_val = midval; + upperBound = midpoint; + fmax = midval; } else { @@ -109,18 +95,15 @@ namespace MathNet.Numerics.RootFinding.Algorithms } } - return 0.5*(lower_bound + upper_bound); + return 0.5*(lowerBound + upperBound); } - void ValidateEvaluation(double output, double input) + static void ValidateEvaluation(double output, double input) { - if (!IsFinite(output)) + if (Double.IsInfinity(output) || Double.IsInfinity(output)) + { throw new Exception(String.Format("Objective function returned non-finite result: f({0}) = {1}", input, output)); - } - - static bool IsFinite(double x) - { - return !(Double.IsInfinity(x) || Double.IsNaN(x)); + } } } } diff --git a/src/Numerics/RootFinding/Algorithms/Brent.cs b/src/Numerics/RootFinding/Algorithms/Brent.cs index dcfb8e9f..5ef3aa29 100644 --- a/src/Numerics/RootFinding/Algorithms/Brent.cs +++ b/src/Numerics/RootFinding/Algorithms/Brent.cs @@ -36,8 +36,8 @@ namespace MathNet.Numerics.RootFinding.Algorithms { /// Find a solution of the equation f(x)=0. /// The function to find roots from. - /// The low value of the range where the root is supposed to be. - /// The high value of the range where the root is supposed to be. + /// The low value of the range where the root is supposed to be. + /// The high value of the range where the root is supposed to be. /// Desired accuracy. The root will be refined until the accuracy or the maximum number of iterations is reached. /// Maximum number of iterations. Usually 100. /// Returns the root with the specified accuracy. @@ -46,58 +46,58 @@ namespace MathNet.Numerics.RootFinding.Algorithms /// Implementation inspired by Press, Teukolsky, Vetterling, and Flannery, "Numerical Recipes in C", 2nd edition, Cambridge University Press /// /// - public static double FindRoot(Func f, double xmin, double xmax, double accuracy = 1e-8, int maxIterations = 100) + public static double FindRoot(Func f, double lowerBound, double upperBound, double accuracy = 1e-8, int maxIterations = 100) { - double fxmin = f(xmin); - double fxmax = f(xmax); - double root = xmax; - double froot = fxmax; + double fmin = f(lowerBound); + double fmax = f(upperBound); + double root = upperBound; + double froot = fmax; double d = 0.0, e = 0.0; for (int i = 0; i <= maxIterations; i++) { // adjust bounds - if (Math.Sign(froot) == Math.Sign(fxmax)) + if (Math.Sign(froot) == Math.Sign(fmax)) { - xmax = xmin; - fxmax = fxmin; - e = d = root - xmin; + upperBound = lowerBound; + fmax = fmin; + e = d = root - lowerBound; } - if (Math.Abs(fxmax) < Math.Abs(froot)) + if (Math.Abs(fmax) < Math.Abs(froot)) { - xmin = root; - root = xmax; - xmax = xmin; - fxmin = froot; - froot = fxmax; - fxmax = fxmin; + lowerBound = root; + root = upperBound; + upperBound = lowerBound; + fmin = froot; + froot = fmax; + fmax = fmin; } // convergence check double xAcc1 = 2.0*Precision.DoubleMachinePrecision*Math.Abs(root) + 0.5*accuracy; - double xMid = (xmax - root)/2.0; + double xMid = (upperBound - root)/2.0; if (Math.Abs(xMid) <= xAcc1 || froot.AlmostEqualWithAbsoluteError(0, froot, accuracy)) { return root; } - if (Math.Abs(e) >= xAcc1 && Math.Abs(fxmin) > Math.Abs(froot)) + if (Math.Abs(e) >= xAcc1 && Math.Abs(fmin) > Math.Abs(froot)) { // Attempt inverse quadratic interpolation - double s = froot/fxmin; + double s = froot/fmin; double p; double q; - if (xmin.AlmostEqual(xmax)) + if (lowerBound.AlmostEqual(upperBound)) { p = 2.0*xMid*s; q = 1.0 - s; } else { - q = fxmin/fxmax; - double r = froot/fxmax; - p = s*(2.0*xMid*q*(q - r) - (root - xmin)*(r - 1.0)); + q = fmin/fmax; + double r = froot/fmax; + p = s*(2.0*xMid*q*(q - r) - (root - lowerBound)*(r - 1.0)); q = (q - 1.0)*(r - 1.0)*(s - 1.0); } @@ -126,8 +126,8 @@ namespace MathNet.Numerics.RootFinding.Algorithms d = xMid; e = d; } - xmin = root; - fxmin = froot; + lowerBound = root; + fmin = froot; if (Math.Abs(d) > xAcc1) { root += d; diff --git a/src/Numerics/RootFinding/FloatingPointRoots.cs b/src/Numerics/RootFinding/FloatingPointRoots.cs index 0c24ce3c..7e67492b 100644 --- a/src/Numerics/RootFinding/FloatingPointRoots.cs +++ b/src/Numerics/RootFinding/FloatingPointRoots.cs @@ -35,9 +35,9 @@ namespace MathNet.Numerics.RootFinding { public static class FloatingPointRoots { - public static double Find(Func f, double xmin, double xmax) + public static double OfFunction(Func f, double lowerBound, double upperBound) { - return Brent.FindRoot(f, xmin, xmax, 1e-8, 100); + return Brent.FindRoot(f, lowerBound, upperBound, 1e-8, 100); } } } diff --git a/src/UnitTests/RootFindingTests/BisectionTest.cs b/src/UnitTests/RootFindingTests/BisectionTest.cs index 7f2cf8db..c99452e9 100644 --- a/src/UnitTests/RootFindingTests/BisectionTest.cs +++ b/src/UnitTests/RootFindingTests/BisectionTest.cs @@ -40,14 +40,13 @@ namespace MathNet.Numerics.UnitTests.RootFindingTests [Test] public void MultipleRoots() { - var algorithm = new Bisection(0.001, 0.001); var f1 = new Func(x => (x - 3)*(x - 4)); - double r1 = algorithm.FindRoot(f1, 2.1, 3.9); + double r1 = Bisection.FindRoot(f1, 2.1, 3.9, 0.001, 0.001); Assert.That(Math.Abs(f1(r1)), Is.LessThan(0.001)); Assert.That(Math.Abs(r1 - 3.0), Is.LessThan(0.001)); var f2 = new Func(x => (x - 3)*(x - 4)); - double r2 = algorithm.FindRoot(f1, 2.1, 3.4); + double r2 = Bisection.FindRoot(f1, 2.1, 3.4, 0.001, 0.001); Assert.That(Math.Abs(f2(r2)), Is.LessThan(0.001)); Assert.That(Math.Abs(r2 - 3.0), Is.LessThan(0.001)); }