Browse Source

Change MatrixMultuplyWithUpdate to cache-oblivious multiplication.

pull/36/head
Abratiychuk 15 years ago
committed by Marcus Cuda
parent
commit
37526cb8d9
  1. 429
      src/Numerics/Algorithms/LinearAlgebra/ManagedLinearAlgebraProvider.Complex.cs
  2. 429
      src/Numerics/Algorithms/LinearAlgebra/ManagedLinearAlgebraProvider.Complex32.cs
  3. 380
      src/Numerics/Algorithms/LinearAlgebra/ManagedLinearAlgebraProvider.Double.cs
  4. 380
      src/Numerics/Algorithms/LinearAlgebra/ManagedLinearAlgebraProvider.Single.cs
  5. 26
      src/Numerics/Control.cs

429
src/Numerics/Algorithms/LinearAlgebra/ManagedLinearAlgebraProvider.Complex.cs

@ -461,9 +461,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
/// <param name="c">The c matrix.</param>
public virtual void MatrixMultiplyWithUpdate(Transpose transposeA, Transpose transposeB, Complex alpha, Complex[] a, int rowsA, int columnsA, Complex[] b, int rowsB, int columnsB, Complex beta, Complex[] c)
{
// Choose nonsensical values for the number of rows in c; fill them in depending
// on the operations on a and b.
int rowsC;
int m; // The number of rows of matrix op(A) and of the matrix C.
int n; // The number of columns of matrix op(B) and of the matrix C.
int k; // The number of columns of matrix op(A) and the rows of the matrix op(B).
// First check some basic requirement on the parameters of the matrix multiplication.
if (a == null)
@ -488,7 +488,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
throw new ArgumentOutOfRangeException();
}
rowsC = columnsA;
m = columnsA;
n = rowsB;
k = rowsA;
}
else if ((int)transposeA > 111)
{
@ -502,7 +504,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
throw new ArgumentOutOfRangeException();
}
rowsC = columnsA;
m = columnsA;
n = columnsB;
k = rowsA;
}
else if ((int)transposeB > 111)
{
@ -516,7 +520,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
throw new ArgumentOutOfRangeException();
}
rowsC = rowsA;
m = rowsA;
n = rowsB;
k = columnsA;
}
else
{
@ -530,7 +536,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
throw new ArgumentOutOfRangeException();
}
rowsC = rowsA;
m = rowsA;
n = columnsB;
k = columnsA;
}
if (alpha.IsZero() && beta.IsZero())
@ -562,268 +570,271 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
bdata = b;
}
if (alpha.IsOne())
if (beta.IsZero())
{
if (beta.IsZero())
Array.Clear(c, 0, c.Length);
}
else if (!beta.IsOne())
{
Control.LinearAlgebraProvider.ScaleArray(beta, c, c);
}
if (alpha.IsZero())
{
return;
}
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, adata, 0, 0, bdata, 0, 0, c, 0, 0, m, n, k, m, n, k, true);
}
/// <summary>
/// Cache-Oblivious Matrix Multiplication
/// </summary>
/// <param name="transposeA">if set to <c>true</c> transpose matrix A.</param>
/// <param name="transposeB">if set to <c>true</c> transpose matrix B.</param>
/// <param name="alpha">The value to scale the matrix A with.</param>
/// <param name="matrixA">The matrix A.</param>
/// <param name="shiftArow">Row-shift of the left matrix</param>
/// <param name="shiftAcol">Column-shift of the left matrix</param>
/// <param name="matrixB">The matrix B.</param>
/// <param name="shiftBrow">Row-shift of the right matrix</param>
/// <param name="shiftBcol">Column-shift of the right matrix</param>
/// <param name="result">The matrix C.</param>
/// <param name="shiftCrow">Row-shift of the result matrix</param>
/// <param name="shiftCcol">Column-shift of the result matrix</param>
/// <param name="m">The number of rows of matrix op(A) and of the matrix C.</param>
/// <param name="n">The number of columns of matrix op(B) and of the matrix C.</param>
/// <param name="k">The number of columns of matrix op(A) and the rows of the matrix op(B).</param>
/// <param name="constM">The constant number of rows of matrix op(A) and of the matrix C.</param>
/// <param name="constN">The constant number of columns of matrix op(B) and of the matrix C.</param>
/// <param name="constK">The constant number of columns of matrix op(A) and the rows of the matrix op(B).</param>
/// <param name="first">Indicates if this is the first recursion.</param>
private static void CacheObliviousMatrixMultiply(Transpose transposeA, Transpose transposeB, Complex alpha, Complex[] matrixA, int shiftArow, int shiftAcol, Complex[] matrixB, int shiftBrow, int shiftBcol, Complex[] result, int shiftCrow, int shiftCcol, int m, int n, int k, int constM, int constN, int constK, bool first)
{
if (m + n + k <= Control.ParallelizeOrder)
{
if ((int)transposeA > 111 && (int)transposeB > 111)
{
if ((int)transposeA > 111 && (int)transposeB > 111)
if ((int)transposeA > 112 && (int)transposeB > 112)
{
CommonParallel.For(
0,
columnsA,
j =>
for (var m1 = 0; m1 < m; m1++)
{
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsB; i++)
var matBcolPos = n1 + shiftBcol;
var sum = Complex.Zero;
for (var k1 = 0; k1 < k; ++k1)
{
var iIndex = i * rowsA;
Complex s = 0;
for (var l = 0; l != columnsB; l++)
{
s += adata[iIndex + l] * bdata[(l * rowsB) + j];
}
c[jIndex + i] = s;
sum += matrixA[(matArowPos * constK) + k1 + shiftAcol].Conjugate() *
matrixB[((k1 + shiftBrow) * constN) + matBcolPos].Conjugate();
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
}
else if ((int)transposeA > 111)
else if ((int)transposeA > 112)
{
CommonParallel.For(
0,
columnsB,
j =>
for (var m1 = 0; m1 < m; m1++)
{
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != columnsA; i++)
var matBcolPos = n1 + shiftBcol;
var sum = Complex.Zero;
for (var k1 = 0; k1 < k; ++k1)
{
var iIndex = i * rowsA;
Complex s = 0;
for (var l = 0; l != rowsA; l++)
{
s += adata[iIndex + l] * bdata[jbIndex + l];
}
c[jcIndex + i] = s;
sum += matrixA[(matArowPos * constK) + k1 + shiftAcol].Conjugate() *
matrixB[((k1 + shiftBrow) * constN) + matBcolPos];
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
}
else if ((int)transposeB > 111)
else if ((int)transposeB > 112)
{
CommonParallel.For(
0,
rowsB,
j =>
for (var m1 = 0; m1 < m; m1++)
{
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsA; i++)
var matBcolPos = n1 + shiftBcol;
var sum = Complex.Zero;
for (var k1 = 0; k1 < k; ++k1)
{
Complex s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[(l * rowsB) + j];
}
c[jIndex + i] = s;
sum += matrixA[(matArowPos * constK) + k1 + shiftAcol] *
matrixB[((k1 + shiftBrow) * constN) + matBcolPos].Conjugate();
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
}
else
{
CommonParallel.For(
0,
columnsB,
j =>
for (var m1 = 0; m1 < m; m1++)
{
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != rowsA; i++)
var matBcolPos = n1 + shiftBcol;
var sum = Complex.Zero;
for (var k1 = 0; k1 < k; ++k1)
{
Complex s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[jbIndex + l];
}
c[jcIndex + i] = s;
sum += matrixA[(matArowPos * constK) + k1 + shiftAcol] *
matrixB[((k1 + shiftBrow) * constN) + matBcolPos];
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
}
}
else
else if ((int)transposeA > 111)
{
if ((int)transposeA > 111 && (int)transposeB > 111)
{
CommonParallel.For(
0,
columnsA,
j =>
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsB; i++)
{
var iIndex = i * rowsA;
Complex s = 0;
for (var l = 0; l != columnsB; l++)
{
s += adata[iIndex + l] * bdata[(l * rowsB) + j];
}
c[jIndex + i] = (c[jIndex + i] * beta) + s;
}
});
}
else if ((int)transposeA > 111)
if ((int)transposeA > 112)
{
CommonParallel.For(
0,
columnsB,
j =>
for (var m1 = 0; m1 < m; m1++)
{
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != columnsA; i++)
var matBcolPos = n1 + shiftBcol;
var sum = Complex.Zero;
for (var k1 = 0; k1 < k; ++k1)
{
var iIndex = i * rowsA;
Complex s = 0;
for (var l = 0; l != rowsA; l++)
{
s += adata[iIndex + l] * bdata[jbIndex + l];
}
c[jcIndex + i] = s + (c[jcIndex + i] * beta);
sum += matrixA[(matArowPos * constK) + k1 + shiftAcol].Conjugate() *
matrixB[(matBcolPos * constK) + k1 + shiftBrow];
}
});
}
else if ((int)transposeB > 111)
{
CommonParallel.For(
0,
rowsB,
j =>
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsA; i++)
{
Complex s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[(l * rowsB) + j];
}
c[jIndex + i] = s + (c[jIndex + i] * beta);
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
}
else
{
CommonParallel.For(
0,
columnsB,
j =>
for (var m1 = 0; m1 < m; m1++)
{
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != rowsA; i++)
var matBcolPos = n1 + shiftBcol;
var sum = Complex.Zero;
for (var k1 = 0; k1 < k; ++k1)
{
Complex s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[jbIndex + l];
}
c[jcIndex + i] = s + (c[jcIndex + i] * beta);
sum += matrixA[(matArowPos * constK) + k1 + shiftAcol] *
matrixB[(matBcolPos * constK) + k1 + shiftBrow];
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
}
}
}
else
{
if ((int)transposeA > 111 && (int)transposeB > 111)
else if ((int)transposeB > 111)
{
CommonParallel.For(
0,
columnsA,
j =>
if ((int)transposeB > 112)
{
for (var m1 = 0; m1 < m; m1++)
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsB; i++)
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var iIndex = i * rowsA;
Complex s = 0;
for (var l = 0; l != columnsB; l++)
var matBcolPos = n1 + shiftBcol;
var sum = Complex.Zero;
for (var k1 = 0; k1 < k; ++k1)
{
s += adata[iIndex + l] * bdata[(l * rowsB) + j];
sum += matrixA[((k1 + shiftAcol) * constM) + matArowPos] *
matrixB[((k1 + shiftBrow) * constN) + matBcolPos].Conjugate();
}
c[jIndex + i] = (c[jIndex + i] * beta) + (alpha * s);
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
});
}
else if ((int)transposeA > 111)
{
CommonParallel.For(
0,
columnsB,
j =>
}
}
else
{
for (var m1 = 0; m1 < m; m1++)
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != columnsA; i++)
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var iIndex = i * rowsA;
Complex s = 0;
for (var l = 0; l != rowsA; l++)
var matBcolPos = n1 + shiftBcol;
var sum = Complex.Zero;
for (var k1 = 0; k1 < k; ++k1)
{
s += adata[iIndex + l] * bdata[jbIndex + l];
sum += matrixA[((k1 + shiftAcol) * constM) + matArowPos] *
matrixB[((k1 + shiftBrow) * constN) + matBcolPos];
}
c[jcIndex + i] = (alpha * s) + (c[jcIndex + i] * beta);
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
});
}
}
}
else if ((int)transposeB > 111)
else
{
CommonParallel.For(
0,
rowsB,
j =>
for (var m1 = 0; m1 < m; m1++)
{
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsA; i++)
var matBcolPos = n1 + shiftBcol;
var sum = Complex.Zero;
for (var k1 = 0; k1 < k; ++k1)
{
Complex s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[(l * rowsB) + j];
}
c[jIndex + i] = (alpha * s) + (c[jIndex + i] * beta);
sum += matrixA[((k1 + shiftAcol) * constM) + matArowPos] *
matrixB[(matBcolPos * constK) + k1 + shiftBrow];
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
}
}
else
{
// divide and conquer
int m2 = m / 2, n2 = n / 2, k2 = k / 2;
if (first)
{
CommonParallel.Invoke(
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol, matrixB, shiftBrow, shiftBcol, result, shiftCrow, shiftCcol, m2, n2, k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol, matrixB, shiftBrow, shiftBcol + n2, result, shiftCrow, shiftCcol + n2, m2, n - n2, k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol, result, shiftCrow, shiftCcol, m2, n2, k - k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol + n2, result, shiftCrow, shiftCcol + n2, m2, n - n2, k - k2, constM, constN, constK, false));
CommonParallel.Invoke(
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol, matrixB, shiftBrow, shiftBcol, result, shiftCrow + m2, shiftCcol, m - m2, n2, k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol, matrixB, shiftBrow, shiftBcol + n2, result, shiftCrow + m2, shiftCcol + n2, m - m2, n - n2, k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol, result, shiftCrow + m2, shiftCcol, m - m2, n2, k - k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol + n2, result, shiftCrow + m2, shiftCcol + n2, m - m2, n - n2, k - k2, constM, constN, constK, false));
}
else
{
CommonParallel.For(
0,
columnsB,
j =>
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != rowsA; i++)
{
Complex s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[jbIndex + l];
}
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol, matrixB, shiftBrow, shiftBcol, result, shiftCrow, shiftCcol, m2, n2, k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol, matrixB, shiftBrow, shiftBcol + n2, result, shiftCrow, shiftCcol + n2, m2, n - n2, k2, constM, constN, constK, false);
c[jcIndex + i] = (alpha * s) + (c[jcIndex + i] * beta);
}
});
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol, result, shiftCrow, shiftCcol, m2, n2, k - k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol + n2, result, shiftCrow, shiftCcol + n2, m2, n - n2, k - k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol, matrixB, shiftBrow, shiftBcol, result, shiftCrow + m2, shiftCcol, m - m2, n2, k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol, matrixB, shiftBrow, shiftBcol + n2, result, shiftCrow + m2, shiftCcol + n2, m - m2, n - n2, k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol, result, shiftCrow + m2, shiftCcol, m - m2, n2, k - k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol + n2, result, shiftCrow + m2, shiftCcol + n2, m - m2, n - n2, k - k2, constM, constN, constK, false);
}
}
}

