From f36f7581d9f8488c88f5134c6a422b9f6b6bd330 Mon Sep 17 00:00:00 2001 From: Aixile Date: Thu, 18 Aug 2016 23:01:17 +0900 Subject: [PATCH] Fix recursion issue of CacheObliviousMatrixMultiply --- .../ManagedLinearAlgebraProvider.Complex.cs | 2 +- .../ManagedLinearAlgebraProvider.Complex32.cs | 2 +- .../ManagedLinearAlgebraProvider.Double.cs | 64 +++++++++---------- .../ManagedLinearAlgebraProvider.Single.cs | 2 +- 4 files changed, 35 insertions(+), 35 deletions(-) diff --git a/src/Numerics/Providers/LinearAlgebra/ManagedLinearAlgebraProvider.Complex.cs b/src/Numerics/Providers/LinearAlgebra/ManagedLinearAlgebraProvider.Complex.cs index 5f78e840..30d887d7 100644 --- a/src/Numerics/Providers/LinearAlgebra/ManagedLinearAlgebraProvider.Complex.cs +++ b/src/Numerics/Providers/LinearAlgebra/ManagedLinearAlgebraProvider.Complex.cs @@ -666,7 +666,7 @@ namespace MathNet.Numerics.Providers.LinearAlgebra /// Indicates if this is the first recursion. static void CacheObliviousMatrixMultiply(Transpose transposeA, Transpose transposeB, Complex alpha, Complex[] matrixA, int shiftArow, int shiftAcol, Complex[] matrixB, int shiftBrow, int shiftBcol, Complex[] result, int shiftCrow, int shiftCcol, int m, int n, int k, int constM, int constN, int constK, bool first) { - if (m + n <= Control.ParallelizeOrder) + if (m + n <= Control.ParallelizeOrder || m == 1 || n == 1 || k == 1) { if ((int) transposeA > 111 && (int) transposeB > 111) { diff --git a/src/Numerics/Providers/LinearAlgebra/ManagedLinearAlgebraProvider.Complex32.cs b/src/Numerics/Providers/LinearAlgebra/ManagedLinearAlgebraProvider.Complex32.cs index f2152d56..a0351631 100644 --- a/src/Numerics/Providers/LinearAlgebra/ManagedLinearAlgebraProvider.Complex32.cs +++ b/src/Numerics/Providers/LinearAlgebra/ManagedLinearAlgebraProvider.Complex32.cs @@ -663,7 +663,7 @@ namespace MathNet.Numerics.Providers.LinearAlgebra /// Indicates if this is the first recursion. static void CacheObliviousMatrixMultiply(Transpose transposeA, Transpose transposeB, Complex32 alpha, Complex32[] matrixA, int shiftArow, int shiftAcol, Complex32[] matrixB, int shiftBrow, int shiftBcol, Complex32[] result, int shiftCrow, int shiftCcol, int m, int n, int k, int constM, int constN, int constK, bool first) { - if (m + n <= Control.ParallelizeOrder) + if (m + n <= Control.ParallelizeOrder || m == 1 || n == 1 || k == 1) { if ((int) transposeA > 111 && (int) transposeB > 111) { diff --git a/src/Numerics/Providers/LinearAlgebra/ManagedLinearAlgebraProvider.Double.cs b/src/Numerics/Providers/LinearAlgebra/ManagedLinearAlgebraProvider.Double.cs index 7b0bb469..05a0da68 100644 --- a/src/Numerics/Providers/LinearAlgebra/ManagedLinearAlgebraProvider.Double.cs +++ b/src/Numerics/Providers/LinearAlgebra/ManagedLinearAlgebraProvider.Double.cs @@ -175,7 +175,7 @@ namespace MathNet.Numerics.Providers.LinearAlgebra for (var index = 0; index < y.Length; index++) { - sum += y[index]*x[index]; + sum += y[index] * x[index]; } return sum; @@ -366,7 +366,7 @@ namespace MathNet.Numerics.Providers.LinearAlgebra var s = 0.0; for (var i = 0; i < rows; i++) { - s += Math.Abs(matrix[(j*rows) + i]); + s += Math.Abs(matrix[(j * rows) + i]); } norm1 = Math.Max(norm1, s); } @@ -383,25 +383,25 @@ namespace MathNet.Numerics.Providers.LinearAlgebra return normMax; case Norm.InfinityNorm: var r = new double[rows]; - for (var j = 0; j < columns; j++) - { - for (var i = 0; i < rows; i++) - { + for (var j = 0; j < columns; j++) + { + for (var i = 0; i < rows; i++) + { r[i] += Math.Abs(matrix[(j * rows) + i]); - } - } + } + } // TODO: reuse var max = r[0]; - for (int i = 0; i < r.Length; i++) - { - if (r[i] > max) - { - max = r[i]; - } - } - return max; + for (int i = 0; i < r.Length; i++) + { + if (r[i] > max) + { + max = r[i]; + } + } + return max; case Norm.FrobeniusNorm: - var aat = new double[rows*rows]; + var aat = new double[rows * rows]; MatrixMultiplyWithUpdate(Transpose.DontTranspose, Transpose.Transpose, 1.0, matrix, rows, columns, matrix, rows, columns, 0.0, aat); var normF = 0d; for (var i = 0; i < rows; i++) @@ -444,12 +444,12 @@ namespace MathNet.Numerics.Providers.LinearAlgebra throw new ArgumentNullException("result"); } - if (rowsX*columnsX != x.Length) + if (rowsX * columnsX != x.Length) { throw new ArgumentException("x.Length != xRows * xColumns"); } - if (rowsY*columnsY != y.Length) + if (rowsY * columnsY != y.Length) { throw new ArgumentException("y.Length != yRows * yColumns"); } @@ -459,7 +459,7 @@ namespace MathNet.Numerics.Providers.LinearAlgebra throw new ArgumentException("xColumns != yRows"); } - if (rowsX*columnsY != result.Length) + if (rowsX * columnsY != result.Length) { throw new ArgumentException("xRows * yColumns != result.Length"); } @@ -470,7 +470,7 @@ namespace MathNet.Numerics.Providers.LinearAlgebra double[] xdata; if (ReferenceEquals(x, result)) { - xdata = (double[]) x.Clone(); + xdata = (double[])x.Clone(); } else { @@ -480,7 +480,7 @@ namespace MathNet.Numerics.Providers.LinearAlgebra double[] ydata; if (ReferenceEquals(y, result)) { - ydata = (double[]) y.Clone(); + ydata = (double[])y.Clone(); } else { @@ -523,14 +523,14 @@ namespace MathNet.Numerics.Providers.LinearAlgebra throw new ArgumentNullException("b"); } - if ((int) transposeA > 111 && (int) transposeB > 111) + if ((int)transposeA > 111 && (int)transposeB > 111) { if (rowsA != columnsB) { throw new ArgumentOutOfRangeException(); } - if (columnsA*rowsB != c.Length) + if (columnsA * rowsB != c.Length) { throw new ArgumentOutOfRangeException(); } @@ -539,14 +539,14 @@ namespace MathNet.Numerics.Providers.LinearAlgebra n = rowsB; k = rowsA; } - else if ((int) transposeA > 111) + else if ((int)transposeA > 111) { if (rowsA != rowsB) { throw new ArgumentOutOfRangeException(); } - if (columnsA*columnsB != c.Length) + if (columnsA * columnsB != c.Length) { throw new ArgumentOutOfRangeException(); } @@ -555,14 +555,14 @@ namespace MathNet.Numerics.Providers.LinearAlgebra n = columnsB; k = rowsA; } - else if ((int) transposeB > 111) + else if ((int)transposeB > 111) { if (columnsA != columnsB) { throw new ArgumentOutOfRangeException(); } - if (rowsA*rowsB != c.Length) + if (rowsA * rowsB != c.Length) { throw new ArgumentOutOfRangeException(); } @@ -578,7 +578,7 @@ namespace MathNet.Numerics.Providers.LinearAlgebra throw new ArgumentOutOfRangeException(); } - if (rowsA*columnsB != c.Length) + if (rowsA * columnsB != c.Length) { throw new ArgumentOutOfRangeException(); } @@ -600,7 +600,7 @@ namespace MathNet.Numerics.Providers.LinearAlgebra double[] adata; if (ReferenceEquals(a, c)) { - adata = (double[]) a.Clone(); + adata = (double[])a.Clone(); } else { @@ -610,7 +610,7 @@ namespace MathNet.Numerics.Providers.LinearAlgebra double[] bdata; if (ReferenceEquals(b, c)) { - bdata = (double[]) b.Clone(); + bdata = (double[])b.Clone(); } else { @@ -658,7 +658,7 @@ namespace MathNet.Numerics.Providers.LinearAlgebra /// Indicates if this is the first recursion. static void CacheObliviousMatrixMultiply(Transpose transposeA, Transpose transposeB, double alpha, double[] matrixA, int shiftArow, int shiftAcol, double[] matrixB, int shiftBrow, int shiftBcol, double[] result, int shiftCrow, int shiftCcol, int m, int n, int k, int constM, int constN, int constK, bool first) { - if (m + n <= Control.ParallelizeOrder) + if (m + n <= Control.ParallelizeOrder || m == 1 || n == 1 || k == 1) { if ((int) transposeA > 111 && (int) transposeB > 111) { diff --git a/src/Numerics/Providers/LinearAlgebra/ManagedLinearAlgebraProvider.Single.cs b/src/Numerics/Providers/LinearAlgebra/ManagedLinearAlgebraProvider.Single.cs index b9b73e4e..5435eeda 100644 --- a/src/Numerics/Providers/LinearAlgebra/ManagedLinearAlgebraProvider.Single.cs +++ b/src/Numerics/Providers/LinearAlgebra/ManagedLinearAlgebraProvider.Single.cs @@ -658,7 +658,7 @@ namespace MathNet.Numerics.Providers.LinearAlgebra /// Indicates if this is the first recursion. static void CacheObliviousMatrixMultiply(Transpose transposeA, Transpose transposeB, float alpha, float[] matrixA, int shiftArow, int shiftAcol, float[] matrixB, int shiftBrow, int shiftBcol, float[] result, int shiftCrow, int shiftCcol, int m, int n, int k, int constM, int constN, int constK, bool first) { - if (m + n <= Control.ParallelizeOrder) + if (m + n <= Control.ParallelizeOrder || m == 1 || n == 1 || k == 1) { if ((int) transposeA > 111 && (int) transposeB > 111) {