diff --git a/src/FSharp/LinearAlgebra.Vector.fs b/src/FSharp/LinearAlgebra.Vector.fs index 2843a34d..917c2c90 100644 --- a/src/FSharp/LinearAlgebra.Vector.fs +++ b/src/FSharp/LinearAlgebra.Vector.fs @@ -169,6 +169,13 @@ module Vector = /// Zero-Zero value-pairs may be skipped (relevant mostly for sparse vectors). let inline map2SkipZeros f (u: #Vector<_>) (v: #Vector<_>) = u.Map2((fun x y -> f x y), v, Zeros.AllowSkip) + /// Folds two vectors by applying a function to update the status for each element pair. + let inline fold2 f status (u: #Vector<_>) (v: #Vector<_>) = u.Fold2((fun s x y -> f s x y), status, v, Zeros.Include) + + /// Folds two vectors by applying a function to update the status for each element pair. + /// Zero-Zero value-pairs may be skipped (relevant mostly for sparse vectors). + let inline fold2SkipZeros f status (u: #Vector<_>) (v: #Vector<_>) = u.Fold2((fun s x y -> f s x y), status, v, Zeros.AllowSkip) + /// Fold all entries of a vector in reverse order. diff --git a/src/Numerics/LinearAlgebra/Storage/DenseVectorStorage.cs b/src/Numerics/LinearAlgebra/Storage/DenseVectorStorage.cs index e92adce9..f1c6293f 100644 --- a/src/Numerics/LinearAlgebra/Storage/DenseVectorStorage.cs +++ b/src/Numerics/LinearAlgebra/Storage/DenseVectorStorage.cs @@ -442,5 +442,47 @@ namespace MathNet.Numerics.LinearAlgebra.Storage base.Map2ToUnchecked(target, other, f, zeros, existingData); } + + internal override TState Fold2Unchecked(VectorStorage other, Func f, TState state, Zeros zeros = Zeros.AllowSkip) + { + var denseOther = other as DenseVectorStorage; + if (denseOther != null) + { + var otherData = denseOther.Data; + for (int i = 0; i < Data.Length; i++) + { + state = f(state, Data[i], otherData[i]); + } + + return state; + } + + var sparseOther = other as SparseVectorStorage; + if (sparseOther != null) + { + int[] otherIndices = sparseOther.Indices; + TOther[] otherValues = sparseOther.Values; + int otherValueCount = sparseOther.ValueCount; + TOther otherZero = BuilderInstance.Vector.Zero; + + int k = 0; + for (int i = 0; i < Data.Length; i++) + { + if (k < otherValueCount && otherIndices[k] == i) + { + state = f(state, Data[i], otherValues[k]); + k++; + } + else + { + state = f(state, Data[i], otherZero); + } + } + + return state; + } + + return base.Fold2Unchecked(other, f, state, zeros); + } } } diff --git a/src/Numerics/LinearAlgebra/Storage/SparseVectorStorage.cs b/src/Numerics/LinearAlgebra/Storage/SparseVectorStorage.cs index 7ddc7501..4145eb5e 100644 --- a/src/Numerics/LinearAlgebra/Storage/SparseVectorStorage.cs +++ b/src/Numerics/LinearAlgebra/Storage/SparseVectorStorage.cs @@ -919,5 +919,78 @@ namespace MathNet.Numerics.LinearAlgebra.Storage base.Map2ToUnchecked(target, other, f, zeros, existingData); } + + internal override TState Fold2Unchecked(VectorStorage other, Func f, TState state, Zeros zeros = Zeros.AllowSkip) + { + var sparseOther = other as SparseVectorStorage; + if (sparseOther != null) + { + int[] otherIndices = sparseOther.Indices; + TOther[] otherValues = sparseOther.Values; + int otherValueCount = sparseOther.ValueCount; + TOther otherZero = BuilderInstance.Vector.Zero; + + if (zeros == Zeros.Include) + { + int p = 0, q = 0; + for (int i = 0; i < Length; i++) + { + var left = p < ValueCount && Indices[p] == i ? Values[p++] : Zero; + var right = q < otherValueCount && otherIndices[q] == i ? otherValues[q++] : otherZero; + state = f(state, left, right); + } + } + else + { + int p = 0, q = 0; + while (p < ValueCount || q < otherValueCount) + { + if (q >= otherValueCount || p < ValueCount && Indices[p] < otherIndices[q]) + { + state = f(state, Values[p], otherZero); + p++; + } + else if (p >= ValueCount || q < otherValueCount && Indices[p] > otherIndices[q]) + { + state = f(state, Zero, otherValues[q]); + q++; + } + else + { + Debug.Assert(Indices[p] == otherIndices[q]); + state = f(state, Values[p], otherValues[q]); + p++; + q++; + } + } + } + + return state; + } + + var denseOther = other as DenseVectorStorage; + if (denseOther != null) + { + TOther[] otherData = denseOther.Data; + + int k = 0; + for (int i = 0; i < otherData.Length; i++) + { + if (k < ValueCount && Indices[k] == i) + { + state = f(state, Values[k], otherData[i]); + k++; + } + else + { + state = f(state, Zero, otherData[i]); + } + } + + return state; + } + + return base.Fold2Unchecked(other, f, state, zeros); + } } } diff --git a/src/Numerics/LinearAlgebra/Storage/VectorStorage.cs b/src/Numerics/LinearAlgebra/Storage/VectorStorage.cs index bb1edeb0..7b4b60d4 100644 --- a/src/Numerics/LinearAlgebra/Storage/VectorStorage.cs +++ b/src/Numerics/LinearAlgebra/Storage/VectorStorage.cs @@ -504,5 +504,32 @@ namespace MathNet.Numerics.LinearAlgebra.Storage target.At(i, f(At(i), other.At(i))); } } + + public TState Fold2(VectorStorage other, Func f, TState state, Zeros zeros = Zeros.AllowSkip) + where TOther : struct, IEquatable, IFormattable + { + if (other == null) + { + throw new ArgumentNullException("other"); + } + + if (Length != other.Length) + { + throw new ArgumentException(Resources.ArgumentVectorsSameLength, "other"); + } + + return Fold2Unchecked(other, f, state, zeros); + } + + internal virtual TState Fold2Unchecked(VectorStorage other, Func f, TState state, Zeros zeros = Zeros.AllowSkip) + where TOther : struct, IEquatable, IFormattable + { + for (int i = 0; i < Length; i++) + { + state = f(state, At(i), other.At(i)); + } + + return state; + } } } diff --git a/src/Numerics/LinearAlgebra/Vector.cs b/src/Numerics/LinearAlgebra/Vector.cs index 3cd2b96f..d01e8675 100644 --- a/src/Numerics/LinearAlgebra/Vector.cs +++ b/src/Numerics/LinearAlgebra/Vector.cs @@ -475,5 +475,14 @@ namespace MathNet.Numerics.LinearAlgebra Storage.Map2To(result.Storage, other.Storage, f, zeros, ExistingData.AssumeZeros); return result; } + + /// + /// Applies a function to update the status with each value pair of two vectors and returns the resulting status. + /// + public TState Fold2(Func f, TState state, Vector other, Zeros zeros = Zeros.AllowSkip) + where TOther : struct, IEquatable, IFormattable + { + return Storage.Fold2(other.Storage, f, state, zeros); + } } } diff --git a/src/UnitTests/LinearAlgebraTests/VectorStorageCombinatorsTests.cs b/src/UnitTests/LinearAlgebraTests/VectorStorageCombinatorsTests.cs index 4f6528c9..bd84fa34 100644 --- a/src/UnitTests/LinearAlgebraTests/VectorStorageCombinatorsTests.cs +++ b/src/UnitTests/LinearAlgebraTests/VectorStorageCombinatorsTests.cs @@ -136,5 +136,23 @@ namespace MathNet.Numerics.UnitTests.LinearAlgebraTests a.Map2To(result, b, (u, v) => u + v + 1.0, Zeros.AllowSkip); Assert.That(result.Equals(expected)); } + + [Theory] + public void Fold2SkipZeros(VectorStorageType aType, VectorStorageType bType) + { + var a = Build.VectorStorage(aType, new[] { 1.0, 2.0, 0.0, 4.0, 0.0, 6.0 }); + var b = Build.VectorStorage(bType, new[] { 11.0, 12.0, 13.0, 0.0, 0.0, 16.0 }); + var result = a.Fold2(b, (acc, u, v) => acc + u + v, 0.0, Zeros.AllowSkip); + Assert.That(result, Is.EqualTo(65)); + } + + [Theory] + public void Fold2ForceIncludeZeros(VectorStorageType aType, VectorStorageType bType) + { + var a = Build.VectorStorage(aType, new[] { 1.0, 2.0, 0.0, 4.0, 0.0, 6.0 }); + var b = Build.VectorStorage(bType, new[] { 11.0, 12.0, 13.0, 0.0, 0.0, 16.0 }); + var result = a.Fold2(b, (acc, u, v) => acc + u + v + 1.0, 0.0, Zeros.Include); + Assert.That(result, Is.EqualTo(71)); + } } }