429
src/Numerics/Algorithms/LinearAlgebra/ManagedLinearAlgebraProvider.Complex32.cs

@ -456,9 +456,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
/// <param name="c">The c matrix.</param>
public virtual void MatrixMultiplyWithUpdate(Transpose transposeA, Transpose transposeB, Complex32 alpha, Complex32[] a, int rowsA, int columnsA, Complex32[] b, int rowsB, int columnsB, Complex32 beta, Complex32[] c)
{
// Choose nonsensical values for the number of rows in c; fill them in depending
// on the operations on a and b.
int rowsC;
int m; // The number of rows of matrix op(A) and of the matrix C.
int n; // The number of columns of matrix op(B) and of the matrix C.
int k; // The number of columns of matrix op(A) and the rows of the matrix op(B).
// First check some basic requirement on the parameters of the matrix multiplication.
if (a == null)
@ -483,7 +483,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
throw new ArgumentOutOfRangeException();
}
rowsC = columnsA;
m = columnsA;
n = rowsB;
k = rowsA;
}
else if ((int)transposeA > 111)
{
@ -497,7 +499,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
throw new ArgumentOutOfRangeException();
}
rowsC = columnsA;
m = columnsA;
n = columnsB;
k = rowsA;
}
else if ((int)transposeB > 111)
{
@ -511,7 +515,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
throw new ArgumentOutOfRangeException();
}
rowsC = rowsA;
m = rowsA;
n = rowsB;
k = columnsA;
}
else
{
@ -525,7 +531,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
throw new ArgumentOutOfRangeException();
}
rowsC = rowsA;
m = rowsA;
n = columnsB;
k = columnsA;
}
if (alpha.IsZero() && beta.IsZero())
@ -557,268 +565,271 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
bdata = b;
}
if (alpha.IsOne())
if (beta.IsZero())
{
if (beta.IsZero())
Array.Clear(c, 0, c.Length);
}
else if (!beta.IsOne())
{
Control.LinearAlgebraProvider.ScaleArray(beta, c, c);
}
if (alpha.IsZero())
{
return;
}
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, adata, 0, 0, bdata, 0, 0, c, 0, 0, m, n, k, m, n, k, true);
}
/// <summary>
/// Cache-Oblivious Matrix Multiplication
/// </summary>
/// <param name="transposeA">if set to <c>true</c> transpose matrix A.</param>
/// <param name="transposeB">if set to <c>true</c> transpose matrix B.</param>
/// <param name="alpha">The value to scale the matrix A with.</param>
/// <param name="matrixA">The matrix A.</param>
/// <param name="shiftArow">Row-shift of the left matrix</param>
/// <param name="shiftAcol">Column-shift of the left matrix</param>
/// <param name="matrixB">The matrix B.</param>
/// <param name="shiftBrow">Row-shift of the right matrix</param>
/// <param name="shiftBcol">Column-shift of the right matrix</param>
/// <param name="result">The matrix C.</param>
/// <param name="shiftCrow">Row-shift of the result matrix</param>
/// <param name="shiftCcol">Column-shift of the result matrix</param>
/// <param name="m">The number of rows of matrix op(A) and of the matrix C.</param>
/// <param name="n">The number of columns of matrix op(B) and of the matrix C.</param>
/// <param name="k">The number of columns of matrix op(A) and the rows of the matrix op(B).</param>
/// <param name="constM">The constant number of rows of matrix op(A) and of the matrix C.</param>
/// <param name="constN">The constant number of columns of matrix op(B) and of the matrix C.</param>
/// <param name="constK">The constant number of columns of matrix op(A) and the rows of the matrix op(B).</param>
/// <param name="first">Indicates if this is the first recursion.</param>
private static void CacheObliviousMatrixMultiply(Transpose transposeA, Transpose transposeB, Complex32 alpha, Complex32[] matrixA, int shiftArow, int shiftAcol, Complex32[] matrixB, int shiftBrow, int shiftBcol, Complex32[] result, int shiftCrow, int shiftCcol, int m, int n, int k, int constM, int constN, int constK, bool first)
{
if (m + n + k <= Control.ParallelizeOrder)
{
if ((int)transposeA > 111 && (int)transposeB > 111)
{
if ((int)transposeA > 111 && (int)transposeB > 111)
if ((int)transposeA > 112 && (int)transposeB > 112)
{
CommonParallel.For(
0,
columnsA,
j =>
for (var m1 = 0; m1 < m; m1++)
{
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsB; i++)
var matBcolPos = n1 + shiftBcol;
var sum = Complex32.Zero;
for (var k1 = 0; k1 < k; ++k1)
{
var iIndex = i * rowsA;
Complex32 s = 0;
for (var l = 0; l != columnsB; l++)
{
s += adata[iIndex + l] * bdata[(l * rowsB) + j];
}
c[jIndex + i] = s;
sum += matrixA[(matArowPos * constK) + k1 + shiftAcol].Conjugate() *
matrixB[((k1 + shiftBrow) * constN) + matBcolPos].Conjugate();
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
}
else if ((int)transposeA > 111)
else if ((int)transposeA > 112)
{
CommonParallel.For(
0,
columnsB,
j =>
for (var m1 = 0; m1 < m; m1++)
{
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != columnsA; i++)
var matBcolPos = n1 + shiftBcol;
var sum = Complex32.Zero;
for (var k1 = 0; k1 < k; ++k1)
{
var iIndex = i * rowsA;
Complex32 s = 0;
for (var l = 0; l != rowsA; l++)
{
s += adata[iIndex + l] * bdata[jbIndex + l];
}
c[jcIndex + i] = s;
sum += matrixA[(matArowPos * constK) + k1 + shiftAcol].Conjugate() *
matrixB[((k1 + shiftBrow) * constN) + matBcolPos];
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
}
else if ((int)transposeB > 111)
else if ((int)transposeB > 112)
{
CommonParallel.For(
0,
rowsB,
j =>
for (var m1 = 0; m1 < m; m1++)
{
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsA; i++)
var matBcolPos = n1 + shiftBcol;
var sum = Complex32.Zero;
for (var k1 = 0; k1 < k; ++k1)
{
Complex32 s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[(l * rowsB) + j];
}
c[jIndex + i] = s;
sum += matrixA[(matArowPos * constK) + k1 + shiftAcol] *
matrixB[((k1 + shiftBrow) * constN) + matBcolPos].Conjugate();
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
}
else
{
CommonParallel.For(
0,
columnsB,
j =>
for (var m1 = 0; m1 < m; m1++)
{
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != rowsA; i++)
var matBcolPos = n1 + shiftBcol;
var sum = Complex32.Zero;
for (var k1 = 0; k1 < k; ++k1)
{
Complex32 s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[jbIndex + l];
}
c[jcIndex + i] = s;
sum += matrixA[(matArowPos * constK) + k1 + shiftAcol] *
matrixB[((k1 + shiftBrow) * constN) + matBcolPos];
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
}
}
else
else if ((int)transposeA > 111)
{
if ((int)transposeA > 111 && (int)transposeB > 111)
if ((int)transposeA > 112)
{
CommonParallel.For(
0,
columnsA,
j =>
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsB; i++)
{
var iIndex = i * rowsA;
Complex32 s = 0;
for (var l = 0; l != columnsB; l++)
{
s += adata[iIndex + l] * bdata[(l * rowsB) + j];
}
c[jIndex + i] = (c[jIndex + i] * beta) + s;
}
});
}
else if ((int)transposeA > 111)
{
CommonParallel.For(
0,
columnsB,
j =>
for (var m1 = 0; m1 < m; m1++)
{
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != columnsA; i++)
var matBcolPos = n1 + shiftBcol;
var sum = Complex32.Zero;
for (var k1 = 0; k1 < k; ++k1)
{
var iIndex = i * rowsA;
Complex32 s = 0;
for (var l = 0; l != rowsA; l++)
{
s += adata[iIndex + l] * bdata[jbIndex + l];
}
c[jcIndex + i] = s + (c[jcIndex + i] * beta);
sum += matrixA[(matArowPos * constK) + k1 + shiftAcol].Conjugate() *
matrixB[(matBcolPos * constK) + k1 + shiftBrow];
}
});
}
else if ((int)transposeB > 111)
{
CommonParallel.For(
0,
rowsB,
j =>
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsA; i++)
{
Complex32 s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[(l * rowsB) + j];
}
c[jIndex + i] = s + (c[jIndex + i] * beta);
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
}
else
{
CommonParallel.For(
0,
columnsB,
j =>
for (var m1 = 0; m1 < m; m1++)
{
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != rowsA; i++)
var matBcolPos = n1 + shiftBcol;
var sum = Complex32.Zero;
for (var k1 = 0; k1 < k; ++k1)
{
Complex32 s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[jbIndex + l];
}
c[jcIndex + i] = s + (c[jcIndex + i] * beta);
sum += matrixA[(matArowPos * constK) + k1 + shiftAcol] *
matrixB[(matBcolPos * constK) + k1 + shiftBrow];
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
}
}
}
else
{
if ((int)transposeA > 111 && (int)transposeB > 111)
else if ((int)transposeB > 111)
{
CommonParallel.For(
0,
columnsA,
j =>
if ((int)transposeB > 112)
{
for (var m1 = 0; m1 < m; m1++)
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsB; i++)
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var iIndex = i * rowsA;
Complex32 s = 0;
for (var l = 0; l != columnsB; l++)
var matBcolPos = n1 + shiftBcol;
var sum = Complex32.Zero;
for (var k1 = 0; k1 < k; ++k1)
{
s += adata[iIndex + l] * bdata[(l * rowsB) + j];
sum += matrixA[((k1 + shiftAcol) * constM) + matArowPos] *
matrixB[((k1 + shiftBrow) * constN) + matBcolPos].Conjugate();
}
c[jIndex + i] = (c[jIndex + i] * beta) + (alpha * s);
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
});
}
else if ((int)transposeA > 111)
{
CommonParallel.For(
0,
columnsB,
j =>
}
}
else
{
for (var m1 = 0; m1 < m; m1++)
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != columnsA; i++)
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var iIndex = i * rowsA;
Complex32 s = 0;
for (var l = 0; l != rowsA; l++)
var matBcolPos = n1 + shiftBcol;
var sum = Complex32.Zero;
for (var k1 = 0; k1 < k; ++k1)
{
s += adata[iIndex + l] * bdata[jbIndex + l];
sum += matrixA[((k1 + shiftAcol) * constM) + matArowPos] *
matrixB[((k1 + shiftBrow) * constN) + matBcolPos];
}
c[jcIndex + i] = (alpha * s) + (c[jcIndex + i] * beta);
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
});
}
}
}
else if ((int)transposeB > 111)
else
{
CommonParallel.For(
0,
rowsB,
j =>
for (var m1 = 0; m1 < m; m1++)
{
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsA; i++)
var matBcolPos = n1 + shiftBcol;
var sum = Complex32.Zero;
for (var k1 = 0; k1 < k; ++k1)
{
Complex32 s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[(l * rowsB) + j];
}
c[jIndex + i] = (alpha * s) + (c[jIndex + i] * beta);
sum += matrixA[((k1 + shiftAcol) * constM) + matArowPos] *
matrixB[(matBcolPos * constK) + k1 + shiftBrow];
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
}
}
else
{
// divide and conquer
int m2 = m / 2, n2 = n / 2, k2 = k / 2;
if (first)
{
CommonParallel.Invoke(
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol, matrixB, shiftBrow, shiftBcol, result, shiftCrow, shiftCcol, m2, n2, k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol, matrixB, shiftBrow, shiftBcol + n2, result, shiftCrow, shiftCcol + n2, m2, n - n2, k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol, result, shiftCrow, shiftCcol, m2, n2, k - k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol + n2, result, shiftCrow, shiftCcol + n2, m2, n - n2, k - k2, constM, constN, constK, false));
CommonParallel.Invoke(
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol, matrixB, shiftBrow, shiftBcol, result, shiftCrow + m2, shiftCcol, m - m2, n2, k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol, matrixB, shiftBrow, shiftBcol + n2, result, shiftCrow + m2, shiftCcol + n2, m - m2, n - n2, k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol, result, shiftCrow + m2, shiftCcol, m - m2, n2, k - k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol + n2, result, shiftCrow + m2, shiftCcol + n2, m - m2, n - n2, k - k2, constM, constN, constK, false));
}
else
{
CommonParallel.For(
0,
columnsB,
j =>
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != rowsA; i++)
{
Complex32 s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[jbIndex + l];
}
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol, matrixB, shiftBrow, shiftBcol, result, shiftCrow, shiftCcol, m2, n2, k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol, matrixB, shiftBrow, shiftBcol + n2, result, shiftCrow, shiftCcol + n2, m2, n - n2, k2, constM, constN, constK, false);
c[jcIndex + i] = (alpha * s) + (c[jcIndex + i] * beta);
}
});
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol, result, shiftCrow, shiftCcol, m2, n2, k - k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol + n2, result, shiftCrow, shiftCcol + n2, m2, n - n2, k - k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol, matrixB, shiftBrow, shiftBcol, result, shiftCrow + m2, shiftCcol, m - m2, n2, k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol, matrixB, shiftBrow, shiftBcol + n2, result, shiftCrow + m2, shiftCcol + n2, m - m2, n - n2, k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol, result, shiftCrow + m2, shiftCcol, m - m2, n2, k - k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol + n2, result, shiftCrow + m2, shiftCcol + n2, m - m2, n - n2, k - k2, constM, constN, constK, false);
}
}
}

