Browse Source

Add AVX2 version of CombinedShannonEntropy

pull/1848/head
Brian Popow 5 years ago
parent
commit
32b97f41fc
  1. 15
      src/ImageSharp/Common/Helpers/Numerics.cs
  2. 97
      src/ImageSharp/Formats/Webp/Lossless/LosslessUtils.cs

15
src/ImageSharp/Common/Helpers/Numerics.cs

@ -820,6 +820,21 @@ namespace SixLabors.ImageSharp
} }
} }
/// <summary>
/// Reduces elements of the vector into one sum.
/// </summary>
/// <param name="accumulator">The accumulator to reduce.</param>
/// <returns>The sum of all elements.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int ReduceSum(Vector256<int> accumulator)
{
Vector128<int> vec0 = Avx2.ExtractVector128(accumulator, 0);
Vector128<int> vec1 = Avx2.ExtractVector128(accumulator, 1);
Vector128<int> sum128 = Sse2.Add(vec0, vec1);
return ReduceSum(sum128);
}
/// <summary> /// <summary>
/// Reduces even elements of the vector into one sum. /// Reduces even elements of the vector into one sum.
/// </summary> /// </summary>

97
src/ImageSharp/Formats/Webp/Lossless/LosslessUtils.cs

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. // Licensed under the Apache License, Version 2.0.
using System; using System;
using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using SixLabors.ImageSharp.Memory; using SixLabors.ImageSharp.Memory;
@ -760,29 +761,30 @@ namespace SixLabors.ImageSharp.Formats.Webp.Lossless
public static float CombinedShannonEntropy(Span<int> x, Span<int> y) public static float CombinedShannonEntropy(Span<int> x, Span<int> y)
{ {
#if SUPPORTS_RUNTIME_INTRINSICS #if SUPPORTS_RUNTIME_INTRINSICS
if (Sse41.IsSupported) if (Avx2.IsSupported)
{ {
double retVal = 0.0d; double retVal = 0.0d;
Span<int> tmp = stackalloc int[4]; Span<int> tmp = stackalloc int[8];
ref int xRef = ref MemoryMarshal.GetReference(x); ref int xRef = ref MemoryMarshal.GetReference(x);
ref int yRef = ref MemoryMarshal.GetReference(y); ref int yRef = ref MemoryMarshal.GetReference(y);
Vector128<int> sumXY128 = Vector128<int>.Zero; Vector256<int> sumXY256 = Vector256<int>.Zero;
Vector128<int> sumX128 = Vector128<int>.Zero; Vector256<int> sumX256 = Vector256<int>.Zero;
ref int tmpRef = ref MemoryMarshal.GetReference(tmp); ref int tmpRef = ref MemoryMarshal.GetReference(tmp);
for (int i = 0; i < 256; i += 4) for (nint i = 0; i < 256; i += 8)
{ {
Vector128<int> xVec = Unsafe.As<int, Vector128<int>>(ref Unsafe.Add(ref xRef, i)); Vector256<int> xVec = Unsafe.As<int, Vector256<int>>(ref Unsafe.Add(ref xRef, i));
Vector128<int> yVec = Unsafe.As<int, Vector128<int>>(ref Unsafe.Add(ref yRef, i)); Vector256<int> yVec = Unsafe.As<int, Vector256<int>>(ref Unsafe.Add(ref yRef, i));
// Check if any X is non-zero: this actually provides a speedup as X is usually sparse. // Check if any X is non-zero: this actually provides a speedup as X is usually sparse.
if (Sse2.MoveMask(Sse2.CompareEqual(xVec, Vector128<int>.Zero).AsByte()) != 0xFFFF) int mask = Avx2.MoveMask(Avx2.CompareEqual(xVec, Vector256<int>.Zero).AsByte());
if (mask != -1)
{ {
Vector128<int> xy128 = Sse2.Add(xVec, yVec); Vector256<int> xy256 = Avx2.Add(xVec, yVec);
sumXY128 = Sse2.Add(sumXY128, xy128); sumXY256 = Avx2.Add(sumXY256, xy256);
sumX128 = Sse2.Add(sumX128, xVec); sumX256 = Avx2.Add(sumX256, xVec);
// Analyze the different X + Y. // Analyze the different X + Y.
Unsafe.As<int, Vector128<int>>(ref tmpRef) = xy128; Unsafe.As<int, Vector256<int>>(ref tmpRef) = xy256;
if (tmpRef != 0) if (tmpRef != 0)
{ {
retVal -= FastSLog2((uint)tmpRef); retVal -= FastSLog2((uint)tmpRef);
@ -818,11 +820,47 @@ namespace SixLabors.ImageSharp.Formats.Webp.Lossless
retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 3)); retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 3));
} }
} }
if (Unsafe.Add(ref tmpRef, 4) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 4));
if (Unsafe.Add(ref xRef, i + 4) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 4));
}
}
if (Unsafe.Add(ref tmpRef, 5) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 5));
if (Unsafe.Add(ref xRef, i + 5) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 5));
}
}
if (Unsafe.Add(ref tmpRef, 6) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 6));
if (Unsafe.Add(ref xRef, i + 6) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 6));
}
}
if (Unsafe.Add(ref tmpRef, 7) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 7));
if (Unsafe.Add(ref xRef, i + 7) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 7));
}
}
} }
else else
{ {
// X is fully 0, so only deal with Y. // X is fully 0, so only deal with Y.
sumXY128 = Sse2.Add(sumXY128, yVec); sumXY256 = Avx2.Add(sumXY256, yVec);
if (Unsafe.Add(ref yRef, i) != 0) if (Unsafe.Add(ref yRef, i) != 0)
{ {
@ -843,19 +881,32 @@ namespace SixLabors.ImageSharp.Formats.Webp.Lossless
{ {
retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 3)); retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 3));
} }
if (Unsafe.Add(ref yRef, i + 4) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 4));
}
if (Unsafe.Add(ref yRef, i + 5) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 5));
}
if (Unsafe.Add(ref yRef, i + 6) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 6));
}
if (Unsafe.Add(ref yRef, i + 7) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 7));
}
} }
} }
// Sum up sumX_128 to get sumX and sum up sumXY_128 to get sumXY. // Sum up sumX256 to get sumX and sum up sumXY256 to get sumXY.
// note: not using here Numerics.ReduceSum, because grouping the same methods together should be slightly faster. int sumX = Numerics.ReduceSum(sumX256);
Vector128<int> haddSumX = Ssse3.HorizontalAdd(sumX128, sumX128); int sumXY = Numerics.ReduceSum(sumXY256);
Vector128<int> haddSumXY = Ssse3.HorizontalAdd(sumXY128, sumXY128);
Vector128<int> swappedSumX = Sse2.Shuffle(haddSumX, 0x1);
Vector128<int> swappedSumXY = Sse2.Shuffle(haddSumXY, 0x1);
Vector128<int> tmpSumX = Sse2.Add(haddSumX, swappedSumX);
Vector128<int> tmpSumXY = Sse2.Add(haddSumXY, swappedSumXY);
int sumX = Sse2.ConvertToInt32(tmpSumX);
int sumXY = Sse2.ConvertToInt32(tmpSumXY);
retVal += FastSLog2((uint)sumX) + FastSLog2((uint)sumXY); retVal += FastSLog2((uint)sumX) + FastSLog2((uint)sumXY);

Loading…
Cancel
Save