diff --git a/src/Numerics/LinearAlgebra/Matrix.cs b/src/Numerics/LinearAlgebra/Matrix.cs
index cf8b7a7d..223bcf96 100644
--- a/src/Numerics/LinearAlgebra/Matrix.cs
+++ b/src/Numerics/LinearAlgebra/Matrix.cs
@@ -1735,5 +1735,47 @@ namespace MathNet.Numerics.LinearAlgebra
{
return EnumerateColumns().Aggregate(f);
}
+
+ ///
+ /// Applies a function to update the status with each value pair of two matrices and returns the resulting status.
+ ///
+ public TState Fold2(Func f, TState state, Matrix other, Zeros zeros = Zeros.AllowSkip)
+ where TOther : struct, IEquatable, IFormattable
+ {
+ return Storage.Fold2(other.Storage, f, state, zeros);
+ }
+
+ public Tuple Find(Func predicate, Zeros zeros = Zeros.AllowSkip)
+ {
+ return Storage.Find(predicate, zeros);
+ }
+
+ public Tuple Find2(Func predicate, Matrix other, Zeros zeros = Zeros.AllowSkip)
+ where TOther : struct, IEquatable, IFormattable
+ {
+ return Storage.Find2(other.Storage, predicate, zeros);
+ }
+
+ public bool Exists(Func predicate, Zeros zeros = Zeros.AllowSkip)
+ {
+ return Storage.Find(predicate, zeros) != null;
+ }
+
+ public bool Exists2(Func predicate, Matrix other, Zeros zeros = Zeros.AllowSkip)
+ where TOther : struct, IEquatable, IFormattable
+ {
+ return Storage.Find2(other.Storage, predicate, zeros) != null;
+ }
+
+ public bool ForAll(Func predicate, Zeros zeros = Zeros.AllowSkip)
+ {
+ return Storage.Find(x => !predicate(x), zeros) == null;
+ }
+
+ public bool ForAll2(Func predicate, Matrix other, Zeros zeros = Zeros.AllowSkip)
+ where TOther : struct, IEquatable, IFormattable
+ {
+ return Storage.Find2(other.Storage, (x, y) => !predicate(x, y), zeros) == null;
+ }
}
}
diff --git a/src/Numerics/LinearAlgebra/Storage/DenseColumnMajorMatrixStorage.cs b/src/Numerics/LinearAlgebra/Storage/DenseColumnMajorMatrixStorage.cs
index d6b8e5c1..265d12e1 100644
--- a/src/Numerics/LinearAlgebra/Storage/DenseColumnMajorMatrixStorage.cs
+++ b/src/Numerics/LinearAlgebra/Storage/DenseColumnMajorMatrixStorage.cs
@@ -907,5 +907,65 @@ namespace MathNet.Numerics.LinearAlgebra.Storage
target[j] = finalize(s, RowCount);
}
}
+
+ internal override TState Fold2Unchecked(MatrixStorage other, Func f, TState state, Zeros zeros)
+ {
+ var denseOther = other as DenseColumnMajorMatrixStorage;
+ if (denseOther != null)
+ {
+ TOther[] otherData = denseOther.Data;
+ for (int i = 0; i < Data.Length; i++)
+ {
+ state = f(state, Data[i], otherData[i]);
+ }
+ return state;
+ }
+
+ var diagonalOther = other as DiagonalMatrixStorage;
+ if (diagonalOther != null)
+ {
+ TOther[] otherData = diagonalOther.Data;
+ TOther otherZero = BuilderInstance.Matrix.Zero;
+ int k = 0;
+ for (int j = 0; j < ColumnCount; j++)
+ {
+ for (int i = 0; i < RowCount; i++)
+ {
+ state = f(state, Data[k], i == j ? otherData[i] : otherZero);
+ k++;
+ }
+ }
+ return state;
+ }
+
+ var sparseOther = other as SparseCompressedRowMatrixStorage;
+ if (sparseOther != null)
+ {
+ int[] otherRowPointers = sparseOther.RowPointers;
+ int[] otherColumnIndices = sparseOther.ColumnIndices;
+ TOther[] otherValues = sparseOther.Values;
+ TOther otherZero = BuilderInstance.Matrix.Zero;
+ int k = 0;
+ for (int row = 0; row < RowCount; row++)
+ {
+ for (int col = 0; col < ColumnCount; col++)
+ {
+ if (k < otherRowPointers[row + 1] && otherColumnIndices[k] == col)
+ {
+ state = f(state, Data[col*RowCount + row], otherValues[k++]);
+ }
+ else
+ {
+ state = f(state, Data[col*RowCount + row], otherZero);
+ }
+ }
+ }
+ return state;
+ }
+
+ // FALL BACK
+
+ return base.Fold2Unchecked(other, f, state, zeros);
+ }
}
}
diff --git a/src/Numerics/LinearAlgebra/Storage/DiagonalMatrixStorage.cs b/src/Numerics/LinearAlgebra/Storage/DiagonalMatrixStorage.cs
index 891e5249..8ec61019 100644
--- a/src/Numerics/LinearAlgebra/Storage/DiagonalMatrixStorage.cs
+++ b/src/Numerics/LinearAlgebra/Storage/DiagonalMatrixStorage.cs
@@ -1087,5 +1087,107 @@ namespace MathNet.Numerics.LinearAlgebra.Storage
}
}
}
+
+ internal override TState Fold2Unchecked(MatrixStorage other, Func f, TState state, Zeros zeros)
+ {
+ var denseOther = other as DenseColumnMajorMatrixStorage;
+ if (denseOther != null)
+ {
+ TOther[] otherData = denseOther.Data;
+ int k = 0;
+ for (int j = 0; j < ColumnCount; j++)
+ {
+ for (int i = 0; i < RowCount; i++)
+ {
+ state = f(state, i == j ? Data[i] : Zero, otherData[k]);
+ k++;
+ }
+ }
+ return state;
+ }
+
+ var diagonalOther = other as DiagonalMatrixStorage;
+ if (diagonalOther != null)
+ {
+ TOther[] otherData = diagonalOther.Data;
+ for (int i = 0; i < Data.Length; i++)
+ {
+ state = f(state, Data[i], otherData[i]);
+ }
+
+ // Do we really need to do this?
+ if (zeros == Zeros.Include)
+ {
+ TOther otherZero = BuilderInstance.Matrix.Zero;
+ int count = RowCount*ColumnCount - Data.Length;
+ for (int i = 0; i < count; i++)
+ {
+ state = f(state, Zero, otherZero);
+ }
+ }
+
+ return state;
+ }
+
+ var sparseOther = other as SparseCompressedRowMatrixStorage;
+ if (sparseOther != null)
+ {
+ int[] otherRowPointers = sparseOther.RowPointers;
+ int[] otherColumnIndices = sparseOther.ColumnIndices;
+ TOther[] otherValues = sparseOther.Values;
+ TOther otherZero = BuilderInstance.Matrix.Zero;
+
+ if (zeros == Zeros.Include)
+ {
+ int k = 0;
+ for (int row = 0; row < RowCount; row++)
+ {
+ for (int col = 0; col < ColumnCount; col++)
+ {
+ if (k < otherRowPointers[row + 1] && otherColumnIndices[k] == col)
+ {
+ state = f(state, row == col ? Data[row] : Zero, otherValues[k++]);
+ }
+ else
+ {
+ state = f(state, row == col ? Data[row] : Zero, otherZero);
+ }
+ }
+ }
+ return state;
+ }
+
+ for (int row = 0; row < RowCount; row++)
+ {
+ bool diagonal = false;
+
+ var startIndex = otherRowPointers[row];
+ var endIndex = otherRowPointers[row + 1];
+ for (var j = startIndex; j < endIndex; j++)
+ {
+ if (otherColumnIndices[j] == row)
+ {
+ diagonal = true;
+ state = f(state, Data[row], otherValues[j]);
+ }
+ else
+ {
+ state = f(state, Zero, otherValues[j]);
+ }
+ }
+
+ if (!diagonal && row < ColumnCount)
+ {
+ state = f(state, Data[row], otherZero);
+ }
+ }
+
+ return state;
+ }
+
+ // FALL BACK
+
+ return base.Fold2Unchecked(other, f, state, zeros);
+ }
}
}
diff --git a/src/Numerics/LinearAlgebra/Storage/MatrixStorage.cs b/src/Numerics/LinearAlgebra/Storage/MatrixStorage.cs
index 7d83e0f6..778aca5e 100644
--- a/src/Numerics/LinearAlgebra/Storage/MatrixStorage.cs
+++ b/src/Numerics/LinearAlgebra/Storage/MatrixStorage.cs
@@ -871,5 +871,36 @@ namespace MathNet.Numerics.LinearAlgebra.Storage
target[j] = finalize(s, RowCount);
}
}
+
+ public TState Fold2(MatrixStorage other, Func f, TState state, Zeros zeros)
+ where TOther : struct, IEquatable, IFormattable
+ {
+ if (other == null)
+ {
+ throw new ArgumentNullException("other");
+ }
+
+ if (RowCount != other.RowCount || ColumnCount != other.ColumnCount)
+ {
+ var message = string.Format(Resources.ArgumentMatrixDimensions2, RowCount + "x" + ColumnCount, other.RowCount + "x" + other.ColumnCount);
+ throw new ArgumentException(message, "other");
+ }
+
+ return Fold2Unchecked(other, f, state, zeros);
+ }
+
+ internal virtual TState Fold2Unchecked(MatrixStorage other, Func f, TState state, Zeros zeros)
+ where TOther : struct, IEquatable, IFormattable
+ {
+ for (int i = 0; i < RowCount; i++)
+ {
+ for (int j = 0; j < ColumnCount; j++)
+ {
+ state = f(state, At(i, j), other.At(i, j));
+ }
+ }
+
+ return state;
+ }
}
}
diff --git a/src/Numerics/LinearAlgebra/Storage/SparseCompressedRowMatrixStorage.cs b/src/Numerics/LinearAlgebra/Storage/SparseCompressedRowMatrixStorage.cs
index 9e037bab..83e3cd92 100644
--- a/src/Numerics/LinearAlgebra/Storage/SparseCompressedRowMatrixStorage.cs
+++ b/src/Numerics/LinearAlgebra/Storage/SparseCompressedRowMatrixStorage.cs
@@ -2065,5 +2065,127 @@ namespace MathNet.Numerics.LinearAlgebra.Storage
}
}
}
+
+ internal override TState Fold2Unchecked(MatrixStorage other, Func f, TState state, Zeros zeros)
+ {
+ var denseOther = other as DenseColumnMajorMatrixStorage;
+ if (denseOther != null)
+ {
+ TOther[] otherData = denseOther.Data;
+ int k = 0;
+ for (int row = 0; row < RowCount; row++)
+ {
+ for (int col = 0; col < ColumnCount; col++)
+ {
+ bool available = k < RowPointers[row + 1] && ColumnIndices[k] == col;
+ state = f(state, available ? Values[k++] : Zero, otherData[col*RowCount + row]);
+ }
+ }
+ return state;
+ }
+
+ var diagonalOther = other as DiagonalMatrixStorage;
+ if (diagonalOther != null)
+ {
+ TOther[] otherData = diagonalOther.Data;
+ TOther otherZero = BuilderInstance.Matrix.Zero;
+
+ if (zeros == Zeros.Include)
+ {
+ int k = 0;
+ for (int row = 0; row < RowCount; row++)
+ {
+ for (int col = 0; col < ColumnCount; col++)
+ {
+ bool available = k < RowPointers[row + 1] && ColumnIndices[k] == col;
+ state = f(state, available ? Values[k++] : Zero, row == col ? otherData[row] : otherZero);
+ }
+ }
+ return state;
+ }
+
+ for (int row = 0; row < RowCount; row++)
+ {
+ bool diagonal = false;
+
+ var startIndex = RowPointers[row];
+ var endIndex = RowPointers[row + 1];
+ for (var j = startIndex; j < endIndex; j++)
+ {
+ if (ColumnIndices[j] == row)
+ {
+ diagonal = true;
+ state = f(state, Values[j], otherData[row]);
+ }
+ else
+ {
+ state = f(state, Values[j], otherZero);
+ }
+ }
+
+ if (!diagonal && row < ColumnCount)
+ {
+ state = f(state, Zero, otherData[row]);
+ }
+ }
+ return state;
+ }
+
+ var sparseOther = other as SparseCompressedRowMatrixStorage;
+ if (sparseOther != null)
+ {
+ int[] otherRowPointers = sparseOther.RowPointers;
+ int[] otherColumnIndices = sparseOther.ColumnIndices;
+ TOther[] otherValues = sparseOther.Values;
+ TOther otherZero = BuilderInstance.Matrix.Zero;
+
+ if (zeros == Zeros.Include)
+ {
+ int k = 0, otherk = 0;
+ for (int row = 0; row < RowCount; row++)
+ {
+ for (int col = 0; col < ColumnCount; col++)
+ {
+ bool available = k < RowPointers[row + 1] && ColumnIndices[k] == col;
+ bool otherAvailable = otherk < otherRowPointers[row + 1] && otherColumnIndices[otherk] == col;
+ state = f(state, available ? Values[k++] : Zero, otherAvailable ? otherValues[otherk++] : otherZero);
+ }
+ }
+ return state;
+ }
+
+ for (int row = 0; row < RowCount; row++)
+ {
+ var startIndex = RowPointers[row];
+ var endIndex = RowPointers[row + 1];
+ var otherStartIndex = otherRowPointers[row];
+ var otherEndIndex = otherRowPointers[row + 1];
+
+ var j1 = startIndex;
+ var j2 = otherStartIndex;
+
+ while (j1 < endIndex || j2 < otherEndIndex)
+ {
+ if (j1 == endIndex || j2 < otherEndIndex && ColumnIndices[j1] > otherColumnIndices[j2])
+ {
+ state = f(state, Zero, otherValues[j2++]);
+ }
+ else if (j2 == otherEndIndex || ColumnIndices[j1] < otherColumnIndices[j2])
+ {
+ state = f(state, Values[j1++], otherZero);
+ }
+ else
+ {
+ state = f(state, Values[j1++], otherValues[j2++]);
+ }
+ }
+ }
+ return state;
+ }
+
+ // FALL BACK
+
+ return base.Fold2Unchecked(other, f, state, zeros);
+ }
}
}
diff --git a/src/UnitTests/LinearAlgebraTests/MatrixStructureTheory.Functional.cs b/src/UnitTests/LinearAlgebraTests/MatrixStructureTheory.Functional.cs
index 44ce7271..c39cab41 100644
--- a/src/UnitTests/LinearAlgebraTests/MatrixStructureTheory.Functional.cs
+++ b/src/UnitTests/LinearAlgebraTests/MatrixStructureTheory.Functional.cs
@@ -298,5 +298,20 @@ namespace MathNet.Numerics.UnitTests.LinearAlgebraTests
Assert.That(matrix.FoldByColumn((s, x) => s + 1.0, 0.0, Zeros.Include),
Is.EqualTo(Vector.Build.Dense(matrix.ColumnCount, matrix.RowCount)), "forced - full coverage");
}
+
+ [Theory]
+ public void CanFold2(Matrix matrix)
+ {
+ var other = -matrix;
+ other.Multiply(Operator.Convert(2), other);
+
+ // not forced
+ T sum = matrix.Fold2((s, x, y) => Operator.Add(Operator.Add(x, y), s), Operator.Zero, other, Zeros.AllowSkip);
+ Assert.That(sum, Is.EqualTo(Operator.Negate(matrix.Enumerate().Aggregate((a, b) => Operator.Add(a, b)))));
+
+ // forced
+ T sum2 = matrix.Fold2((s, x, y) => Operator.Add(Operator.Add(x, y), s), Operator.Zero, other, Zeros.Include);
+ Assert.That(sum2, Is.EqualTo(Operator.Negate(matrix.Enumerate().Aggregate((a, b) => Operator.Add(a, b)))));
+ }
}
}