380
src/Numerics/Algorithms/LinearAlgebra/ManagedLinearAlgebraProvider.Double.cs

@ -455,9 +455,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
/// <param name="c">The c matrix.</param>
public virtual void MatrixMultiplyWithUpdate(Transpose transposeA, Transpose transposeB, double alpha, double[] a, int rowsA, int columnsA, double[] b, int rowsB, int columnsB, double beta, double[] c)
{
// Choose nonsensical values for the number of rows in c; fill them in depending
// on the operations on a and b.
int rowsC;
int m; // The number of rows of matrix op(A) and of the matrix C.
int n; // The number of columns of matrix op(B) and of the matrix C.
int k; // The number of columns of matrix op(A) and the rows of the matrix op(B).
// First check some basic requirement on the parameters of the matrix multiplication.
if (a == null)
@ -482,7 +482,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
throw new ArgumentOutOfRangeException();
}
rowsC = columnsA;
m = columnsA;
n = rowsB;
k = rowsA;
}
else if ((int)transposeA > 111)
{
@ -496,7 +498,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
throw new ArgumentOutOfRangeException();
}
rowsC = columnsA;
m = columnsA;
n = columnsB;
k = rowsA;
}
else if ((int)transposeB > 111)
{
@ -510,7 +514,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
throw new ArgumentOutOfRangeException();
}
rowsC = rowsA;
m = rowsA;
n = rowsB;
k = columnsA;
}
else
{
@ -524,7 +530,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
throw new ArgumentOutOfRangeException();
}
rowsC = rowsA;
m = rowsA;
n = columnsB;
k = columnsA;
}
if (alpha == 0.0 && beta == 0.0)
@ -556,268 +564,162 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
bdata = b;
}
if (alpha == 1.0)
if (beta == 0.0)
{
if (beta == 0.0)
Array.Clear(c, 0, c.Length);
}
else if (beta != 1.0)
{
Control.LinearAlgebraProvider.ScaleArray(beta, c, c);
}
if (alpha == 0.0)
{
return;
}
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, adata, 0, 0, bdata, 0, 0, c, 0, 0, m, n, k, m, n, k, true);
}
/// <summary>
/// Cache-Oblivious Matrix Multiplication
/// </summary>
/// <param name="transposeA">if set to <c>true</c> transpose matrix A.</param>
/// <param name="transposeB">if set to <c>true</c> transpose matrix B.</param>
/// <param name="alpha">The value to scale the matrix A with.</param>
/// <param name="matrixA">The matrix A.</param>
/// <param name="shiftArow">Row-shift of the left matrix</param>
/// <param name="shiftAcol">Column-shift of the left matrix</param>
/// <param name="matrixB">The matrix B.</param>
/// <param name="shiftBrow">Row-shift of the right matrix</param>
/// <param name="shiftBcol">Column-shift of the right matrix</param>
/// <param name="result">The matrix C.</param>
/// <param name="shiftCrow">Row-shift of the result matrix</param>
/// <param name="shiftCcol">Column-shift of the result matrix</param>
/// <param name="m">The number of rows of matrix op(A) and of the matrix C.</param>
/// <param name="n">The number of columns of matrix op(B) and of the matrix C.</param>
/// <param name="k">The number of columns of matrix op(A) and the rows of the matrix op(B).</param>
/// <param name="constM">The constant number of rows of matrix op(A) and of the matrix C.</param>
/// <param name="constN">The constant number of columns of matrix op(B) and of the matrix C.</param>
/// <param name="constK">The constant number of columns of matrix op(A) and the rows of the matrix op(B).</param>
/// <param name="first">Indicates if this is the first recursion.</param>
private static void CacheObliviousMatrixMultiply(Transpose transposeA, Transpose transposeB, double alpha, double[] matrixA, int shiftArow, int shiftAcol, double[] matrixB, int shiftBrow, int shiftBcol, double[] result, int shiftCrow, int shiftCcol, int m, int n, int k, int constM, int constN, int constK, bool first)
{
if (m + n + k <= Control.ParallelizeOrder)
{
if ((int)transposeA > 111 && (int)transposeB > 111)
{
if ((int)transposeA > 111 && (int)transposeB > 111)
{
CommonParallel.For(
0,
columnsA,
j =>
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsB; i++)
{
var iIndex = i * rowsA;
double s = 0;
for (var l = 0; l != columnsB; l++)
{
s += adata[iIndex + l] * bdata[(l * rowsB) + j];
}
c[jIndex + i] = s;
}
});
}
else if ((int)transposeA > 111)
{
CommonParallel.For(
0,
columnsB,
j =>
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != columnsA; i++)
{
var iIndex = i * rowsA;
double s = 0;
for (var l = 0; l != rowsA; l++)
{
s += adata[iIndex + l] * bdata[jbIndex + l];
}
c[jcIndex + i] = s;
}
});
}
else if ((int)transposeB > 111)
for (var m1 = 0; m1 < m; m1++)
{
CommonParallel.For(
0,
rowsB,
j =>
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var matBcolPos = n1 + shiftBcol;
double sum = 0;
for (var k1 = 0; k1 < k; ++k1)
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsA; i++)
{
double s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[(l * rowsB) + j];
}
sum += matrixA[(matArowPos * constK) + k1 + shiftAcol] *
matrixB[((k1 + shiftBrow) * constN) + matBcolPos];
}
c[jIndex + i] = s;
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
else
}
else if ((int)transposeA > 111)
{
for (var m1 = 0; m1 < m; m1++)
{
CommonParallel.For(
0,
columnsB,
j =>
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var matBcolPos = n1 + shiftBcol;
double sum = 0;
for (var k1 = 0; k1 < k; ++k1)
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != rowsA; i++)
{
double s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[jbIndex + l];
}
sum += matrixA[(matArowPos * constK) + k1 + shiftAcol] *
matrixB[(matBcolPos * constK) + k1 + shiftBrow];
}
c[jcIndex + i] = s;
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
}
else
else if ((int)transposeB > 111)
{
if ((int)transposeA > 111 && (int)transposeB > 111)
{
CommonParallel.For(
0,
columnsA,
j =>
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsB; i++)
{
var iIndex = i * rowsA;
double s = 0;
for (var l = 0; l != columnsB; l++)
{
s += adata[iIndex + l] * bdata[(l * rowsB) + j];
}
c[jIndex + i] = (c[jIndex + i] * beta) + s;
}
});
}
else if ((int)transposeA > 111)
for (var m1 = 0; m1 < m; m1++)
{
CommonParallel.For(
0,
columnsB,
j =>
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != columnsA; i++)
{
var iIndex = i * rowsA;
double s = 0;
for (var l = 0; l != rowsA; l++)
{
s += adata[iIndex + l] * bdata[jbIndex + l];
}
c[jcIndex + i] = s + (c[jcIndex + i] * beta);
}
});
}
else if ((int)transposeB > 111)
{
CommonParallel.For(
0,
rowsB,
j =>
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var matBcolPos = n1 + shiftBcol;
double sum = 0;
for (var k1 = 0; k1 < k; ++k1)
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsA; i++)
{
double s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[(l * rowsB) + j];
}
sum += matrixA[((k1 + shiftAcol) * constM) + matArowPos] *
matrixB[((k1 + shiftBrow) * constN) + matBcolPos];
}
c[jIndex + i] = s + (c[jIndex + i] * beta);
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
else
}
else
{
for (var m1 = 0; m1 < m; m1++)
{
CommonParallel.For(
0,
columnsB,
j =>
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var matBcolPos = n1 + shiftBcol;
double sum = 0;
for (var k1 = 0; k1 < k; ++k1)
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != rowsA; i++)
{
double s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[jbIndex + l];
}
sum += matrixA[((k1 + shiftAcol) * constM) + matArowPos] *
matrixB[(matBcolPos * constK) + k1 + shiftBrow];
}
c[jcIndex + i] = s + (c[jcIndex + i] * beta);
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
}
}
else
{
if ((int)transposeA > 111 && (int)transposeB > 111)
{
CommonParallel.For(
0,
columnsA,
j =>
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsB; i++)
{
var iIndex = i * rowsA;
double s = 0;
for (var l = 0; l != columnsB; l++)
{
s += adata[iIndex + l] * bdata[(l * rowsB) + j];
}
// divide and conquer
int m2 = m / 2, n2 = n / 2, k2 = k / 2;
c[jIndex + i] = (c[jIndex + i] * beta) + (alpha * s);
}
});
}
else if ((int)transposeA > 111)
if (first)
{
CommonParallel.For(
0,
columnsB,
j =>
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != columnsA; i++)
{
var iIndex = i * rowsA;
double s = 0;
for (var l = 0; l != rowsA; l++)
{
s += adata[iIndex + l] * bdata[jbIndex + l];
}
c[jcIndex + i] = (alpha * s) + (c[jcIndex + i] * beta);
}
});
}
else if ((int)transposeB > 111)
{
CommonParallel.For(
0,
rowsB,
j =>
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsA; i++)
{
double s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[(l * rowsB) + j];
}
c[jIndex + i] = (alpha * s) + (c[jIndex + i] * beta);
}
});
CommonParallel.Invoke(
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol, matrixB, shiftBrow, shiftBcol, result, shiftCrow, shiftCcol, m2, n2, k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol, matrixB, shiftBrow, shiftBcol + n2, result, shiftCrow, shiftCcol + n2, m2, n - n2, k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol, result, shiftCrow, shiftCcol, m2, n2, k - k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol + n2, result, shiftCrow, shiftCcol + n2, m2, n - n2, k - k2, constM, constN, constK, false));
CommonParallel.Invoke(
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol, matrixB, shiftBrow, shiftBcol, result, shiftCrow + m2, shiftCcol, m - m2, n2, k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol, matrixB, shiftBrow, shiftBcol + n2, result, shiftCrow + m2, shiftCcol + n2, m - m2, n - n2, k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol, result, shiftCrow + m2, shiftCcol, m - m2, n2, k - k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol + n2, result, shiftCrow + m2, shiftCcol + n2, m - m2, n - n2, k - k2, constM, constN, constK, false));
}
else
{
CommonParallel.For(
0,
columnsB,
j =>
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != rowsA; i++)
{
double s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[jbIndex + l];
}
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol, matrixB, shiftBrow, shiftBcol, result, shiftCrow, shiftCcol, m2, n2, k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol, matrixB, shiftBrow, shiftBcol + n2, result, shiftCrow, shiftCcol + n2, m2, n - n2, k2, constM, constN, constK, false);
c[jcIndex + i] = (alpha * s) + (c[jcIndex + i] * beta);
}
});
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol, result, shiftCrow, shiftCcol, m2, n2, k - k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol + n2, result, shiftCrow, shiftCcol + n2, m2, n - n2, k - k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol, matrixB, shiftBrow, shiftBcol, result, shiftCrow + m2, shiftCcol, m - m2, n2, k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol, matrixB, shiftBrow, shiftBcol + n2, result, shiftCrow + m2, shiftCcol + n2, m - m2, n - n2, k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol, result, shiftCrow + m2, shiftCcol, m - m2, n2, k - k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol + n2, result, shiftCrow + m2, shiftCcol + n2, m - m2, n - n2, k - k2, constM, constN, constK, false);
}
}
}

