From bd64c8e61251b3379fee2f60f226ec2bac92bb5c Mon Sep 17 00:00:00 2001 From: Christoph Ruegg Date: Sat, 7 Mar 2015 20:58:20 +0100 Subject: [PATCH] LA: Matrix.Fold2, Find/2, Exists/2, ForAll/2 --- src/Numerics/LinearAlgebra/Matrix.cs | 42 ++++++ .../Storage/DenseColumnMajorMatrixStorage.cs | 60 +++++++++ .../Storage/DiagonalMatrixStorage.cs | 102 +++++++++++++++ .../LinearAlgebra/Storage/MatrixStorage.cs | 31 +++++ .../SparseCompressedRowMatrixStorage.cs | 122 ++++++++++++++++++ .../MatrixStructureTheory.Functional.cs | 15 +++ 6 files changed, 372 insertions(+) diff --git a/src/Numerics/LinearAlgebra/Matrix.cs b/src/Numerics/LinearAlgebra/Matrix.cs index cf8b7a7d..223bcf96 100644 --- a/src/Numerics/LinearAlgebra/Matrix.cs +++ b/src/Numerics/LinearAlgebra/Matrix.cs @@ -1735,5 +1735,47 @@ namespace MathNet.Numerics.LinearAlgebra { return EnumerateColumns().Aggregate(f); } + + /// + /// Applies a function to update the status with each value pair of two matrices and returns the resulting status. + /// + public TState Fold2(Func f, TState state, Matrix other, Zeros zeros = Zeros.AllowSkip) + where TOther : struct, IEquatable, IFormattable + { + return Storage.Fold2(other.Storage, f, state, zeros); + } + + public Tuple Find(Func predicate, Zeros zeros = Zeros.AllowSkip) + { + return Storage.Find(predicate, zeros); + } + + public Tuple Find2(Func predicate, Matrix other, Zeros zeros = Zeros.AllowSkip) + where TOther : struct, IEquatable, IFormattable + { + return Storage.Find2(other.Storage, predicate, zeros); + } + + public bool Exists(Func predicate, Zeros zeros = Zeros.AllowSkip) + { + return Storage.Find(predicate, zeros) != null; + } + + public bool Exists2(Func predicate, Matrix other, Zeros zeros = Zeros.AllowSkip) + where TOther : struct, IEquatable, IFormattable + { + return Storage.Find2(other.Storage, predicate, zeros) != null; + } + + public bool ForAll(Func predicate, Zeros zeros = Zeros.AllowSkip) + { + return Storage.Find(x => !predicate(x), zeros) == null; + } + + public bool ForAll2(Func predicate, Matrix other, Zeros zeros = Zeros.AllowSkip) + where TOther : struct, IEquatable, IFormattable + { + return Storage.Find2(other.Storage, (x, y) => !predicate(x, y), zeros) == null; + } } } diff --git a/src/Numerics/LinearAlgebra/Storage/DenseColumnMajorMatrixStorage.cs b/src/Numerics/LinearAlgebra/Storage/DenseColumnMajorMatrixStorage.cs index d6b8e5c1..265d12e1 100644 --- a/src/Numerics/LinearAlgebra/Storage/DenseColumnMajorMatrixStorage.cs +++ b/src/Numerics/LinearAlgebra/Storage/DenseColumnMajorMatrixStorage.cs @@ -907,5 +907,65 @@ namespace MathNet.Numerics.LinearAlgebra.Storage target[j] = finalize(s, RowCount); } } + + internal override TState Fold2Unchecked(MatrixStorage other, Func f, TState state, Zeros zeros) + { + var denseOther = other as DenseColumnMajorMatrixStorage; + if (denseOther != null) + { + TOther[] otherData = denseOther.Data; + for (int i = 0; i < Data.Length; i++) + { + state = f(state, Data[i], otherData[i]); + } + return state; + } + + var diagonalOther = other as DiagonalMatrixStorage; + if (diagonalOther != null) + { + TOther[] otherData = diagonalOther.Data; + TOther otherZero = BuilderInstance.Matrix.Zero; + int k = 0; + for (int j = 0; j < ColumnCount; j++) + { + for (int i = 0; i < RowCount; i++) + { + state = f(state, Data[k], i == j ? otherData[i] : otherZero); + k++; + } + } + return state; + } + + var sparseOther = other as SparseCompressedRowMatrixStorage; + if (sparseOther != null) + { + int[] otherRowPointers = sparseOther.RowPointers; + int[] otherColumnIndices = sparseOther.ColumnIndices; + TOther[] otherValues = sparseOther.Values; + TOther otherZero = BuilderInstance.Matrix.Zero; + int k = 0; + for (int row = 0; row < RowCount; row++) + { + for (int col = 0; col < ColumnCount; col++) + { + if (k < otherRowPointers[row + 1] && otherColumnIndices[k] == col) + { + state = f(state, Data[col*RowCount + row], otherValues[k++]); + } + else + { + state = f(state, Data[col*RowCount + row], otherZero); + } + } + } + return state; + } + + // FALL BACK + + return base.Fold2Unchecked(other, f, state, zeros); + } } } diff --git a/src/Numerics/LinearAlgebra/Storage/DiagonalMatrixStorage.cs b/src/Numerics/LinearAlgebra/Storage/DiagonalMatrixStorage.cs index 891e5249..8ec61019 100644 --- a/src/Numerics/LinearAlgebra/Storage/DiagonalMatrixStorage.cs +++ b/src/Numerics/LinearAlgebra/Storage/DiagonalMatrixStorage.cs @@ -1087,5 +1087,107 @@ namespace MathNet.Numerics.LinearAlgebra.Storage } } } + + internal override TState Fold2Unchecked(MatrixStorage other, Func f, TState state, Zeros zeros) + { + var denseOther = other as DenseColumnMajorMatrixStorage; + if (denseOther != null) + { + TOther[] otherData = denseOther.Data; + int k = 0; + for (int j = 0; j < ColumnCount; j++) + { + for (int i = 0; i < RowCount; i++) + { + state = f(state, i == j ? Data[i] : Zero, otherData[k]); + k++; + } + } + return state; + } + + var diagonalOther = other as DiagonalMatrixStorage; + if (diagonalOther != null) + { + TOther[] otherData = diagonalOther.Data; + for (int i = 0; i < Data.Length; i++) + { + state = f(state, Data[i], otherData[i]); + } + + // Do we really need to do this? + if (zeros == Zeros.Include) + { + TOther otherZero = BuilderInstance.Matrix.Zero; + int count = RowCount*ColumnCount - Data.Length; + for (int i = 0; i < count; i++) + { + state = f(state, Zero, otherZero); + } + } + + return state; + } + + var sparseOther = other as SparseCompressedRowMatrixStorage; + if (sparseOther != null) + { + int[] otherRowPointers = sparseOther.RowPointers; + int[] otherColumnIndices = sparseOther.ColumnIndices; + TOther[] otherValues = sparseOther.Values; + TOther otherZero = BuilderInstance.Matrix.Zero; + + if (zeros == Zeros.Include) + { + int k = 0; + for (int row = 0; row < RowCount; row++) + { + for (int col = 0; col < ColumnCount; col++) + { + if (k < otherRowPointers[row + 1] && otherColumnIndices[k] == col) + { + state = f(state, row == col ? Data[row] : Zero, otherValues[k++]); + } + else + { + state = f(state, row == col ? Data[row] : Zero, otherZero); + } + } + } + return state; + } + + for (int row = 0; row < RowCount; row++) + { + bool diagonal = false; + + var startIndex = otherRowPointers[row]; + var endIndex = otherRowPointers[row + 1]; + for (var j = startIndex; j < endIndex; j++) + { + if (otherColumnIndices[j] == row) + { + diagonal = true; + state = f(state, Data[row], otherValues[j]); + } + else + { + state = f(state, Zero, otherValues[j]); + } + } + + if (!diagonal && row < ColumnCount) + { + state = f(state, Data[row], otherZero); + } + } + + return state; + } + + // FALL BACK + + return base.Fold2Unchecked(other, f, state, zeros); + } } } diff --git a/src/Numerics/LinearAlgebra/Storage/MatrixStorage.cs b/src/Numerics/LinearAlgebra/Storage/MatrixStorage.cs index 7d83e0f6..778aca5e 100644 --- a/src/Numerics/LinearAlgebra/Storage/MatrixStorage.cs +++ b/src/Numerics/LinearAlgebra/Storage/MatrixStorage.cs @@ -871,5 +871,36 @@ namespace MathNet.Numerics.LinearAlgebra.Storage target[j] = finalize(s, RowCount); } } + + public TState Fold2(MatrixStorage other, Func f, TState state, Zeros zeros) + where TOther : struct, IEquatable, IFormattable + { + if (other == null) + { + throw new ArgumentNullException("other"); + } + + if (RowCount != other.RowCount || ColumnCount != other.ColumnCount) + { + var message = string.Format(Resources.ArgumentMatrixDimensions2, RowCount + "x" + ColumnCount, other.RowCount + "x" + other.ColumnCount); + throw new ArgumentException(message, "other"); + } + + return Fold2Unchecked(other, f, state, zeros); + } + + internal virtual TState Fold2Unchecked(MatrixStorage other, Func f, TState state, Zeros zeros) + where TOther : struct, IEquatable, IFormattable + { + for (int i = 0; i < RowCount; i++) + { + for (int j = 0; j < ColumnCount; j++) + { + state = f(state, At(i, j), other.At(i, j)); + } + } + + return state; + } } } diff --git a/src/Numerics/LinearAlgebra/Storage/SparseCompressedRowMatrixStorage.cs b/src/Numerics/LinearAlgebra/Storage/SparseCompressedRowMatrixStorage.cs index 9e037bab..83e3cd92 100644 --- a/src/Numerics/LinearAlgebra/Storage/SparseCompressedRowMatrixStorage.cs +++ b/src/Numerics/LinearAlgebra/Storage/SparseCompressedRowMatrixStorage.cs @@ -2065,5 +2065,127 @@ namespace MathNet.Numerics.LinearAlgebra.Storage } } } + + internal override TState Fold2Unchecked(MatrixStorage other, Func f, TState state, Zeros zeros) + { + var denseOther = other as DenseColumnMajorMatrixStorage; + if (denseOther != null) + { + TOther[] otherData = denseOther.Data; + int k = 0; + for (int row = 0; row < RowCount; row++) + { + for (int col = 0; col < ColumnCount; col++) + { + bool available = k < RowPointers[row + 1] && ColumnIndices[k] == col; + state = f(state, available ? Values[k++] : Zero, otherData[col*RowCount + row]); + } + } + return state; + } + + var diagonalOther = other as DiagonalMatrixStorage; + if (diagonalOther != null) + { + TOther[] otherData = diagonalOther.Data; + TOther otherZero = BuilderInstance.Matrix.Zero; + + if (zeros == Zeros.Include) + { + int k = 0; + for (int row = 0; row < RowCount; row++) + { + for (int col = 0; col < ColumnCount; col++) + { + bool available = k < RowPointers[row + 1] && ColumnIndices[k] == col; + state = f(state, available ? Values[k++] : Zero, row == col ? otherData[row] : otherZero); + } + } + return state; + } + + for (int row = 0; row < RowCount; row++) + { + bool diagonal = false; + + var startIndex = RowPointers[row]; + var endIndex = RowPointers[row + 1]; + for (var j = startIndex; j < endIndex; j++) + { + if (ColumnIndices[j] == row) + { + diagonal = true; + state = f(state, Values[j], otherData[row]); + } + else + { + state = f(state, Values[j], otherZero); + } + } + + if (!diagonal && row < ColumnCount) + { + state = f(state, Zero, otherData[row]); + } + } + return state; + } + + var sparseOther = other as SparseCompressedRowMatrixStorage; + if (sparseOther != null) + { + int[] otherRowPointers = sparseOther.RowPointers; + int[] otherColumnIndices = sparseOther.ColumnIndices; + TOther[] otherValues = sparseOther.Values; + TOther otherZero = BuilderInstance.Matrix.Zero; + + if (zeros == Zeros.Include) + { + int k = 0, otherk = 0; + for (int row = 0; row < RowCount; row++) + { + for (int col = 0; col < ColumnCount; col++) + { + bool available = k < RowPointers[row + 1] && ColumnIndices[k] == col; + bool otherAvailable = otherk < otherRowPointers[row + 1] && otherColumnIndices[otherk] == col; + state = f(state, available ? Values[k++] : Zero, otherAvailable ? otherValues[otherk++] : otherZero); + } + } + return state; + } + + for (int row = 0; row < RowCount; row++) + { + var startIndex = RowPointers[row]; + var endIndex = RowPointers[row + 1]; + var otherStartIndex = otherRowPointers[row]; + var otherEndIndex = otherRowPointers[row + 1]; + + var j1 = startIndex; + var j2 = otherStartIndex; + + while (j1 < endIndex || j2 < otherEndIndex) + { + if (j1 == endIndex || j2 < otherEndIndex && ColumnIndices[j1] > otherColumnIndices[j2]) + { + state = f(state, Zero, otherValues[j2++]); + } + else if (j2 == otherEndIndex || ColumnIndices[j1] < otherColumnIndices[j2]) + { + state = f(state, Values[j1++], otherZero); + } + else + { + state = f(state, Values[j1++], otherValues[j2++]); + } + } + } + return state; + } + + // FALL BACK + + return base.Fold2Unchecked(other, f, state, zeros); + } } } diff --git a/src/UnitTests/LinearAlgebraTests/MatrixStructureTheory.Functional.cs b/src/UnitTests/LinearAlgebraTests/MatrixStructureTheory.Functional.cs index 44ce7271..c39cab41 100644 --- a/src/UnitTests/LinearAlgebraTests/MatrixStructureTheory.Functional.cs +++ b/src/UnitTests/LinearAlgebraTests/MatrixStructureTheory.Functional.cs @@ -298,5 +298,20 @@ namespace MathNet.Numerics.UnitTests.LinearAlgebraTests Assert.That(matrix.FoldByColumn((s, x) => s + 1.0, 0.0, Zeros.Include), Is.EqualTo(Vector.Build.Dense(matrix.ColumnCount, matrix.RowCount)), "forced - full coverage"); } + + [Theory] + public void CanFold2(Matrix matrix) + { + var other = -matrix; + other.Multiply(Operator.Convert(2), other); + + // not forced + T sum = matrix.Fold2((s, x, y) => Operator.Add(Operator.Add(x, y), s), Operator.Zero, other, Zeros.AllowSkip); + Assert.That(sum, Is.EqualTo(Operator.Negate(matrix.Enumerate().Aggregate((a, b) => Operator.Add(a, b))))); + + // forced + T sum2 = matrix.Fold2((s, x, y) => Operator.Add(Operator.Add(x, y), s), Operator.Zero, other, Zeros.Include); + Assert.That(sum2, Is.EqualTo(Operator.Negate(matrix.Enumerate().Aggregate((a, b) => Operator.Add(a, b))))); + } } }