From 747422cf6d0075f03b1d10b0d7b07a2063d7a0c9 Mon Sep 17 00:00:00 2001 From: Brian Popow Date: Tue, 22 Feb 2022 13:46:41 +0100 Subject: [PATCH] Add Sse2 version of Average png filter --- .../Formats/Png/Filters/AverageFilter.cs | 126 +++++++++++++++--- 1 file changed, 110 insertions(+), 16 deletions(-) diff --git a/src/ImageSharp/Formats/Png/Filters/AverageFilter.cs b/src/ImageSharp/Formats/Png/Filters/AverageFilter.cs index 83c6389348..94c4fb4d1a 100644 --- a/src/ImageSharp/Formats/Png/Filters/AverageFilter.cs +++ b/src/ImageSharp/Formats/Png/Filters/AverageFilter.cs @@ -20,9 +20,9 @@ namespace SixLabors.ImageSharp.Formats.Png.Filters internal static class AverageFilter { /// - /// Decodes the scanline + /// Decodes a scanline, which was filtered with the average filter. /// - /// The scanline to decode + /// The scanline to decode. /// The previous scanline. /// The bytes per pixel. [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -33,32 +33,126 @@ namespace SixLabors.ImageSharp.Formats.Png.Filters ref byte scanBaseRef = ref MemoryMarshal.GetReference(scanline); ref byte prevBaseRef = ref MemoryMarshal.GetReference(previousScanline); + // The Avg filter predicts each pixel as the (truncated) average of a and b: // Average(x) + floor((Raw(x-bpp)+Prior(x))/2) - int x = 1; - for (; x <= bytesPerPixel /* Note the <= because x starts at 1 */; ++x) +#if SUPPORTS_RUNTIME_INTRINSICS + if (Sse2.IsSupported && bytesPerPixel is 3 or 4) { - ref byte scan = ref Unsafe.Add(ref scanBaseRef, x); - byte above = Unsafe.Add(ref prevBaseRef, x); - scan = (byte)(scan + (above >> 1)); - } + if (bytesPerPixel is 3) + { + Vector128 a = Vector128.Zero; + Vector128 b = Vector128.Zero; + Vector128 d = Vector128.Zero; + var ones = Vector128.Create((byte)1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); + + Span scratch = stackalloc byte[4]; + ref byte scratchRef = ref MemoryMarshal.GetReference(scratch); + int rb = scanline.Length; + int offset = 0; + while (rb >= 4) + { + a = d; + b = Sse2.ConvertScalarToVector128Int32(Unsafe.As(ref Unsafe.Add(ref prevBaseRef, offset))).AsByte(); + d = Sse2.ConvertScalarToVector128Int32(Unsafe.As(ref Unsafe.Add(ref scanBaseRef, offset))).AsByte(); + + d = AverageSubtractAdd(a, b, d, ones); + + // Store the result. + int result = Sse2.ConvertToInt32(d.AsInt32()); + Unsafe.As(ref scratchRef) = result; + scratch.Slice(0, 3).CopyTo(scanline.Slice(offset, 3)); + + rb -= 3; + offset += 3; + } + + if (rb is 3) + { + a = d; + scratch[3] = 0; + previousScanline.Slice(offset, 3).CopyTo(scratch); + b = Sse2.ConvertScalarToVector128Int32(Unsafe.As(ref scratchRef)).AsByte(); + scanline.Slice(offset, 3).CopyTo(scratch); + d = Sse2.ConvertScalarToVector128Int32(Unsafe.As(ref scratchRef)).AsByte(); + + d = AverageSubtractAdd(a, b, d, ones); + + // Store the result. + int result = Sse2.ConvertToInt32(d.AsInt32()); + Unsafe.As(ref scratchRef) = result; + scratch.Slice(0, 3).CopyTo(scanline.Slice(offset, 3)); + } + } + else + { + Vector128 a = Vector128.Zero; + Vector128 b = Vector128.Zero; + Vector128 d = Vector128.Zero; + var ones = Vector128.Create((byte)1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); + + Span scratch = stackalloc byte[4]; + ref byte scratchRef = ref MemoryMarshal.GetReference(scratch); + int rb = scanline.Length; + int offset = 0; + while (rb >= 4) + { + a = d; + b = Sse2.ConvertScalarToVector128Int32(Unsafe.As(ref Unsafe.Add(ref prevBaseRef, offset))).AsByte(); + d = Sse2.ConvertScalarToVector128Int32(Unsafe.As(ref Unsafe.Add(ref scanBaseRef, offset))).AsByte(); - for (; x < scanline.Length; ++x) + d = AverageSubtractAdd(a, b, d, ones); + + // Store the result. + int result = Sse2.ConvertToInt32(d.AsInt32()); + Unsafe.As(ref scratchRef) = result; + scratch.CopyTo(scanline.Slice(offset, 4)); + + rb -= 4; + offset += 4; + } + } + } + else +#endif { - ref byte scan = ref Unsafe.Add(ref scanBaseRef, x); - byte left = Unsafe.Add(ref scanBaseRef, x - bytesPerPixel); - byte above = Unsafe.Add(ref prevBaseRef, x); - scan = (byte)(scan + Average(left, above)); + int x = 1; + for (; x <= bytesPerPixel /* Note the <= because x starts at 1 */; ++x) + { + ref byte scan = ref Unsafe.Add(ref scanBaseRef, x); + byte above = Unsafe.Add(ref prevBaseRef, x); + scan = (byte)(scan + (above >> 1)); + } + + for (; x < scanline.Length; ++x) + { + ref byte scan = ref Unsafe.Add(ref scanBaseRef, x); + byte left = Unsafe.Add(ref scanBaseRef, x - bytesPerPixel); + byte above = Unsafe.Add(ref prevBaseRef, x); + scan = (byte)(scan + Average(left, above)); + } } } +#if SUPPORTS_RUNTIME_INTRINSICS + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 AverageSubtractAdd(Vector128 a, Vector128 b, Vector128 d, Vector128 ones) + { + // PNG requires a truncating average, so we can't just use _mm_avg_epu8. + // ...but we can fix it up by subtracting off 1 if it rounded up. + Vector128 avg = Sse2.Average(a, b); + avg = Sse2.Subtract(avg, Sse2.And(Sse2.Xor(a, b), ones)); + return Sse2.Add(d, avg); + } +#endif + /// - /// Encodes the scanline + /// Encodes a scanline with the average filter applied. /// - /// The scanline to encode + /// The scanline to encode. /// The previous scanline. /// The filtered scanline result. /// The bytes per pixel. - /// The sum of the total variance of the filtered row + /// The sum of the total variance of the filtered row. [MethodImpl(MethodImplOptions.AggressiveInlining)] public static void Encode(ReadOnlySpan scanline, ReadOnlySpan previousScanline, Span result, int bytesPerPixel, out int sum) {