diff --git a/src/ImageSharp/Formats/Png/Filters/AverageFilter.cs b/src/ImageSharp/Formats/Png/Filters/AverageFilter.cs
index 83c6389348..94c4fb4d1a 100644
--- a/src/ImageSharp/Formats/Png/Filters/AverageFilter.cs
+++ b/src/ImageSharp/Formats/Png/Filters/AverageFilter.cs
@@ -20,9 +20,9 @@ namespace SixLabors.ImageSharp.Formats.Png.Filters
internal static class AverageFilter
{
///
- /// Decodes the scanline
+ /// Decodes a scanline, which was filtered with the average filter.
///
- /// The scanline to decode
+ /// The scanline to decode.
/// The previous scanline.
/// The bytes per pixel.
[MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -33,32 +33,126 @@ namespace SixLabors.ImageSharp.Formats.Png.Filters
ref byte scanBaseRef = ref MemoryMarshal.GetReference(scanline);
ref byte prevBaseRef = ref MemoryMarshal.GetReference(previousScanline);
+ // The Avg filter predicts each pixel as the (truncated) average of a and b:
// Average(x) + floor((Raw(x-bpp)+Prior(x))/2)
- int x = 1;
- for (; x <= bytesPerPixel /* Note the <= because x starts at 1 */; ++x)
+#if SUPPORTS_RUNTIME_INTRINSICS
+ if (Sse2.IsSupported && bytesPerPixel is 3 or 4)
{
- ref byte scan = ref Unsafe.Add(ref scanBaseRef, x);
- byte above = Unsafe.Add(ref prevBaseRef, x);
- scan = (byte)(scan + (above >> 1));
- }
+ if (bytesPerPixel is 3)
+ {
+ Vector128 a = Vector128.Zero;
+ Vector128 b = Vector128.Zero;
+ Vector128 d = Vector128.Zero;
+ var ones = Vector128.Create((byte)1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
+
+ Span scratch = stackalloc byte[4];
+ ref byte scratchRef = ref MemoryMarshal.GetReference(scratch);
+ int rb = scanline.Length;
+ int offset = 0;
+ while (rb >= 4)
+ {
+ a = d;
+ b = Sse2.ConvertScalarToVector128Int32(Unsafe.As(ref Unsafe.Add(ref prevBaseRef, offset))).AsByte();
+ d = Sse2.ConvertScalarToVector128Int32(Unsafe.As(ref Unsafe.Add(ref scanBaseRef, offset))).AsByte();
+
+ d = AverageSubtractAdd(a, b, d, ones);
+
+ // Store the result.
+ int result = Sse2.ConvertToInt32(d.AsInt32());
+ Unsafe.As(ref scratchRef) = result;
+ scratch.Slice(0, 3).CopyTo(scanline.Slice(offset, 3));
+
+ rb -= 3;
+ offset += 3;
+ }
+
+ if (rb is 3)
+ {
+ a = d;
+ scratch[3] = 0;
+ previousScanline.Slice(offset, 3).CopyTo(scratch);
+ b = Sse2.ConvertScalarToVector128Int32(Unsafe.As(ref scratchRef)).AsByte();
+ scanline.Slice(offset, 3).CopyTo(scratch);
+ d = Sse2.ConvertScalarToVector128Int32(Unsafe.As(ref scratchRef)).AsByte();
+
+ d = AverageSubtractAdd(a, b, d, ones);
+
+ // Store the result.
+ int result = Sse2.ConvertToInt32(d.AsInt32());
+ Unsafe.As(ref scratchRef) = result;
+ scratch.Slice(0, 3).CopyTo(scanline.Slice(offset, 3));
+ }
+ }
+ else
+ {
+ Vector128 a = Vector128.Zero;
+ Vector128 b = Vector128.Zero;
+ Vector128 d = Vector128.Zero;
+ var ones = Vector128.Create((byte)1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
+
+ Span scratch = stackalloc byte[4];
+ ref byte scratchRef = ref MemoryMarshal.GetReference(scratch);
+ int rb = scanline.Length;
+ int offset = 0;
+ while (rb >= 4)
+ {
+ a = d;
+ b = Sse2.ConvertScalarToVector128Int32(Unsafe.As(ref Unsafe.Add(ref prevBaseRef, offset))).AsByte();
+ d = Sse2.ConvertScalarToVector128Int32(Unsafe.As(ref Unsafe.Add(ref scanBaseRef, offset))).AsByte();
- for (; x < scanline.Length; ++x)
+ d = AverageSubtractAdd(a, b, d, ones);
+
+ // Store the result.
+ int result = Sse2.ConvertToInt32(d.AsInt32());
+ Unsafe.As(ref scratchRef) = result;
+ scratch.CopyTo(scanline.Slice(offset, 4));
+
+ rb -= 4;
+ offset += 4;
+ }
+ }
+ }
+ else
+#endif
{
- ref byte scan = ref Unsafe.Add(ref scanBaseRef, x);
- byte left = Unsafe.Add(ref scanBaseRef, x - bytesPerPixel);
- byte above = Unsafe.Add(ref prevBaseRef, x);
- scan = (byte)(scan + Average(left, above));
+ int x = 1;
+ for (; x <= bytesPerPixel /* Note the <= because x starts at 1 */; ++x)
+ {
+ ref byte scan = ref Unsafe.Add(ref scanBaseRef, x);
+ byte above = Unsafe.Add(ref prevBaseRef, x);
+ scan = (byte)(scan + (above >> 1));
+ }
+
+ for (; x < scanline.Length; ++x)
+ {
+ ref byte scan = ref Unsafe.Add(ref scanBaseRef, x);
+ byte left = Unsafe.Add(ref scanBaseRef, x - bytesPerPixel);
+ byte above = Unsafe.Add(ref prevBaseRef, x);
+ scan = (byte)(scan + Average(left, above));
+ }
}
}
+#if SUPPORTS_RUNTIME_INTRINSICS
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static Vector128 AverageSubtractAdd(Vector128 a, Vector128 b, Vector128 d, Vector128 ones)
+ {
+ // PNG requires a truncating average, so we can't just use _mm_avg_epu8.
+ // ...but we can fix it up by subtracting off 1 if it rounded up.
+ Vector128 avg = Sse2.Average(a, b);
+ avg = Sse2.Subtract(avg, Sse2.And(Sse2.Xor(a, b), ones));
+ return Sse2.Add(d, avg);
+ }
+#endif
+
///
- /// Encodes the scanline
+ /// Encodes a scanline with the average filter applied.
///
- /// The scanline to encode
+ /// The scanline to encode.
/// The previous scanline.
/// The filtered scanline result.
/// The bytes per pixel.
- /// The sum of the total variance of the filtered row
+ /// The sum of the total variance of the filtered row.
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void Encode(ReadOnlySpan scanline, ReadOnlySpan previousScanline, Span result, int bytesPerPixel, out int sum)
{