diff --git a/src/Numerics/LinearAlgebra/Complex/DiagonalMatrix.cs b/src/Numerics/LinearAlgebra/Complex/DiagonalMatrix.cs index c89e3e5b..d0e3781d 100644 --- a/src/Numerics/LinearAlgebra/Complex/DiagonalMatrix.cs +++ b/src/Numerics/LinearAlgebra/Complex/DiagonalMatrix.cs @@ -194,39 +194,6 @@ namespace MathNet.Numerics.LinearAlgebra.Complex i => new Complex(distribution.Sample(), distribution.Sample()))); } - /// - /// Adds another matrix to this matrix. - /// - /// The matrix to add to this matrix. - /// The result of the addition. - /// If the other matrix is . - /// If the two matrices don't have the same dimensions. - public override Matrix Add(Matrix other) - { - if (other == null) - { - throw new ArgumentNullException("other"); - } - - if (other.RowCount != RowCount || other.ColumnCount != ColumnCount) - { - throw DimensionsDontMatch(this, other, "other"); - } - - Matrix result; - if (other is DiagonalMatrix) - { - result = new DiagonalMatrix(RowCount, ColumnCount); - } - else - { - result = new DenseMatrix(RowCount, ColumnCount); - } - - Add(other, result); - return result; - } - /// /// Adds another matrix to this matrix. /// @@ -249,39 +216,6 @@ namespace MathNet.Numerics.LinearAlgebra.Complex } } - /// - /// Subtracts another matrix from this matrix. - /// - /// The matrix to subtract. - /// The result of the subtraction. - /// If the other matrix is . - /// If the two matrices don't have the same dimensions. - public override Matrix Subtract(Matrix other) - { - if (other == null) - { - throw new ArgumentNullException("other"); - } - - if (other.RowCount != RowCount || other.ColumnCount != ColumnCount) - { - throw DimensionsDontMatch(this, other, "other"); - } - - Matrix result; - if (other is DiagonalMatrix) - { - result = new DenseMatrix(RowCount, ColumnCount); - } - else - { - result = new DiagonalMatrix(RowCount, ColumnCount); - } - - Subtract(other, result); - return result; - } - /// /// Subtracts another matrix from this matrix. /// @@ -398,72 +332,44 @@ namespace MathNet.Numerics.LinearAlgebra.Complex /// /// The matrix to multiply with. /// The result of the multiplication. - /// If the other matrix is . - /// If the result matrix is . - /// If this.Columns != other.Rows. - /// If the result matrix's dimensions are not the this.Rows x other.Columns. - public override void Multiply(Matrix other, Matrix result) - { - if (other == null) - { - throw new ArgumentNullException("other"); - } - - if (result == null) - { - throw new ArgumentNullException("result"); - } - - if (ColumnCount != other.RowCount) - { - throw DimensionsDontMatch(this, other); - } - - if (result.RowCount != RowCount || result.ColumnCount != other.ColumnCount) - { - throw DimensionsDontMatch(this, result); - } - - var m = other as DiagonalMatrix; - var r = result as DiagonalMatrix; - - if (m == null || r == null) - { - base.Multiply(other, result); - } - else - { - var thisDataCopy = new Complex[r._data.Length]; - var otherDataCopy = new Complex[r._data.Length]; - Array.Copy(_data, thisDataCopy, (r._data.Length > _data.Length) ? _data.Length : r._data.Length); - Array.Copy(m._data, otherDataCopy, (r._data.Length > m._data.Length) ? m._data.Length : r._data.Length); - - Control.LinearAlgebraProvider.PointWiseMultiplyArrays(thisDataCopy, otherDataCopy, r._data); - } - } - - /// - /// Multiplies this matrix with another matrix and returns the result. - /// - /// The matrix to multiply with. - /// If this.Columns != other.Rows. - /// If the other matrix is . - /// The result of multiplication. - public override Matrix Multiply(Matrix other) + protected override void DoMultiply(Matrix other, Matrix result) { - if (other == null) + var diagonalOther = other as DiagonalMatrix; + var diagonalResult = result as DiagonalMatrix; + if (diagonalOther != null && diagonalResult != null) { - throw new ArgumentNullException("other"); + var thisDataCopy = new Complex[diagonalResult._data.Length]; + var otherDataCopy = new Complex[diagonalResult._data.Length]; + Array.Copy(_data, thisDataCopy, (diagonalResult._data.Length > _data.Length) ? _data.Length : diagonalResult._data.Length); + Array.Copy(diagonalOther._data, otherDataCopy, (diagonalResult._data.Length > diagonalOther._data.Length) ? diagonalOther._data.Length : diagonalResult._data.Length); + Control.LinearAlgebraProvider.PointWiseMultiplyArrays(thisDataCopy, otherDataCopy, diagonalResult._data); + return; } - if (ColumnCount != other.RowCount) + var denseOther = other.Storage as DenseColumnMajorMatrixStorage; + if (denseOther != null) { - throw DimensionsDontMatch(this, other); + var dense = denseOther.Data; + var diagonal = _data; + var d = Math.Min(denseOther.RowCount, RowCount); + if (d < RowCount) + { + result.ClearSubMatrix(denseOther.RowCount, RowCount - denseOther.RowCount, 0, denseOther.ColumnCount); + } + int index = 0; + for (int i = 0; i < denseOther.ColumnCount; i++) + { + for (int j = 0; j < d; j++) + { + result.At(j, i, dense[index]*diagonal[j]); + index++; + } + index += (denseOther.RowCount - d); + } + return; } - var result = other.CreateMatrix(RowCount, other.ColumnCount); - Multiply(other, result); - return result; + base.DoMultiply(other, result); } /// @@ -596,22 +502,88 @@ namespace MathNet.Numerics.LinearAlgebra.Complex /// /// The matrix to multiply with. /// The result of the multiplication. - /// If the other matrix is . - /// If the result matrix is . - /// If this.Columns != other.Rows. - /// If the result matrix's dimensions are not the this.Rows x other.Columns. - public override void TransposeAndMultiply(Matrix other, Matrix result) + protected override void DoTransposeAndMultiply(Matrix other, Matrix result) { - var otherDiagonal = other as DiagonalMatrix; - var resultDiagonal = result as DiagonalMatrix; + var diagonalOther = other as DiagonalMatrix; + var diagonalResult = result as DiagonalMatrix; + if (diagonalOther != null && diagonalResult != null) + { + var thisDataCopy = new Complex[diagonalResult._data.Length]; + var otherDataCopy = new Complex[diagonalResult._data.Length]; + Array.Copy(_data, thisDataCopy, (diagonalResult._data.Length > _data.Length) ? _data.Length : diagonalResult._data.Length); + Array.Copy(diagonalOther._data, otherDataCopy, (diagonalResult._data.Length > diagonalOther._data.Length) ? diagonalOther._data.Length : diagonalResult._data.Length); + Control.LinearAlgebraProvider.PointWiseMultiplyArrays(thisDataCopy, otherDataCopy, diagonalResult._data); + return; + } + + var denseOther = other.Storage as DenseColumnMajorMatrixStorage; + if (denseOther != null) + { + var dense = denseOther.Data; + var diagonal = _data; + var d = Math.Min(denseOther.ColumnCount, RowCount); + if (d < RowCount) + { + result.ClearSubMatrix(denseOther.ColumnCount, RowCount - denseOther.ColumnCount, 0, denseOther.RowCount); + } + int index = 0; + for (int j = 0; j < d; j++) + { + for (int i = 0; i < denseOther.RowCount; i++) + { + result.At(j, i, dense[index]*diagonal[j]); + index++; + } + } + return; + } + + base.DoTransposeAndMultiply(other, result); + } - if (otherDiagonal == null || resultDiagonal == null) + /// + /// Multiplies the transpose of this matrix with another matrix and places the results into the result matrix. + /// + /// The matrix to multiply with. + /// The result of the multiplication. + protected override void DoTransposeThisAndMultiply(Matrix other, Matrix result) + { + var diagonalOther = other as DiagonalMatrix; + var diagonalResult = result as DiagonalMatrix; + if (diagonalOther != null && diagonalResult != null) { - base.TransposeAndMultiply(other, result); + var thisDataCopy = new Complex[diagonalResult._data.Length]; + var otherDataCopy = new Complex[diagonalResult._data.Length]; + Array.Copy(_data, thisDataCopy, (diagonalResult._data.Length > _data.Length) ? _data.Length : diagonalResult._data.Length); + Array.Copy(diagonalOther._data, otherDataCopy, (diagonalResult._data.Length > diagonalOther._data.Length) ? diagonalOther._data.Length : diagonalResult._data.Length); + Control.LinearAlgebraProvider.PointWiseMultiplyArrays(thisDataCopy, otherDataCopy, diagonalResult._data); + return; + } + + var denseOther = other.Storage as DenseColumnMajorMatrixStorage; + if (denseOther != null) + { + var dense = denseOther.Data; + var diagonal = _data; + var d = Math.Min(denseOther.RowCount, ColumnCount); + if (d < ColumnCount) + { + result.ClearSubMatrix(denseOther.RowCount, ColumnCount - denseOther.RowCount, 0, denseOther.ColumnCount); + } + int index = 0; + for (int i = 0; i < denseOther.ColumnCount; i++) + { + for (int j = 0; j < d; j++) + { + result.At(j, i, dense[index]*diagonal[j]); + index++; + } + index += (denseOther.RowCount - d); + } return; } - Multiply(otherDiagonal.Transpose(), result); + base.DoTransposeThisAndMultiply(other, result); } /// diff --git a/src/Numerics/LinearAlgebra/Complex32/DiagonalMatrix.cs b/src/Numerics/LinearAlgebra/Complex32/DiagonalMatrix.cs index 2213a10e..91ccf693 100644 --- a/src/Numerics/LinearAlgebra/Complex32/DiagonalMatrix.cs +++ b/src/Numerics/LinearAlgebra/Complex32/DiagonalMatrix.cs @@ -189,39 +189,6 @@ namespace MathNet.Numerics.LinearAlgebra.Complex32 i => new Complex32((float) distribution.Sample(), (float) distribution.Sample()))); } - /// - /// Adds another matrix to this matrix. - /// - /// The matrix to add to this matrix. - /// The result of the addition. - /// If the other matrix is . - /// If the two matrices don't have the same dimensions. - public override Matrix Add(Matrix other) - { - if (other == null) - { - throw new ArgumentNullException("other"); - } - - if (other.RowCount != RowCount || other.ColumnCount != ColumnCount) - { - throw DimensionsDontMatch(this, other, "other"); - } - - Matrix result; - if (other is DiagonalMatrix) - { - result = new DiagonalMatrix(RowCount, ColumnCount); - } - else - { - result = new DenseMatrix(RowCount, ColumnCount); - } - - Add(other, result); - return result; - } - /// /// Adds another matrix to this matrix. /// @@ -244,39 +211,6 @@ namespace MathNet.Numerics.LinearAlgebra.Complex32 } } - /// - /// Subtracts another matrix from this matrix. - /// - /// The matrix to subtract. - /// The result of the subtraction. - /// If the other matrix is . - /// If the two matrices don't have the same dimensions. - public override Matrix Subtract(Matrix other) - { - if (other == null) - { - throw new ArgumentNullException("other"); - } - - if (other.RowCount != RowCount || other.ColumnCount != ColumnCount) - { - throw DimensionsDontMatch(this, other, "other"); - } - - Matrix result; - if (other is DiagonalMatrix) - { - result = new DiagonalMatrix(RowCount, ColumnCount); - } - else - { - result = new DenseMatrix(RowCount, ColumnCount); - } - - Subtract(other, result); - return result; - } - /// /// Subtracts another matrix from this matrix. /// @@ -356,8 +290,6 @@ namespace MathNet.Numerics.LinearAlgebra.Complex32 /// /// The scalar to multiply the matrix with. /// The matrix to store the result of the multiplication. - /// If the result matrix is . - /// If the result matrix's dimensions are not the same as this matrix. protected override void DoMultiply(Complex32 scalar, Matrix result) { if (scalar.IsZero()) @@ -393,72 +325,44 @@ namespace MathNet.Numerics.LinearAlgebra.Complex32 /// /// The matrix to multiply with. /// The result of the multiplication. - /// If the other matrix is . - /// If the result matrix is . - /// If this.Columns != other.Rows. - /// If the result matrix's dimensions are not the this.Rows x other.Columns. - public override void Multiply(Matrix other, Matrix result) - { - if (other == null) - { - throw new ArgumentNullException("other"); - } - - if (result == null) - { - throw new ArgumentNullException("result"); - } - - if (ColumnCount != other.RowCount) - { - throw DimensionsDontMatch(this, other); - } - - if (result.RowCount != RowCount || result.ColumnCount != other.ColumnCount) - { - throw DimensionsDontMatch(this, other); - } - - var m = other as DiagonalMatrix; - var r = result as DiagonalMatrix; - - if (m == null || r == null) - { - base.Multiply(other, result); - } - else - { - var thisDataCopy = new Complex32[r._data.Length]; - var otherDataCopy = new Complex32[r._data.Length]; - Array.Copy(_data, thisDataCopy, (r._data.Length > _data.Length) ? _data.Length : r._data.Length); - Array.Copy(m._data, otherDataCopy, (r._data.Length > m._data.Length) ? m._data.Length : r._data.Length); - - Control.LinearAlgebraProvider.PointWiseMultiplyArrays(thisDataCopy, otherDataCopy, r._data); - } - } - - /// - /// Multiplies this matrix with another matrix and returns the result. - /// - /// The matrix to multiply with. - /// If this.Columns != other.Rows. - /// If the other matrix is . - /// The result of multiplication. - public override Matrix Multiply(Matrix other) + protected override void DoMultiply(Matrix other, Matrix result) { - if (other == null) + var diagonalOther = other as DiagonalMatrix; + var diagonalResult = result as DiagonalMatrix; + if (diagonalOther != null && diagonalResult != null) { - throw new ArgumentNullException("other"); + var thisDataCopy = new Complex32[diagonalResult._data.Length]; + var otherDataCopy = new Complex32[diagonalResult._data.Length]; + Array.Copy(_data, thisDataCopy, (diagonalResult._data.Length > _data.Length) ? _data.Length : diagonalResult._data.Length); + Array.Copy(diagonalOther._data, otherDataCopy, (diagonalResult._data.Length > diagonalOther._data.Length) ? diagonalOther._data.Length : diagonalResult._data.Length); + Control.LinearAlgebraProvider.PointWiseMultiplyArrays(thisDataCopy, otherDataCopy, diagonalResult._data); + return; } - if (ColumnCount != other.RowCount) + var denseOther = other.Storage as DenseColumnMajorMatrixStorage; + if (denseOther != null) { - throw DimensionsDontMatch(this, other); + var dense = denseOther.Data; + var diagonal = _data; + var d = Math.Min(denseOther.RowCount, RowCount); + if (d < RowCount) + { + result.ClearSubMatrix(denseOther.RowCount, RowCount - denseOther.RowCount, 0, denseOther.ColumnCount); + } + int index = 0; + for (int i = 0; i < denseOther.ColumnCount; i++) + { + for (int j = 0; j < d; j++) + { + result.At(j, i, dense[index]*diagonal[j]); + index++; + } + index += (denseOther.RowCount - d); + } + return; } - var result = other.CreateMatrix(RowCount, other.ColumnCount); - Multiply(other, result); - return result; + base.DoMultiply(other, result); } /// @@ -591,22 +495,88 @@ namespace MathNet.Numerics.LinearAlgebra.Complex32 /// /// The matrix to multiply with. /// The result of the multiplication. - /// If the other matrix is . - /// If the result matrix is . - /// If this.Columns != other.Rows. - /// If the result matrix's dimensions are not the this.Rows x other.Columns. - public override void TransposeAndMultiply(Matrix other, Matrix result) + protected override void DoTransposeAndMultiply(Matrix other, Matrix result) { - var otherDiagonal = other as DiagonalMatrix; - var resultDiagonal = result as DiagonalMatrix; + var diagonalOther = other as DiagonalMatrix; + var diagonalResult = result as DiagonalMatrix; + if (diagonalOther != null && diagonalResult != null) + { + var thisDataCopy = new Complex32[diagonalResult._data.Length]; + var otherDataCopy = new Complex32[diagonalResult._data.Length]; + Array.Copy(_data, thisDataCopy, (diagonalResult._data.Length > _data.Length) ? _data.Length : diagonalResult._data.Length); + Array.Copy(diagonalOther._data, otherDataCopy, (diagonalResult._data.Length > diagonalOther._data.Length) ? diagonalOther._data.Length : diagonalResult._data.Length); + Control.LinearAlgebraProvider.PointWiseMultiplyArrays(thisDataCopy, otherDataCopy, diagonalResult._data); + return; + } - if (otherDiagonal == null || resultDiagonal == null) + var denseOther = other.Storage as DenseColumnMajorMatrixStorage; + if (denseOther != null) { - base.TransposeAndMultiply(other, result); + var dense = denseOther.Data; + var diagonal = _data; + var d = Math.Min(denseOther.ColumnCount, RowCount); + if (d < RowCount) + { + result.ClearSubMatrix(denseOther.ColumnCount, RowCount - denseOther.ColumnCount, 0, denseOther.RowCount); + } + int index = 0; + for (int j = 0; j < d; j++) + { + for (int i = 0; i < denseOther.RowCount; i++) + { + result.At(j, i, dense[index]*diagonal[j]); + index++; + } + } + return; + } + + base.DoTransposeAndMultiply(other, result); + } + + /// + /// Multiplies the transpose of this matrix with another matrix and places the results into the result matrix. + /// + /// The matrix to multiply with. + /// The result of the multiplication. + protected override void DoTransposeThisAndMultiply(Matrix other, Matrix result) + { + var diagonalOther = other as DiagonalMatrix; + var diagonalResult = result as DiagonalMatrix; + if (diagonalOther != null && diagonalResult != null) + { + var thisDataCopy = new Complex32[diagonalResult._data.Length]; + var otherDataCopy = new Complex32[diagonalResult._data.Length]; + Array.Copy(_data, thisDataCopy, (diagonalResult._data.Length > _data.Length) ? _data.Length : diagonalResult._data.Length); + Array.Copy(diagonalOther._data, otherDataCopy, (diagonalResult._data.Length > diagonalOther._data.Length) ? diagonalOther._data.Length : diagonalResult._data.Length); + Control.LinearAlgebraProvider.PointWiseMultiplyArrays(thisDataCopy, otherDataCopy, diagonalResult._data); + return; + } + + var denseOther = other.Storage as DenseColumnMajorMatrixStorage; + if (denseOther != null) + { + var dense = denseOther.Data; + var diagonal = _data; + var d = Math.Min(denseOther.RowCount, ColumnCount); + if (d < ColumnCount) + { + result.ClearSubMatrix(denseOther.RowCount, ColumnCount - denseOther.RowCount, 0, denseOther.ColumnCount); + } + int index = 0; + for (int i = 0; i < denseOther.ColumnCount; i++) + { + for (int j = 0; j < d; j++) + { + result.At(j, i, dense[index]*diagonal[j]); + index++; + } + index += (denseOther.RowCount - d); + } return; } - Multiply(otherDiagonal.Transpose(), result); + base.DoTransposeThisAndMultiply(other, result); } /// diff --git a/src/Numerics/LinearAlgebra/Double/DiagonalMatrix.cs b/src/Numerics/LinearAlgebra/Double/DiagonalMatrix.cs index 872d85b3..2e0b8216 100644 --- a/src/Numerics/LinearAlgebra/Double/DiagonalMatrix.cs +++ b/src/Numerics/LinearAlgebra/Double/DiagonalMatrix.cs @@ -188,39 +188,6 @@ namespace MathNet.Numerics.LinearAlgebra.Double i => distribution.Sample())); } - /// - /// Adds another matrix to this matrix. - /// - /// The matrix to add to this matrix. - /// The result of the addition. - /// If the other matrix is . - /// If the two matrices don't have the same dimensions. - public override Matrix Add(Matrix other) - { - if (other == null) - { - throw new ArgumentNullException("other"); - } - - if (other.RowCount != RowCount || other.ColumnCount != ColumnCount) - { - throw DimensionsDontMatch(this, other, "other"); - } - - Matrix result; - if (other is DiagonalMatrix) - { - result = new DiagonalMatrix(RowCount, ColumnCount); - } - else - { - result = new DenseMatrix(RowCount, ColumnCount); - } - - Add(other, result); - return result; - } - /// /// Adds another matrix to this matrix. /// @@ -243,39 +210,6 @@ namespace MathNet.Numerics.LinearAlgebra.Double } } - /// - /// Subtracts another matrix from this matrix. - /// - /// The matrix to subtract. - /// The result of the subtraction. - /// If the other matrix is . - /// If the two matrices don't have the same dimensions. - public override Matrix Subtract(Matrix other) - { - if (other == null) - { - throw new ArgumentNullException("other"); - } - - if (other.RowCount != RowCount || other.ColumnCount != ColumnCount) - { - throw DimensionsDontMatch(this, other, "other"); - } - - Matrix result; - if (other is DiagonalMatrix) - { - result = new DiagonalMatrix(RowCount, ColumnCount); - } - else - { - result = new DenseMatrix(RowCount, ColumnCount); - } - - Subtract(other, result); - return result; - } - /// /// Subtracts another matrix from this matrix. /// @@ -392,67 +326,44 @@ namespace MathNet.Numerics.LinearAlgebra.Double /// /// The matrix to multiply with. /// The result of the multiplication. - /// If the other matrix is . - /// If the result matrix is . - /// If this.Columns != other.Rows. - /// If the result matrix's dimensions are not the this.Rows x other.Columns. - public override void Multiply(Matrix other, Matrix result) - { - if (other == null) - { - throw new ArgumentNullException("other"); - } - - if (result == null) - { - throw new ArgumentNullException("result"); - } - - if (ColumnCount != other.RowCount || result.RowCount != RowCount || result.ColumnCount != other.ColumnCount) - { - throw DimensionsDontMatch(this, other, result); - } - - var m = other as DiagonalMatrix; - var r = result as DiagonalMatrix; - - if (m == null || r == null) - { - base.Multiply(other, result); - } - else - { - var thisDataCopy = new double[r._data.Length]; - var otherDataCopy = new double[r._data.Length]; - Buffer.BlockCopy(_data, 0, thisDataCopy, 0, (r._data.Length > _data.Length) ? _data.Length * Constants.SizeOfDouble : r._data.Length * Constants.SizeOfDouble); - Buffer.BlockCopy(m._data, 0, otherDataCopy, 0, (r._data.Length > m._data.Length) ? m._data.Length * Constants.SizeOfDouble : r._data.Length * Constants.SizeOfDouble); - - Control.LinearAlgebraProvider.PointWiseMultiplyArrays(thisDataCopy, otherDataCopy, r._data); - } - } - - /// - /// Multiplies this matrix with another matrix and returns the result. - /// - /// The matrix to multiply with. - /// If this.Columns != other.Rows. - /// If the other matrix is . - /// The result of multiplication. - public override Matrix Multiply(Matrix other) + protected override void DoMultiply(Matrix other, Matrix result) { - if (other == null) + var diagonalOther = other as DiagonalMatrix; + var diagonalResult = result as DiagonalMatrix; + if (diagonalOther != null && diagonalResult != null) { - throw new ArgumentNullException("other"); + var thisDataCopy = new double[diagonalResult._data.Length]; + var otherDataCopy = new double[diagonalResult._data.Length]; + Array.Copy(_data, thisDataCopy, (diagonalResult._data.Length > _data.Length) ? _data.Length : diagonalResult._data.Length); + Array.Copy(diagonalOther._data, otherDataCopy, (diagonalResult._data.Length > diagonalOther._data.Length) ? diagonalOther._data.Length : diagonalResult._data.Length); + Control.LinearAlgebraProvider.PointWiseMultiplyArrays(thisDataCopy, otherDataCopy, diagonalResult._data); + return; } - if (ColumnCount != other.RowCount) + var denseOther = other.Storage as DenseColumnMajorMatrixStorage; + if (denseOther != null) { - throw DimensionsDontMatch(this, other); + var dense = denseOther.Data; + var diagonal = _data; + var d = Math.Min(denseOther.RowCount, RowCount); + if (d < RowCount) + { + result.ClearSubMatrix(denseOther.RowCount, RowCount - denseOther.RowCount, 0, denseOther.ColumnCount); + } + int index = 0; + for (int i = 0; i < denseOther.ColumnCount; i++) + { + for (int j = 0; j < d; j++) + { + result.At(j, i, dense[index]*diagonal[j]); + index++; + } + index += (denseOther.RowCount - d); + } + return; } - var result = other.CreateMatrix(RowCount, other.ColumnCount); - Multiply(other, result); - return result; + base.DoMultiply(other, result); } /// @@ -585,22 +496,88 @@ namespace MathNet.Numerics.LinearAlgebra.Double /// /// The matrix to multiply with. /// The result of the multiplication. - /// If the other matrix is . - /// If the result matrix is . - /// If this.Columns != other.Rows. - /// If the result matrix's dimensions are not the this.Rows x other.Columns. - public override void TransposeAndMultiply(Matrix other, Matrix result) + protected override void DoTransposeAndMultiply(Matrix other, Matrix result) + { + var diagonalOther = other as DiagonalMatrix; + var diagonalResult = result as DiagonalMatrix; + if (diagonalOther != null && diagonalResult != null) + { + var thisDataCopy = new double[diagonalResult._data.Length]; + var otherDataCopy = new double[diagonalResult._data.Length]; + Array.Copy(_data, thisDataCopy, (diagonalResult._data.Length > _data.Length) ? _data.Length : diagonalResult._data.Length); + Array.Copy(diagonalOther._data, otherDataCopy, (diagonalResult._data.Length > diagonalOther._data.Length) ? diagonalOther._data.Length : diagonalResult._data.Length); + Control.LinearAlgebraProvider.PointWiseMultiplyArrays(thisDataCopy, otherDataCopy, diagonalResult._data); + return; + } + + var denseOther = other.Storage as DenseColumnMajorMatrixStorage; + if (denseOther != null) + { + var dense = denseOther.Data; + var diagonal = _data; + var d = Math.Min(denseOther.ColumnCount, RowCount); + if (d < RowCount) + { + result.ClearSubMatrix(denseOther.ColumnCount, RowCount - denseOther.ColumnCount, 0, denseOther.RowCount); + } + int index = 0; + for (int j = 0; j < d; j++) + { + for (int i = 0; i < denseOther.RowCount; i++) + { + result.At(j, i, dense[index]*diagonal[j]); + index++; + } + } + return; + } + + base.DoTransposeAndMultiply(other, result); + } + + /// + /// Multiplies the transpose of this matrix with another matrix and places the results into the result matrix. + /// + /// The matrix to multiply with. + /// The result of the multiplication. + protected override void DoTransposeThisAndMultiply(Matrix other, Matrix result) { - var otherDiagonal = other as DiagonalMatrix; - var resultDiagonal = result as DiagonalMatrix; + var diagonalOther = other as DiagonalMatrix; + var diagonalResult = result as DiagonalMatrix; + if (diagonalOther != null && diagonalResult != null) + { + var thisDataCopy = new double[diagonalResult._data.Length]; + var otherDataCopy = new double[diagonalResult._data.Length]; + Array.Copy(_data, thisDataCopy, (diagonalResult._data.Length > _data.Length) ? _data.Length : diagonalResult._data.Length); + Array.Copy(diagonalOther._data, otherDataCopy, (diagonalResult._data.Length > diagonalOther._data.Length) ? diagonalOther._data.Length : diagonalResult._data.Length); + Control.LinearAlgebraProvider.PointWiseMultiplyArrays(thisDataCopy, otherDataCopy, diagonalResult._data); + return; + } - if (otherDiagonal == null || resultDiagonal == null) + var denseOther = other.Storage as DenseColumnMajorMatrixStorage; + if (denseOther != null) { - base.TransposeAndMultiply(other, result); + var dense = denseOther.Data; + var diagonal = _data; + var d = Math.Min(denseOther.RowCount, ColumnCount); + if (d < ColumnCount) + { + result.ClearSubMatrix(denseOther.RowCount, ColumnCount - denseOther.RowCount, 0, denseOther.ColumnCount); + } + int index = 0; + for (int i = 0; i < denseOther.ColumnCount; i++) + { + for (int j = 0; j < d; j++) + { + result.At(j, i, dense[index]*diagonal[j]); + index++; + } + index += (denseOther.RowCount - d); + } return; } - Multiply(otherDiagonal.Transpose(), result); + base.DoTransposeThisAndMultiply(other, result); } /// diff --git a/src/Numerics/LinearAlgebra/Matrix.Arithmetic.cs b/src/Numerics/LinearAlgebra/Matrix.Arithmetic.cs index bb5fcebf..bd36f09a 100644 --- a/src/Numerics/LinearAlgebra/Matrix.Arithmetic.cs +++ b/src/Numerics/LinearAlgebra/Matrix.Arithmetic.cs @@ -204,7 +204,7 @@ namespace MathNet.Numerics.LinearAlgebra return Clone(); } - var result = CreateMatrix(RowCount, ColumnCount); + var result = Build.SameType(this, RowCount, ColumnCount); DoAdd(scalar, result); return result; } @@ -237,14 +237,14 @@ namespace MathNet.Numerics.LinearAlgebra /// The matrix to add to this matrix. /// The result of the addition. /// If the two matrices don't have the same dimensions. - public virtual Matrix Add(Matrix other) + public Matrix Add(Matrix other) { if (other.RowCount != RowCount || other.ColumnCount != ColumnCount) { throw DimensionsDontMatch(this, other); } - var result = CreateMatrix(RowCount, ColumnCount); + var result = Build.SameType(this, other, RowCount, ColumnCount); DoAdd(other, result); return result; } @@ -282,7 +282,7 @@ namespace MathNet.Numerics.LinearAlgebra return Clone(); } - var result = CreateMatrix(RowCount, ColumnCount); + var result = Build.SameType(this, RowCount, ColumnCount); DoSubtract(scalar, result); return result; } @@ -316,7 +316,7 @@ namespace MathNet.Numerics.LinearAlgebra /// A new matrix containing the subtraction of the scalar and this matrix. public Matrix SubtractFrom(T scalar) { - var result = CreateMatrix(RowCount, ColumnCount); + var result = Build.SameType(this, RowCount, ColumnCount); DoSubtractFrom(scalar, result); return result; } @@ -343,14 +343,14 @@ namespace MathNet.Numerics.LinearAlgebra /// The matrix to subtract. /// The result of the subtraction. /// If the two matrices don't have the same dimensions. - public virtual Matrix Subtract(Matrix other) + public Matrix Subtract(Matrix other) { if (other.RowCount != RowCount || other.ColumnCount != ColumnCount) { throw DimensionsDontMatch(this, other); } - var result = CreateMatrix(RowCount, ColumnCount); + var result = Build.SameType(this, other, RowCount, ColumnCount); DoSubtract(other, result); return result; } @@ -633,7 +633,7 @@ namespace MathNet.Numerics.LinearAlgebra /// The result of the multiplication. /// If this.Columns != other.Rows. /// If the result matrix's dimensions are not the this.Rows x other.Columns. - public virtual void Multiply(Matrix other, Matrix result) + public void Multiply(Matrix other, Matrix result) { if (ColumnCount != other.RowCount || result.RowCount != RowCount || result.ColumnCount != other.ColumnCount) { @@ -642,7 +642,7 @@ namespace MathNet.Numerics.LinearAlgebra if (ReferenceEquals(this, result) || ReferenceEquals(other, result)) { - var tmp = result.CreateMatrix(result.RowCount, result.ColumnCount); + var tmp = Build.SameType(result, result.RowCount, result.ColumnCount); DoMultiply(other, tmp); tmp.CopyTo(result); } @@ -658,14 +658,14 @@ namespace MathNet.Numerics.LinearAlgebra /// The matrix to multiply with. /// If this.Columns != other.Rows. /// The result of the multiplication. - public virtual Matrix Multiply(Matrix other) + public Matrix Multiply(Matrix other) { if (ColumnCount != other.RowCount) { throw DimensionsDontMatch(this, other); } - var result = CreateMatrix(RowCount, other.ColumnCount); + var result = Build.SameType(this, other, RowCount, other.ColumnCount); DoMultiply(other, result); return result; } @@ -686,7 +686,7 @@ namespace MathNet.Numerics.LinearAlgebra if (ReferenceEquals(this, result) || ReferenceEquals(other, result)) { - var tmp = result.CreateMatrix(result.RowCount, result.ColumnCount); + var tmp = Build.SameType(result, result.RowCount, result.ColumnCount); DoTransposeAndMultiply(other, tmp); tmp.CopyTo(result); } @@ -709,7 +709,7 @@ namespace MathNet.Numerics.LinearAlgebra throw DimensionsDontMatch(this, other); } - var result = CreateMatrix(RowCount, other.RowCount); + var result = Build.SameType(this, other, RowCount, other.RowCount); DoTransposeAndMultiply(other, result); return result; } @@ -727,9 +727,9 @@ namespace MathNet.Numerics.LinearAlgebra throw DimensionsDontMatch(this, rightSide, "rightSide"); } - var ret = CreateVector(ColumnCount); - DoTransposeThisAndMultiply(rightSide, ret); - return ret; + var result = CreateVector(ColumnCount); + DoTransposeThisAndMultiply(rightSide, result); + return result; } /// @@ -779,7 +779,7 @@ namespace MathNet.Numerics.LinearAlgebra if (ReferenceEquals(this, result) || ReferenceEquals(other, result)) { - var tmp = result.CreateMatrix(result.RowCount, result.ColumnCount); + var tmp = Build.SameType(result, result.RowCount, result.ColumnCount); DoTransposeThisAndMultiply(other, tmp); tmp.CopyTo(result); } @@ -802,7 +802,7 @@ namespace MathNet.Numerics.LinearAlgebra throw DimensionsDontMatch(this, other); } - var result = CreateMatrix(ColumnCount, other.ColumnCount); + var result = Build.SameType(this, other, ColumnCount, other.ColumnCount); DoTransposeThisAndMultiply(other, result); return result; } diff --git a/src/Numerics/LinearAlgebra/Single/DiagonalMatrix.cs b/src/Numerics/LinearAlgebra/Single/DiagonalMatrix.cs index 42d5b944..381a3f41 100644 --- a/src/Numerics/LinearAlgebra/Single/DiagonalMatrix.cs +++ b/src/Numerics/LinearAlgebra/Single/DiagonalMatrix.cs @@ -188,39 +188,6 @@ namespace MathNet.Numerics.LinearAlgebra.Single i => (float) distribution.Sample())); } - /// - /// Adds another matrix to this matrix. - /// - /// The matrix to add to this matrix. - /// The result of the addition. - /// If the other matrix is . - /// If the two matrices don't have the same dimensions. - public override Matrix Add(Matrix other) - { - if (other == null) - { - throw new ArgumentNullException("other"); - } - - if (other.RowCount != RowCount || other.ColumnCount != ColumnCount) - { - throw DimensionsDontMatch(this, other, "other"); - } - - Matrix result; - if (other is DiagonalMatrix) - { - result = new DiagonalMatrix(RowCount, ColumnCount); - } - else - { - result = new DenseMatrix(RowCount, ColumnCount); - } - - Add(other, result); - return result; - } - /// /// Adds another matrix to this matrix. /// @@ -243,39 +210,6 @@ namespace MathNet.Numerics.LinearAlgebra.Single } } - /// - /// Subtracts another matrix from this matrix. - /// - /// The matrix to subtract. - /// The result of the subtraction. - /// If the other matrix is . - /// If the two matrices don't have the same dimensions. - public override Matrix Subtract(Matrix other) - { - if (other == null) - { - throw new ArgumentNullException("other"); - } - - if (other.RowCount != RowCount || other.ColumnCount != ColumnCount) - { - throw DimensionsDontMatch(this, other, "other"); - } - - Matrix result; - if (other is DiagonalMatrix) - { - result = new DiagonalMatrix(RowCount, ColumnCount); - } - else - { - result = new DenseMatrix(RowCount, ColumnCount); - } - - Subtract(other, result); - return result; - } - /// /// Subtracts another matrix from this matrix. /// @@ -392,67 +326,44 @@ namespace MathNet.Numerics.LinearAlgebra.Single /// /// The matrix to multiply with. /// The result of the multiplication. - /// If the other matrix is . - /// If the result matrix is . - /// If this.Columns != other.Rows. - /// If the result matrix's dimensions are not the this.Rows x other.Columns. - public override void Multiply(Matrix other, Matrix result) - { - if (other == null) - { - throw new ArgumentNullException("other"); - } - - if (result == null) - { - throw new ArgumentNullException("result"); - } - - if (ColumnCount != other.RowCount || result.RowCount != RowCount || result.ColumnCount != other.ColumnCount) - { - throw DimensionsDontMatch(this, other, result); - } - - var m = other as DiagonalMatrix; - var r = result as DiagonalMatrix; - - if (m == null || r == null) - { - base.Multiply(other, result); - } - else - { - var thisDataCopy = new float[r._data.Length]; - var otherDataCopy = new float[r._data.Length]; - Buffer.BlockCopy(_data, 0, thisDataCopy, 0, (r._data.Length > _data.Length) ? _data.Length * Constants.SizeOfFloat : r._data.Length * Constants.SizeOfFloat); - Buffer.BlockCopy(m._data, 0, otherDataCopy, 0, (r._data.Length > m._data.Length) ? m._data.Length * Constants.SizeOfFloat : r._data.Length * Constants.SizeOfFloat); - - Control.LinearAlgebraProvider.PointWiseMultiplyArrays(thisDataCopy, otherDataCopy, r._data); - } - } - - /// - /// Multiplies this matrix with another matrix and returns the result. - /// - /// The matrix to multiply with. - /// If this.Columns != other.Rows. - /// If the other matrix is . - /// The result of multiplication. - public override Matrix Multiply(Matrix other) + protected override void DoMultiply(Matrix other, Matrix result) { - if (other == null) + var diagonalOther = other as DiagonalMatrix; + var diagonalResult = result as DiagonalMatrix; + if (diagonalOther != null && diagonalResult != null) { - throw new ArgumentNullException("other"); + var thisDataCopy = new float[diagonalResult._data.Length]; + var otherDataCopy = new float[diagonalResult._data.Length]; + Array.Copy(_data, thisDataCopy, (diagonalResult._data.Length > _data.Length) ? _data.Length : diagonalResult._data.Length); + Array.Copy(diagonalOther._data, otherDataCopy, (diagonalResult._data.Length > diagonalOther._data.Length) ? diagonalOther._data.Length : diagonalResult._data.Length); + Control.LinearAlgebraProvider.PointWiseMultiplyArrays(thisDataCopy, otherDataCopy, diagonalResult._data); + return; } - if (ColumnCount != other.RowCount) + var denseOther = other.Storage as DenseColumnMajorMatrixStorage; + if (denseOther != null) { - throw DimensionsDontMatch(this, other); + var dense = denseOther.Data; + var diagonal = _data; + var d = Math.Min(denseOther.RowCount, RowCount); + if (d < RowCount) + { + result.ClearSubMatrix(denseOther.RowCount, RowCount - denseOther.RowCount, 0, denseOther.ColumnCount); + } + int index = 0; + for (int i = 0; i < denseOther.ColumnCount; i++) + { + for (int j = 0; j < d; j++) + { + result.At(j, i, dense[index]*diagonal[j]); + index++; + } + index += (denseOther.RowCount - d); + } + return; } - var result = other.CreateMatrix(RowCount, other.ColumnCount); - Multiply(other, result); - return result; + base.DoMultiply(other, result); } /// @@ -585,22 +496,88 @@ namespace MathNet.Numerics.LinearAlgebra.Single /// /// The matrix to multiply with. /// The result of the multiplication. - /// If the other matrix is . - /// If the result matrix is . - /// If this.Columns != other.Rows. - /// If the result matrix's dimensions are not the this.Rows x other.Columns. - public override void TransposeAndMultiply(Matrix other, Matrix result) + protected override void DoTransposeAndMultiply(Matrix other, Matrix result) + { + var diagonalOther = other as DiagonalMatrix; + var diagonalResult = result as DiagonalMatrix; + if (diagonalOther != null && diagonalResult != null) + { + var thisDataCopy = new float[diagonalResult._data.Length]; + var otherDataCopy = new float[diagonalResult._data.Length]; + Array.Copy(_data, thisDataCopy, (diagonalResult._data.Length > _data.Length) ? _data.Length : diagonalResult._data.Length); + Array.Copy(diagonalOther._data, otherDataCopy, (diagonalResult._data.Length > diagonalOther._data.Length) ? diagonalOther._data.Length : diagonalResult._data.Length); + Control.LinearAlgebraProvider.PointWiseMultiplyArrays(thisDataCopy, otherDataCopy, diagonalResult._data); + return; + } + + var denseOther = other.Storage as DenseColumnMajorMatrixStorage; + if (denseOther != null) + { + var dense = denseOther.Data; + var diagonal = _data; + var d = Math.Min(denseOther.ColumnCount, RowCount); + if (d < RowCount) + { + result.ClearSubMatrix(denseOther.ColumnCount, RowCount - denseOther.ColumnCount, 0, denseOther.RowCount); + } + int index = 0; + for (int j = 0; j < d; j++) + { + for (int i = 0; i < denseOther.RowCount; i++) + { + result.At(j, i, dense[index]*diagonal[j]); + index++; + } + } + return; + } + + base.DoTransposeAndMultiply(other, result); + } + + /// + /// Multiplies the transpose of this matrix with another matrix and places the results into the result matrix. + /// + /// The matrix to multiply with. + /// The result of the multiplication. + protected override void DoTransposeThisAndMultiply(Matrix other, Matrix result) { - var otherDiagonal = other as DiagonalMatrix; - var resultDiagonal = result as DiagonalMatrix; + var diagonalOther = other as DiagonalMatrix; + var diagonalResult = result as DiagonalMatrix; + if (diagonalOther != null && diagonalResult != null) + { + var thisDataCopy = new float[diagonalResult._data.Length]; + var otherDataCopy = new float[diagonalResult._data.Length]; + Array.Copy(_data, thisDataCopy, (diagonalResult._data.Length > _data.Length) ? _data.Length : diagonalResult._data.Length); + Array.Copy(diagonalOther._data, otherDataCopy, (diagonalResult._data.Length > diagonalOther._data.Length) ? diagonalOther._data.Length : diagonalResult._data.Length); + Control.LinearAlgebraProvider.PointWiseMultiplyArrays(thisDataCopy, otherDataCopy, diagonalResult._data); + return; + } - if (otherDiagonal == null || resultDiagonal == null) + var denseOther = other.Storage as DenseColumnMajorMatrixStorage; + if (denseOther != null) { - base.TransposeAndMultiply(other, result); + var dense = denseOther.Data; + var diagonal = _data; + var d = Math.Min(denseOther.RowCount, ColumnCount); + if (d < ColumnCount) + { + result.ClearSubMatrix(denseOther.RowCount, ColumnCount - denseOther.RowCount, 0, denseOther.ColumnCount); + } + int index = 0; + for (int i = 0; i < denseOther.ColumnCount; i++) + { + for (int j = 0; j < d; j++) + { + result.At(j, i, dense[index]*diagonal[j]); + index++; + } + index += (denseOther.RowCount - d); + } return; } - Multiply(otherDiagonal.Transpose(), result); + base.DoTransposeThisAndMultiply(other, result); } /// diff --git a/src/UnitTests/LinearAlgebraTests/Complex/DiagonalMatrixTests.cs b/src/UnitTests/LinearAlgebraTests/Complex/DiagonalMatrixTests.cs index 55de88b8..928e5195 100644 --- a/src/UnitTests/LinearAlgebraTests/Complex/DiagonalMatrixTests.cs +++ b/src/UnitTests/LinearAlgebraTests/Complex/DiagonalMatrixTests.cs @@ -436,5 +436,56 @@ namespace MathNet.Numerics.UnitTests.LinearAlgebraTests.Complex Assert.IsTrue(tall.TransposeThisAndMultiply(Matrix.Build.Diagonal(8, 10, 2d)).Equals(tall.Transpose().Multiply(2d).Append(Matrix.Build.Dense(3, 2)))); Assert.IsTrue(tall.TransposeThisAndMultiply(Matrix.Build.Diagonal(8, 2, 2d)).Equals(tall.Transpose().Multiply(2d).SubMatrix(0, 3, 0, 2))); } + + [Test] + public void DiagonalDenseMatrixMultiply() + { + var dist = new ContinuousUniform(-1.0, 1.0, new MersenneTwister()); + Assert.IsInstanceOf(Matrix.Build.DiagonalIdentity(3, 3)); + + var wide = Matrix.Build.Random(3, 8, dist); + Assert.IsTrue((Matrix.Build.DiagonalIdentity(3).Multiply(2d)*wide).Equals(wide.Multiply(2d))); + Assert.IsTrue((Matrix.Build.Diagonal(5, 3, 2d)*wide).Equals(wide.Multiply(2d).Stack(Matrix.Build.Dense(2, 8)))); + Assert.IsTrue((Matrix.Build.Diagonal(2, 3, 2d)*wide).Equals(wide.Multiply(2d).SubMatrix(0, 2, 0, 8))); + + var tall = Matrix.Build.Random(8, 3, dist); + Assert.IsTrue((Matrix.Build.DiagonalIdentity(8).Multiply(2d)*tall).Equals(tall.Multiply(2d))); + Assert.IsTrue((Matrix.Build.Diagonal(10, 8, 2d)*tall).Equals(tall.Multiply(2d).Stack(Matrix.Build.Dense(2, 3)))); + Assert.IsTrue((Matrix.Build.Diagonal(2, 8, 2d)*tall).Equals(tall.Multiply(2d).SubMatrix(0, 2, 0, 3))); + } + + [Test] + public void DiagonalDenseMatrixTransposeAndMultiply() + { + var dist = new ContinuousUniform(-1.0, 1.0, new MersenneTwister()); + Assert.IsInstanceOf(Matrix.Build.DiagonalIdentity(3, 3)); + + var tall = Matrix.Build.Random(8, 3, dist); + Assert.IsTrue(Matrix.Build.DiagonalIdentity(3).Multiply(2d).TransposeAndMultiply(tall).Equals(tall.Multiply(2d).Transpose())); + Assert.IsTrue(Matrix.Build.Diagonal(5, 3, 2d).TransposeAndMultiply(tall).Equals(tall.Multiply(2d).Append(Matrix.Build.Dense(8, 2)).Transpose())); + Assert.IsTrue(Matrix.Build.Diagonal(2, 3, 2d).TransposeAndMultiply(tall).Equals(tall.Multiply(2d).SubMatrix(0, 8, 0, 2).Transpose())); + + var wide = Matrix.Build.Random(3, 8, dist); + Assert.IsTrue(Matrix.Build.DiagonalIdentity(8).Multiply(2d).TransposeAndMultiply(wide).Equals(wide.Multiply(2d).Transpose())); + Assert.IsTrue(Matrix.Build.Diagonal(10, 8, 2d).TransposeAndMultiply(wide).Equals(wide.Multiply(2d).Append(Matrix.Build.Dense(3, 2)).Transpose())); + Assert.IsTrue(Matrix.Build.Diagonal(2, 8, 2d).TransposeAndMultiply(wide).Equals(wide.Multiply(2d).SubMatrix(0, 3, 0, 2).Transpose())); + } + + [Test] + public void DiagonalDenseMatrixTransposeThisAndMultiply() + { + var dist = new ContinuousUniform(-1.0, 1.0, new MersenneTwister()); + Assert.IsInstanceOf(Matrix.Build.DiagonalIdentity(3, 3)); + + var wide = Matrix.Build.Random(3, 8, dist); + Assert.IsTrue((Matrix.Build.DiagonalIdentity(3).Multiply(2d).TransposeThisAndMultiply(wide)).Equals(wide.Multiply(2d))); + Assert.IsTrue((Matrix.Build.Diagonal(3, 5, 2d).TransposeThisAndMultiply(wide)).Equals(wide.Multiply(2d).Stack(Matrix.Build.Dense(2, 8)))); + Assert.IsTrue((Matrix.Build.Diagonal(3, 2, 2d).TransposeThisAndMultiply(wide)).Equals(wide.Multiply(2d).SubMatrix(0, 2, 0, 8))); + + var tall = Matrix.Build.Random(8, 3, dist); + Assert.IsTrue((Matrix.Build.DiagonalIdentity(8).Multiply(2d).TransposeThisAndMultiply(tall)).Equals(tall.Multiply(2d))); + Assert.IsTrue((Matrix.Build.Diagonal(8, 10, 2d).TransposeThisAndMultiply(tall)).Equals(tall.Multiply(2d).Stack(Matrix.Build.Dense(2, 3)))); + Assert.IsTrue((Matrix.Build.Diagonal(8, 2, 2d).TransposeThisAndMultiply(tall)).Equals(tall.Multiply(2d).SubMatrix(0, 2, 0, 3))); + } } } diff --git a/src/UnitTests/LinearAlgebraTests/Complex32/DiagonalMatrixTests.cs b/src/UnitTests/LinearAlgebraTests/Complex32/DiagonalMatrixTests.cs index c4423a74..9afe6de7 100644 --- a/src/UnitTests/LinearAlgebraTests/Complex32/DiagonalMatrixTests.cs +++ b/src/UnitTests/LinearAlgebraTests/Complex32/DiagonalMatrixTests.cs @@ -432,5 +432,56 @@ namespace MathNet.Numerics.UnitTests.LinearAlgebraTests.Complex32 Assert.IsTrue(tall.TransposeThisAndMultiply(Matrix.Build.Diagonal(8, 10, 2f)).Equals(tall.Transpose().Multiply(2f).Append(Matrix.Build.Dense(3, 2)))); Assert.IsTrue(tall.TransposeThisAndMultiply(Matrix.Build.Diagonal(8, 2, 2f)).Equals(tall.Transpose().Multiply(2f).SubMatrix(0, 3, 0, 2))); } + + [Test] + public void DiagonalDenseMatrixMultiply() + { + var dist = new ContinuousUniform(-1.0, 1.0, new MersenneTwister()); + Assert.IsInstanceOf(Matrix.Build.DiagonalIdentity(3, 3)); + + var wide = Matrix.Build.Random(3, 8, dist); + Assert.IsTrue((Matrix.Build.DiagonalIdentity(3).Multiply(2f)*wide).Equals(wide.Multiply(2f))); + Assert.IsTrue((Matrix.Build.Diagonal(5, 3, 2f)*wide).Equals(wide.Multiply(2f).Stack(Matrix.Build.Dense(2, 8)))); + Assert.IsTrue((Matrix.Build.Diagonal(2, 3, 2f)*wide).Equals(wide.Multiply(2f).SubMatrix(0, 2, 0, 8))); + + var tall = Matrix.Build.Random(8, 3, dist); + Assert.IsTrue((Matrix.Build.DiagonalIdentity(8).Multiply(2f)*tall).Equals(tall.Multiply(2f))); + Assert.IsTrue((Matrix.Build.Diagonal(10, 8, 2f)*tall).Equals(tall.Multiply(2f).Stack(Matrix.Build.Dense(2, 3)))); + Assert.IsTrue((Matrix.Build.Diagonal(2, 8, 2f)*tall).Equals(tall.Multiply(2f).SubMatrix(0, 2, 0, 3))); + } + + [Test] + public void DiagonalDenseMatrixTransposeAndMultiply() + { + var dist = new ContinuousUniform(-1.0, 1.0, new MersenneTwister()); + Assert.IsInstanceOf(Matrix.Build.DiagonalIdentity(3, 3)); + + var tall = Matrix.Build.Random(8, 3, dist); + Assert.IsTrue(Matrix.Build.DiagonalIdentity(3).Multiply(2f).TransposeAndMultiply(tall).Equals(tall.Multiply(2f).Transpose())); + Assert.IsTrue(Matrix.Build.Diagonal(5, 3, 2f).TransposeAndMultiply(tall).Equals(tall.Multiply(2f).Append(Matrix.Build.Dense(8, 2)).Transpose())); + Assert.IsTrue(Matrix.Build.Diagonal(2, 3, 2f).TransposeAndMultiply(tall).Equals(tall.Multiply(2f).SubMatrix(0, 8, 0, 2).Transpose())); + + var wide = Matrix.Build.Random(3, 8, dist); + Assert.IsTrue(Matrix.Build.DiagonalIdentity(8).Multiply(2f).TransposeAndMultiply(wide).Equals(wide.Multiply(2f).Transpose())); + Assert.IsTrue(Matrix.Build.Diagonal(10, 8, 2f).TransposeAndMultiply(wide).Equals(wide.Multiply(2f).Append(Matrix.Build.Dense(3, 2)).Transpose())); + Assert.IsTrue(Matrix.Build.Diagonal(2, 8, 2f).TransposeAndMultiply(wide).Equals(wide.Multiply(2f).SubMatrix(0, 3, 0, 2).Transpose())); + } + + [Test] + public void DiagonalDenseMatrixTransposeThisAndMultiply() + { + var dist = new ContinuousUniform(-1.0, 1.0, new MersenneTwister()); + Assert.IsInstanceOf(Matrix.Build.DiagonalIdentity(3, 3)); + + var wide = Matrix.Build.Random(3, 8, dist); + Assert.IsTrue((Matrix.Build.DiagonalIdentity(3).Multiply(2f).TransposeThisAndMultiply(wide)).Equals(wide.Multiply(2f))); + Assert.IsTrue((Matrix.Build.Diagonal(3, 5, 2f).TransposeThisAndMultiply(wide)).Equals(wide.Multiply(2f).Stack(Matrix.Build.Dense(2, 8)))); + Assert.IsTrue((Matrix.Build.Diagonal(3, 2, 2f).TransposeThisAndMultiply(wide)).Equals(wide.Multiply(2f).SubMatrix(0, 2, 0, 8))); + + var tall = Matrix.Build.Random(8, 3, dist); + Assert.IsTrue((Matrix.Build.DiagonalIdentity(8).Multiply(2f).TransposeThisAndMultiply(tall)).Equals(tall.Multiply(2f))); + Assert.IsTrue((Matrix.Build.Diagonal(8, 10, 2f).TransposeThisAndMultiply(tall)).Equals(tall.Multiply(2f).Stack(Matrix.Build.Dense(2, 3)))); + Assert.IsTrue((Matrix.Build.Diagonal(8, 2, 2f).TransposeThisAndMultiply(tall)).Equals(tall.Multiply(2f).SubMatrix(0, 2, 0, 3))); + } } } diff --git a/src/UnitTests/LinearAlgebraTests/Double/DiagonalMatrixTests.cs b/src/UnitTests/LinearAlgebraTests/Double/DiagonalMatrixTests.cs index bb27d8b5..7bf5b72a 100644 --- a/src/UnitTests/LinearAlgebraTests/Double/DiagonalMatrixTests.cs +++ b/src/UnitTests/LinearAlgebraTests/Double/DiagonalMatrixTests.cs @@ -463,5 +463,56 @@ namespace MathNet.Numerics.UnitTests.LinearAlgebraTests.Double Assert.IsTrue(tall.TransposeThisAndMultiply(Matrix.Build.Diagonal(8, 10, 2d)).Equals(tall.Transpose().Multiply(2d).Append(Matrix.Build.Dense(3, 2)))); Assert.IsTrue(tall.TransposeThisAndMultiply(Matrix.Build.Diagonal(8, 2, 2d)).Equals(tall.Transpose().Multiply(2d).SubMatrix(0, 3, 0, 2))); } + + [Test] + public void DiagonalDenseMatrixMultiply() + { + var dist = new ContinuousUniform(-1.0, 1.0, new MersenneTwister()); + Assert.IsInstanceOf(Matrix.Build.DiagonalIdentity(3, 3)); + + var wide = Matrix.Build.Random(3, 8, dist); + Assert.IsTrue((Matrix.Build.DiagonalIdentity(3).Multiply(2d)*wide).Equals(wide.Multiply(2d))); + Assert.IsTrue((Matrix.Build.Diagonal(5, 3, 2d)*wide).Equals(wide.Multiply(2d).Stack(Matrix.Build.Dense(2, 8)))); + Assert.IsTrue((Matrix.Build.Diagonal(2, 3, 2d)*wide).Equals(wide.Multiply(2d).SubMatrix(0, 2, 0, 8))); + + var tall = Matrix.Build.Random(8, 3, dist); + Assert.IsTrue((Matrix.Build.DiagonalIdentity(8).Multiply(2d)*tall).Equals(tall.Multiply(2d))); + Assert.IsTrue((Matrix.Build.Diagonal(10, 8, 2d)*tall).Equals(tall.Multiply(2d).Stack(Matrix.Build.Dense(2, 3)))); + Assert.IsTrue((Matrix.Build.Diagonal(2, 8, 2d)*tall).Equals(tall.Multiply(2d).SubMatrix(0, 2, 0, 3))); + } + + [Test] + public void DiagonalDenseMatrixTransposeAndMultiply() + { + var dist = new ContinuousUniform(-1.0, 1.0, new MersenneTwister()); + Assert.IsInstanceOf(Matrix.Build.DiagonalIdentity(3, 3)); + + var tall = Matrix.Build.Random(8, 3, dist); + Assert.IsTrue(Matrix.Build.DiagonalIdentity(3).Multiply(2d).TransposeAndMultiply(tall).Equals(tall.Multiply(2d).Transpose())); + Assert.IsTrue(Matrix.Build.Diagonal(5, 3, 2d).TransposeAndMultiply(tall).Equals(tall.Multiply(2d).Append(Matrix.Build.Dense(8, 2)).Transpose())); + Assert.IsTrue(Matrix.Build.Diagonal(2, 3, 2d).TransposeAndMultiply(tall).Equals(tall.Multiply(2d).SubMatrix(0, 8, 0, 2).Transpose())); + + var wide = Matrix.Build.Random(3, 8, dist); + Assert.IsTrue(Matrix.Build.DiagonalIdentity(8).Multiply(2d).TransposeAndMultiply(wide).Equals(wide.Multiply(2d).Transpose())); + Assert.IsTrue(Matrix.Build.Diagonal(10, 8, 2d).TransposeAndMultiply(wide).Equals(wide.Multiply(2d).Append(Matrix.Build.Dense(3, 2)).Transpose())); + Assert.IsTrue(Matrix.Build.Diagonal(2, 8, 2d).TransposeAndMultiply(wide).Equals(wide.Multiply(2d).SubMatrix(0, 3, 0, 2).Transpose())); + } + + [Test] + public void DiagonalDenseMatrixTransposeThisAndMultiply() + { + var dist = new ContinuousUniform(-1.0, 1.0, new MersenneTwister()); + Assert.IsInstanceOf(Matrix.Build.DiagonalIdentity(3, 3)); + + var wide = Matrix.Build.Random(3, 8, dist); + Assert.IsTrue((Matrix.Build.DiagonalIdentity(3).Multiply(2d).TransposeThisAndMultiply(wide)).Equals(wide.Multiply(2d))); + Assert.IsTrue((Matrix.Build.Diagonal(3, 5, 2d).TransposeThisAndMultiply(wide)).Equals(wide.Multiply(2d).Stack(Matrix.Build.Dense(2, 8)))); + Assert.IsTrue((Matrix.Build.Diagonal(3, 2, 2d).TransposeThisAndMultiply(wide)).Equals(wide.Multiply(2d).SubMatrix(0, 2, 0, 8))); + + var tall = Matrix.Build.Random(8, 3, dist); + Assert.IsTrue((Matrix.Build.DiagonalIdentity(8).Multiply(2d).TransposeThisAndMultiply(tall)).Equals(tall.Multiply(2d))); + Assert.IsTrue((Matrix.Build.Diagonal(8, 10, 2d).TransposeThisAndMultiply(tall)).Equals(tall.Multiply(2d).Stack(Matrix.Build.Dense(2, 3)))); + Assert.IsTrue((Matrix.Build.Diagonal(8, 2, 2d).TransposeThisAndMultiply(tall)).Equals(tall.Multiply(2d).SubMatrix(0, 2, 0, 3))); + } } } diff --git a/src/UnitTests/LinearAlgebraTests/Single/DiagonalMatrixTests.cs b/src/UnitTests/LinearAlgebraTests/Single/DiagonalMatrixTests.cs index 6c253abd..118cc55e 100644 --- a/src/UnitTests/LinearAlgebraTests/Single/DiagonalMatrixTests.cs +++ b/src/UnitTests/LinearAlgebraTests/Single/DiagonalMatrixTests.cs @@ -430,5 +430,56 @@ namespace MathNet.Numerics.UnitTests.LinearAlgebraTests.Single Assert.IsTrue(tall.TransposeThisAndMultiply(Matrix.Build.Diagonal(8, 10, 2f)).Equals(tall.Transpose().Multiply(2f).Append(Matrix.Build.Dense(3, 2)))); Assert.IsTrue(tall.TransposeThisAndMultiply(Matrix.Build.Diagonal(8, 2, 2f)).Equals(tall.Transpose().Multiply(2f).SubMatrix(0, 3, 0, 2))); } + + [Test] + public void DiagonalDenseMatrixMultiply() + { + var dist = new ContinuousUniform(-1.0, 1.0, new MersenneTwister()); + Assert.IsInstanceOf(Matrix.Build.DiagonalIdentity(3, 3)); + + var wide = Matrix.Build.Random(3, 8, dist); + Assert.IsTrue((Matrix.Build.DiagonalIdentity(3).Multiply(2f)*wide).Equals(wide.Multiply(2f))); + Assert.IsTrue((Matrix.Build.Diagonal(5, 3, 2f)*wide).Equals(wide.Multiply(2f).Stack(Matrix.Build.Dense(2, 8)))); + Assert.IsTrue((Matrix.Build.Diagonal(2, 3, 2f)*wide).Equals(wide.Multiply(2f).SubMatrix(0, 2, 0, 8))); + + var tall = Matrix.Build.Random(8, 3, dist); + Assert.IsTrue((Matrix.Build.DiagonalIdentity(8).Multiply(2f)*tall).Equals(tall.Multiply(2f))); + Assert.IsTrue((Matrix.Build.Diagonal(10, 8, 2f)*tall).Equals(tall.Multiply(2f).Stack(Matrix.Build.Dense(2, 3)))); + Assert.IsTrue((Matrix.Build.Diagonal(2, 8, 2f)*tall).Equals(tall.Multiply(2f).SubMatrix(0, 2, 0, 3))); + } + + [Test] + public void DiagonalDenseMatrixTransposeAndMultiply() + { + var dist = new ContinuousUniform(-1.0, 1.0, new MersenneTwister()); + Assert.IsInstanceOf(Matrix.Build.DiagonalIdentity(3, 3)); + + var tall = Matrix.Build.Random(8, 3, dist); + Assert.IsTrue(Matrix.Build.DiagonalIdentity(3).Multiply(2f).TransposeAndMultiply(tall).Equals(tall.Multiply(2f).Transpose())); + Assert.IsTrue(Matrix.Build.Diagonal(5, 3, 2f).TransposeAndMultiply(tall).Equals(tall.Multiply(2f).Append(Matrix.Build.Dense(8, 2)).Transpose())); + Assert.IsTrue(Matrix.Build.Diagonal(2, 3, 2f).TransposeAndMultiply(tall).Equals(tall.Multiply(2f).SubMatrix(0, 8, 0, 2).Transpose())); + + var wide = Matrix.Build.Random(3, 8, dist); + Assert.IsTrue(Matrix.Build.DiagonalIdentity(8).Multiply(2f).TransposeAndMultiply(wide).Equals(wide.Multiply(2f).Transpose())); + Assert.IsTrue(Matrix.Build.Diagonal(10, 8, 2f).TransposeAndMultiply(wide).Equals(wide.Multiply(2f).Append(Matrix.Build.Dense(3, 2)).Transpose())); + Assert.IsTrue(Matrix.Build.Diagonal(2, 8, 2f).TransposeAndMultiply(wide).Equals(wide.Multiply(2f).SubMatrix(0, 3, 0, 2).Transpose())); + } + + [Test] + public void DiagonalDenseMatrixTransposeThisAndMultiply() + { + var dist = new ContinuousUniform(-1.0, 1.0, new MersenneTwister()); + Assert.IsInstanceOf(Matrix.Build.DiagonalIdentity(3, 3)); + + var wide = Matrix.Build.Random(3, 8, dist); + Assert.IsTrue((Matrix.Build.DiagonalIdentity(3).Multiply(2f).TransposeThisAndMultiply(wide)).Equals(wide.Multiply(2f))); + Assert.IsTrue((Matrix.Build.Diagonal(3, 5, 2f).TransposeThisAndMultiply(wide)).Equals(wide.Multiply(2f).Stack(Matrix.Build.Dense(2, 8)))); + Assert.IsTrue((Matrix.Build.Diagonal(3, 2, 2f).TransposeThisAndMultiply(wide)).Equals(wide.Multiply(2f).SubMatrix(0, 2, 0, 8))); + + var tall = Matrix.Build.Random(8, 3, dist); + Assert.IsTrue((Matrix.Build.DiagonalIdentity(8).Multiply(2f).TransposeThisAndMultiply(tall)).Equals(tall.Multiply(2f))); + Assert.IsTrue((Matrix.Build.Diagonal(8, 10, 2f).TransposeThisAndMultiply(tall)).Equals(tall.Multiply(2f).Stack(Matrix.Build.Dense(2, 3)))); + Assert.IsTrue((Matrix.Build.Diagonal(8, 2, 2f).TransposeThisAndMultiply(tall)).Equals(tall.Multiply(2f).SubMatrix(0, 2, 0, 3))); + } } }