diff --git a/src/ImageSharp/Common/Helpers/Numerics.cs b/src/ImageSharp/Common/Helpers/Numerics.cs
index ba5c588ca..9dc13079d 100644
--- a/src/ImageSharp/Common/Helpers/Numerics.cs
+++ b/src/ImageSharp/Common/Helpers/Numerics.cs
@@ -820,6 +820,21 @@ namespace SixLabors.ImageSharp
}
}
+ ///
+ /// Reduces elements of the vector into one sum.
+ ///
+ /// The accumulator to reduce.
+ /// The sum of all elements.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int ReduceSum(Vector256 accumulator)
+ {
+ Vector128 vec0 = Avx2.ExtractVector128(accumulator, 0);
+ Vector128 vec1 = Avx2.ExtractVector128(accumulator, 1);
+ Vector128 sum128 = Sse2.Add(vec0, vec1);
+
+ return ReduceSum(sum128);
+ }
+
///
/// Reduces even elements of the vector into one sum.
///
diff --git a/src/ImageSharp/Formats/Webp/Lossless/LosslessUtils.cs b/src/ImageSharp/Formats/Webp/Lossless/LosslessUtils.cs
index 0f24e8e8f..314f26d64 100644
--- a/src/ImageSharp/Formats/Webp/Lossless/LosslessUtils.cs
+++ b/src/ImageSharp/Formats/Webp/Lossless/LosslessUtils.cs
@@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0.
using System;
+using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using SixLabors.ImageSharp.Memory;
@@ -760,29 +761,30 @@ namespace SixLabors.ImageSharp.Formats.Webp.Lossless
public static float CombinedShannonEntropy(Span x, Span y)
{
#if SUPPORTS_RUNTIME_INTRINSICS
- if (Sse41.IsSupported)
+ if (Avx2.IsSupported)
{
double retVal = 0.0d;
- Span tmp = stackalloc int[4];
+ Span tmp = stackalloc int[8];
ref int xRef = ref MemoryMarshal.GetReference(x);
ref int yRef = ref MemoryMarshal.GetReference(y);
- Vector128 sumXY128 = Vector128.Zero;
- Vector128 sumX128 = Vector128.Zero;
+ Vector256 sumXY256 = Vector256.Zero;
+ Vector256 sumX256 = Vector256.Zero;
ref int tmpRef = ref MemoryMarshal.GetReference(tmp);
- for (int i = 0; i < 256; i += 4)
+ for (nint i = 0; i < 256; i += 8)
{
- Vector128 xVec = Unsafe.As>(ref Unsafe.Add(ref xRef, i));
- Vector128 yVec = Unsafe.As>(ref Unsafe.Add(ref yRef, i));
+ Vector256 xVec = Unsafe.As>(ref Unsafe.Add(ref xRef, i));
+ Vector256 yVec = Unsafe.As>(ref Unsafe.Add(ref yRef, i));
// Check if any X is non-zero: this actually provides a speedup as X is usually sparse.
- if (Sse2.MoveMask(Sse2.CompareEqual(xVec, Vector128.Zero).AsByte()) != 0xFFFF)
+ int mask = Avx2.MoveMask(Avx2.CompareEqual(xVec, Vector256.Zero).AsByte());
+ if (mask != -1)
{
- Vector128 xy128 = Sse2.Add(xVec, yVec);
- sumXY128 = Sse2.Add(sumXY128, xy128);
- sumX128 = Sse2.Add(sumX128, xVec);
+ Vector256 xy256 = Avx2.Add(xVec, yVec);
+ sumXY256 = Avx2.Add(sumXY256, xy256);
+ sumX256 = Avx2.Add(sumX256, xVec);
// Analyze the different X + Y.
- Unsafe.As>(ref tmpRef) = xy128;
+ Unsafe.As>(ref tmpRef) = xy256;
if (tmpRef != 0)
{
retVal -= FastSLog2((uint)tmpRef);
@@ -818,11 +820,47 @@ namespace SixLabors.ImageSharp.Formats.Webp.Lossless
retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 3));
}
}
+
+ if (Unsafe.Add(ref tmpRef, 4) != 0)
+ {
+ retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 4));
+ if (Unsafe.Add(ref xRef, i + 4) != 0)
+ {
+ retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 4));
+ }
+ }
+
+ if (Unsafe.Add(ref tmpRef, 5) != 0)
+ {
+ retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 5));
+ if (Unsafe.Add(ref xRef, i + 5) != 0)
+ {
+ retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 5));
+ }
+ }
+
+ if (Unsafe.Add(ref tmpRef, 6) != 0)
+ {
+ retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 6));
+ if (Unsafe.Add(ref xRef, i + 6) != 0)
+ {
+ retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 6));
+ }
+ }
+
+ if (Unsafe.Add(ref tmpRef, 7) != 0)
+ {
+ retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 7));
+ if (Unsafe.Add(ref xRef, i + 7) != 0)
+ {
+ retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 7));
+ }
+ }
}
else
{
// X is fully 0, so only deal with Y.
- sumXY128 = Sse2.Add(sumXY128, yVec);
+ sumXY256 = Avx2.Add(sumXY256, yVec);
if (Unsafe.Add(ref yRef, i) != 0)
{
@@ -843,19 +881,32 @@ namespace SixLabors.ImageSharp.Formats.Webp.Lossless
{
retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 3));
}
+
+ if (Unsafe.Add(ref yRef, i + 4) != 0)
+ {
+ retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 4));
+ }
+
+ if (Unsafe.Add(ref yRef, i + 5) != 0)
+ {
+ retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 5));
+ }
+
+ if (Unsafe.Add(ref yRef, i + 6) != 0)
+ {
+ retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 6));
+ }
+
+ if (Unsafe.Add(ref yRef, i + 7) != 0)
+ {
+ retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 7));
+ }
}
}
- // Sum up sumX_128 to get sumX and sum up sumXY_128 to get sumXY.
- // note: not using here Numerics.ReduceSum, because grouping the same methods together should be slightly faster.
- Vector128 haddSumX = Ssse3.HorizontalAdd(sumX128, sumX128);
- Vector128 haddSumXY = Ssse3.HorizontalAdd(sumXY128, sumXY128);
- Vector128 swappedSumX = Sse2.Shuffle(haddSumX, 0x1);
- Vector128 swappedSumXY = Sse2.Shuffle(haddSumXY, 0x1);
- Vector128 tmpSumX = Sse2.Add(haddSumX, swappedSumX);
- Vector128 tmpSumXY = Sse2.Add(haddSumXY, swappedSumXY);
- int sumX = Sse2.ConvertToInt32(tmpSumX);
- int sumXY = Sse2.ConvertToInt32(tmpSumXY);
+ // Sum up sumX256 to get sumX and sum up sumXY256 to get sumXY.
+ int sumX = Numerics.ReduceSum(sumX256);
+ int sumXY = Numerics.ReduceSum(sumXY256);
retVal += FastSLog2((uint)sumX) + FastSLog2((uint)sumXY);