380
src/Numerics/Algorithms/LinearAlgebra/ManagedLinearAlgebraProvider.Single.cs

@ -451,9 +451,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
/// <param name="c">The c matrix.</param>
public virtual void MatrixMultiplyWithUpdate(Transpose transposeA, Transpose transposeB, float alpha, float[] a, int rowsA, int columnsA, float[] b, int rowsB, int columnsB, float beta, float[] c)
{
// Choose nonsensical values for the number of rows in c; fill them in depending
// on the operations on a and b.
int rowsC;
int m; // The number of rows of matrix op(A) and of the matrix C.
int n; // The number of columns of matrix op(B) and of the matrix C.
int k; // The number of columns of matrix op(A) and the rows of the matrix op(B).
// First check some basic requirement on the parameters of the matrix multiplication.
if (a == null)
@ -478,7 +478,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
throw new ArgumentOutOfRangeException();
}
rowsC = columnsA;
m = columnsA;
n = rowsB;
k = rowsA;
}
else if ((int)transposeA > 111)
{
@ -492,7 +494,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
throw new ArgumentOutOfRangeException();
}
rowsC = columnsA;
m = columnsA;
n = columnsB;
k = rowsA;
}
else if ((int)transposeB > 111)
{
@ -506,7 +510,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
throw new ArgumentOutOfRangeException();
}
rowsC = rowsA;
m = rowsA;
n = rowsB;
k = columnsA;
}
else
{
@ -520,7 +526,9 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
throw new ArgumentOutOfRangeException();
}
rowsC = rowsA;
m = rowsA;
n = columnsB;
k = columnsA;
}
if (alpha == 0.0 && beta == 0.0)
@ -552,268 +560,162 @@ namespace MathNet.Numerics.Algorithms.LinearAlgebra
bdata = b;
}
if (alpha == 1.0)
if (beta == 0.0f)
{
if (beta == 0.0)
Array.Clear(c, 0, c.Length);
}
else if (beta != 1.0f)
{
Control.LinearAlgebraProvider.ScaleArray(beta, c, c);
}
if (alpha == 0.0f)
{
return;
}
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, adata, 0, 0, bdata, 0, 0, c, 0, 0, m, n, k, m, n, k, true);
}
/// <summary>
/// Cache-Oblivious Matrix Multiplication
/// </summary>
/// <param name="transposeA">if set to <c>true</c> transpose matrix A.</param>
/// <param name="transposeB">if set to <c>true</c> transpose matrix B.</param>
/// <param name="alpha">The value to scale the matrix A with.</param>
/// <param name="matrixA">The matrix A.</param>
/// <param name="shiftArow">Row-shift of the left matrix</param>
/// <param name="shiftAcol">Column-shift of the left matrix</param>
/// <param name="matrixB">The matrix B.</param>
/// <param name="shiftBrow">Row-shift of the right matrix</param>
/// <param name="shiftBcol">Column-shift of the right matrix</param>
/// <param name="result">The matrix C.</param>
/// <param name="shiftCrow">Row-shift of the result matrix</param>
/// <param name="shiftCcol">Column-shift of the result matrix</param>
/// <param name="m">The number of rows of matrix op(A) and of the matrix C.</param>
/// <param name="n">The number of columns of matrix op(B) and of the matrix C.</param>
/// <param name="k">The number of columns of matrix op(A) and the rows of the matrix op(B).</param>
/// <param name="constM">The constant number of rows of matrix op(A) and of the matrix C.</param>
/// <param name="constN">The constant number of columns of matrix op(B) and of the matrix C.</param>
/// <param name="constK">The constant number of columns of matrix op(A) and the rows of the matrix op(B).</param>
/// <param name="first">Indicates if this is the first recursion.</param>
private static void CacheObliviousMatrixMultiply(Transpose transposeA, Transpose transposeB, float alpha, float[] matrixA, int shiftArow, int shiftAcol, float[] matrixB, int shiftBrow, int shiftBcol, float[] result, int shiftCrow, int shiftCcol, int m, int n, int k, int constM, int constN, int constK, bool first)
{
if (m + n + k <= Control.ParallelizeOrder)
{
if ((int)transposeA > 111 && (int)transposeB > 111)
{
if ((int)transposeA > 111 && (int)transposeB > 111)
{
CommonParallel.For(
0,
columnsA,
j =>
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsB; i++)
{
var iIndex = i * rowsA;
float s = 0;
for (var l = 0; l != columnsB; l++)
{
s += adata[iIndex + l] * bdata[(l * rowsB) + j];
}
c[jIndex + i] = s;
}
});
}
else if ((int)transposeA > 111)
for (var m1 = 0; m1 < m; m1++)
{
CommonParallel.For(
0,
columnsB,
j =>
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != columnsA; i++)
{
var iIndex = i * rowsA;
float s = 0;
for (var l = 0; l != rowsA; l++)
{
s += adata[iIndex + l] * bdata[jbIndex + l];
}
c[jcIndex + i] = s;
}
});
}
else if ((int)transposeB > 111)
{
CommonParallel.For(
0,
rowsB,
j =>
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var matBcolPos = n1 + shiftBcol;
float sum = 0;
for (var k1 = 0; k1 < k; ++k1)
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsA; i++)
{
float s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[(l * rowsB) + j];
}
sum += matrixA[(matArowPos * constK) + k1 + shiftAcol] *
matrixB[((k1 + shiftBrow) * constN) + matBcolPos];
}
c[jIndex + i] = s;
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
else
}
else if ((int)transposeA > 111)
{
for (var m1 = 0; m1 < m; m1++)
{
CommonParallel.For(
0,
columnsB,
j =>
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var matBcolPos = n1 + shiftBcol;
float sum = 0;
for (var k1 = 0; k1 < k; ++k1)
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != rowsA; i++)
{
float s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[jbIndex + l];
}
sum += matrixA[(matArowPos * constK) + k1 + shiftAcol] *
matrixB[(matBcolPos * constK) + k1 + shiftBrow];
}
c[jcIndex + i] = s;
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
}
else
else if ((int)transposeB > 111)
{
if ((int)transposeA > 111 && (int)transposeB > 111)
{
CommonParallel.For(
0,
columnsA,
j =>
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsB; i++)
{
var iIndex = i * rowsA;
float s = 0;
for (var l = 0; l != columnsB; l++)
{
s += adata[iIndex + l] * bdata[(l * rowsB) + j];
}
c[jIndex + i] = (c[jIndex + i] * beta) + s;
}
});
}
else if ((int)transposeA > 111)
for (var m1 = 0; m1 < m; m1++)
{
CommonParallel.For(
0,
columnsB,
j =>
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != columnsA; i++)
{
var iIndex = i * rowsA;
float s = 0;
for (var l = 0; l != rowsA; l++)
{
s += adata[iIndex + l] * bdata[jbIndex + l];
}
c[jcIndex + i] = s + (c[jcIndex + i] * beta);
}
});
}
else if ((int)transposeB > 111)
{
CommonParallel.For(
0,
rowsB,
j =>
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var matBcolPos = n1 + shiftBcol;
float sum = 0;
for (var k1 = 0; k1 < k; ++k1)
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsA; i++)
{
float s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[(l * rowsB) + j];
}
sum += matrixA[((k1 + shiftAcol) * constM) + matArowPos] *
matrixB[((k1 + shiftBrow) * constN) + matBcolPos];
}
c[jIndex + i] = s + (c[jIndex + i] * beta);
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
else
}
else
{
for (var m1 = 0; m1 < m; m1++)
{
CommonParallel.For(
0,
columnsB,
j =>
var matArowPos = m1 + shiftArow;
var matCrowPos = m1 + shiftCrow;
for (var n1 = 0; n1 < n; ++n1)
{
var matBcolPos = n1 + shiftBcol;
float sum = 0;
for (var k1 = 0; k1 < k; ++k1)
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != rowsA; i++)
{
float s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[jbIndex + l];
}
sum += matrixA[((k1 + shiftAcol) * constM) + matArowPos] *
matrixB[(matBcolPos * constK) + k1 + shiftBrow];
}
c[jcIndex + i] = s + (c[jcIndex + i] * beta);
}
});
result[((n1 + shiftCcol) * constM) + matCrowPos] += alpha * sum;
}
}
}
}
else
{
if ((int)transposeA > 111 && (int)transposeB > 111)
{
CommonParallel.For(
0,
columnsA,
j =>
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsB; i++)
{
var iIndex = i * rowsA;
float s = 0;
for (var l = 0; l != columnsB; l++)
{
s += adata[iIndex + l] * bdata[(l * rowsB) + j];
}
c[jIndex + i] = (c[jIndex + i] * beta) + (alpha * s);
}
});
}
else if ((int)transposeA > 111)
{
CommonParallel.For(
0,
columnsB,
j =>
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != columnsA; i++)
{
var iIndex = i * rowsA;
float s = 0;
for (var l = 0; l != rowsA; l++)
{
s += adata[iIndex + l] * bdata[jbIndex + l];
}
// divide and conquer
int m2 = m / 2, n2 = n / 2, k2 = k / 2;
c[jcIndex + i] = (alpha * s) + (c[jcIndex + i] * beta);
}
});
}
else if ((int)transposeB > 111)
if (first)
{
CommonParallel.For(
0,
rowsB,
j =>
{
var jIndex = j * rowsC;
for (var i = 0; i != rowsA; i++)
{
float s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[(l * rowsB) + j];
}
c[jIndex + i] = (alpha * s) + (c[jIndex + i] * beta);
}
});
CommonParallel.Invoke(
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol, matrixB, shiftBrow, shiftBcol, result, shiftCrow, shiftCcol, m2, n2, k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol, matrixB, shiftBrow, shiftBcol + n2, result, shiftCrow, shiftCcol + n2, m2, n - n2, k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol, result, shiftCrow, shiftCcol, m2, n2, k - k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol + n2, result, shiftCrow, shiftCcol + n2, m2, n - n2, k - k2, constM, constN, constK, false));
CommonParallel.Invoke(
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol, matrixB, shiftBrow, shiftBcol, result, shiftCrow + m2, shiftCcol, m - m2, n2, k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol, matrixB, shiftBrow, shiftBcol + n2, result, shiftCrow + m2, shiftCcol + n2, m - m2, n - n2, k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol, result, shiftCrow + m2, shiftCcol, m - m2, n2, k - k2, constM, constN, constK, false),
() => CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol + n2, result, shiftCrow + m2, shiftCcol + n2, m - m2, n - n2, k - k2, constM, constN, constK, false));
}
else
{
CommonParallel.For(
0,
columnsB,
j =>
{
var jcIndex = j * rowsC;
var jbIndex = j * rowsB;
for (var i = 0; i != rowsA; i++)
{
float s = 0;
for (var l = 0; l != columnsA; l++)
{
s += adata[(l * rowsA) + i] * bdata[jbIndex + l];
}
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol, matrixB, shiftBrow, shiftBcol, result, shiftCrow, shiftCcol, m2, n2, k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol, matrixB, shiftBrow, shiftBcol + n2, result, shiftCrow, shiftCcol + n2, m2, n - n2, k2, constM, constN, constK, false);
c[jcIndex + i] = (alpha * s) + (c[jcIndex + i] * beta);
}
});
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol, result, shiftCrow, shiftCcol, m2, n2, k - k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol + n2, result, shiftCrow, shiftCcol + n2, m2, n - n2, k - k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol, matrixB, shiftBrow, shiftBcol, result, shiftCrow + m2, shiftCcol, m - m2, n2, k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol, matrixB, shiftBrow, shiftBcol + n2, result, shiftCrow + m2, shiftCcol + n2, m - m2, n - n2, k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol, result, shiftCrow + m2, shiftCcol, m - m2, n2, k - k2, constM, constN, constK, false);
CacheObliviousMatrixMultiply(transposeA, transposeB, alpha, matrixA, shiftArow + m2, shiftAcol + k2, matrixB, shiftBrow + k2, shiftBcol + n2, result, shiftCrow + m2, shiftCcol + n2, m - m2, n - n2, k - k2, constM, constN, constK, false);
}
}
}

26
src/Numerics/Control.cs

@ -51,6 +51,11 @@ namespace MathNet.Numerics
/// </summary>
private static int _blockSize = 512;
/// <summary>
/// Initial parallel order size for the matrix multiply in linear algebra provider.
/// </summary>
private static int _parallelizeOrder = 64;
/// <summary>
/// Initializes static members of the Control class.
/// </summary>
@ -135,5 +140,26 @@ namespace MathNet.Numerics
}
}
}
/// <summary>
/// Gets or sets the order of the matrix when linear algebra provider must calculate multiply in parallel threads.
/// </summary>
/// <value>The order. Default is 64.</value>
public static int ParallelizeOrder
{
get
{
return _parallelizeOrder;
}
set
{
if (_parallelizeOrder > 2)
{
_parallelizeOrder = value;
}
}
}
}
}

Loading…
Cancel
Save