Browse Source

Implement 4x4 forward DCT transform

pull/2633/head
Ynse Hoornenborg 2 years ago
parent
commit
97052eb431
  1. 2
      src/ImageSharp/Formats/Heif/Av1/Transform/Av1SinusConstants.cs
  2. 16
      src/ImageSharp/Formats/Heif/Av1/Transform/Av1Transform2dFlipConfiguration.cs
  3. 109
      src/ImageSharp/Formats/Heif/Av1/Transform/Forward/Av1Dct4ForwardTransformer.cs
  4. 119
      tests/ImageSharp.Tests/Formats/Heif/Av1/Av1ForwardTransformTests.cs
  5. 2
      tests/ImageSharp.Tests/Formats/Heif/Av1/Av1ReferenceTransform.cs

2
src/ImageSharp/Formats/Heif/Av1/Transform/Av1SinusConstants.cs

@ -1,8 +1,6 @@
// Copyright (c) Six Labors.
// Licensed under the Six Labors Split License.
using System;
namespace SixLabors.ImageSharp.Formats.Heif.Av1.Transform;
internal static class Av1SinusConstants

16
src/ImageSharp/Formats/Heif/Av1/Transform/Av1Transform2dFlipConfiguration.cs

@ -142,8 +142,8 @@ internal class Av1Transform2dFlipConfiguration
this.TransformFunctionTypeRow = TransformFunctionTypeMap[txw_idx][(int)tx_type_1d_row];
this.StageNumberColumn = StageNumberList[(int)this.TransformFunctionTypeColumn];
this.StageNumberRow = StageNumberList[(int)this.TransformFunctionTypeRow];
this.StageRangeColumn = new int[12];
this.StageRangeRow = new int[12];
this.StageRangeColumn = new byte[12];
this.StageRangeRow = new byte[12];
this.NonScaleRange();
}
@ -169,9 +169,9 @@ internal class Av1Transform2dFlipConfiguration
public Span<int> Shift => this.shift;
public int[] StageRangeColumn { get; }
public byte[] StageRangeColumn { get; }
public int[] StageRangeRow { get; }
public byte[] StageRangeRow { get; }
/// <summary>
/// SVT: svt_av1_gen_fwd_stage_range
@ -184,13 +184,13 @@ internal class Av1Transform2dFlipConfiguration
// i < MAX_TXFM_STAGE_NUM will mute above array bounds warning
for (int i = 0; i < this.StageNumberColumn && i < MaxStageNumber; ++i)
{
this.StageRangeColumn[i] = this.StageRangeColumn[i] + shift[0] + bitDepth + 1;
this.StageRangeColumn[i] = (byte)(this.StageRangeColumn[i] + shift[0] + bitDepth + 1);
}
// i < MAX_TXFM_STAGE_NUM will mute above array bounds warning
for (int i = 0; i < this.StageNumberRow && i < MaxStageNumber; ++i)
{
this.StageRangeRow[i] = this.StageRangeRow[i] + shift[0] + shift[1] + bitDepth + 1;
this.StageRangeRow[i] = (byte)(this.StageRangeRow[i] + shift[0] + shift[1] + bitDepth + 1);
}
}
@ -296,7 +296,7 @@ internal class Av1Transform2dFlipConfiguration
int stage_num_col = this.StageNumberColumn;
for (int i = 0; i < stage_num_col; ++i)
{
this.StageRangeColumn[i] = (range_mult2_col[i] + 1) >> 1;
this.StageRangeColumn[i] = (byte)((range_mult2_col[i] + 1) >> 1);
}
}
@ -306,7 +306,7 @@ internal class Av1Transform2dFlipConfiguration
Span<int> range_mult2_row = RangeMulti2List[(int)this.TransformFunctionTypeRow];
for (int i = 0; i < stage_num_row; ++i)
{
this.StageRangeRow[i] = (range_mult2_col[this.StageNumberColumn - 1] + range_mult2_row[i] + 1) >> 1;
this.StageRangeRow[i] = (byte)((range_mult2_col[this.StageNumberColumn - 1] + range_mult2_row[i] + 1) >> 1);
}
}
}

109
src/ImageSharp/Formats/Heif/Av1/Transform/Forward/Av1Dct4ForwardTransformer.cs

@ -10,7 +10,39 @@ namespace SixLabors.ImageSharp.Formats.Heif.Av1.Transform.Forward;
internal class Av1Dct4ForwardTransformer : IAv1ForwardTransformer
{
public void Transform(ref int input, ref int output, int cosBit, Span<byte> stageRange)
=> throw new NotImplementedException();
{
Span<int> cospi = Av1SinusConstants.CosinusPi(cosBit);
ref int bf0 = ref output;
ref int bf1 = ref output;
Span<int> stepSpan = new int[4];
ref int step0 = ref stepSpan[0];
ref int step1 = ref Unsafe.Add(ref step0, 1);
ref int step2 = ref Unsafe.Add(ref step0, 2);
ref int step3 = ref Unsafe.Add(ref step0, 3);
ref int output1 = ref Unsafe.Add(ref output, 1);
ref int output2 = ref Unsafe.Add(ref output, 2);
ref int output3 = ref Unsafe.Add(ref output, 3);
// stage 0;
// stage 1;
output = input + Unsafe.Add(ref input, 3);
output1 = Unsafe.Add(ref input, 1) + Unsafe.Add(ref input, 2);
output2 = -Unsafe.Add(ref input, 2) + Unsafe.Add(ref input, 1);
output3 = -Unsafe.Add(ref input, 3) + Unsafe.Add(ref input, 0);
// stage 2
step0 = HalfBtf(cospi[32], output, cospi[32], output1, cosBit);
step1 = HalfBtf(-cospi[32], output1, cospi[32], output, cosBit);
step2 = HalfBtf(cospi[48], output2, cospi[16], output3, cosBit);
step3 = HalfBtf(cospi[48], output3, -cospi[16], output2, cosBit);
// stage 3
output = step0;
output1 = step2;
output2 = step1;
output3 = step3;
}
public void TransformAvx2(ref Vector256<int> input, ref Vector256<int> output, int cosBit, int columnNumber)
=> throw new NotImplementedException("Too small block for Vector implementation, use TransformSse() method instead.");
@ -20,7 +52,8 @@ internal class Av1Dct4ForwardTransformer : IAv1ForwardTransformer
/// </summary>
public static void TransformSse(ref Vector128<int> input, ref Vector128<int> output, byte cosBit, int columnNumber)
{
/*
#pragma warning disable CA1857 // A constant is expected for the parameter
// We only use stage-2 bit;
// shift[0] is used in load_buffer_4x4()
// shift[1] is used in txfm_func_col()
@ -35,51 +68,71 @@ internal class Av1Dct4ForwardTransformer : IAv1ForwardTransformer
Vector128<int> v0, v1, v2, v3;
int endidx = 3 * columnNumber;
s0 = Sse41.Add(input, Unsafe.Add(ref input, endidx));
s3 = Sse41.Subtract(input, Unsafe.Add(ref input, endidx));
s0 = Sse2.Add(input, Unsafe.Add(ref input, endidx));
s3 = Sse2.Subtract(input, Unsafe.Add(ref input, endidx));
endidx -= columnNumber;
s1 = Sse41.Add(Unsafe.Add(ref input, columnNumber), Unsafe.Add(ref input, endidx));
s2 = Sse41.Subtract(Unsafe.Add(ref input, columnNumber), Unsafe.Add(ref input, endidx));
s1 = Sse2.Add(Unsafe.Add(ref input, columnNumber), Unsafe.Add(ref input, endidx));
s2 = Sse2.Subtract(Unsafe.Add(ref input, columnNumber), Unsafe.Add(ref input, endidx));
// btf_32_sse4_1_type0(cospi32, cospi32, s[01], u[02], bit);
u0 = Sse41.MultiplyLow(s0, cospi32);
u1 = Sse41.MultiplyLow(s1, cospi32);
u2 = Sse41.Add(u0, u1);
v0 = Sse41.Subtract(u0, u1);
u2 = Sse2.Add(u0, u1);
v0 = Sse2.Subtract(u0, u1);
u3 = Sse41.Add(u2, rnding);
v1 = Sse41.Add(v0, rnding);
u3 = Sse2.Add(u2, rnding);
v1 = Sse2.Add(v0, rnding);
u0 = Sse41.ShiftRightArithmetic(u3, cosBit);
u2 = Sse41.ShiftRightArithmetic(v1, cosBit);
u0 = Sse2.ShiftRightArithmetic(u3, cosBit);
u2 = Sse2.ShiftRightArithmetic(v1, cosBit);
// btf_32_sse4_1_type1(cospi48, cospi16, s[23], u[13], bit);
v0 = Sse41.MultiplyLow(s2, cospi48);
v1 = Sse41.MultiplyLow(s3, cospi16);
v2 = Sse41.Add(v0, v1);
v2 = Sse2.Add(v0, v1);
v3 = Sse41.Add(v2, rnding);
u1 = Sse41.ShiftRightArithmetic(v3, cosBit);
v3 = Sse2.Add(v2, rnding);
u1 = Sse2.ShiftRightArithmetic(v3, cosBit);
v0 = Sse41.MultiplyLow(s2, cospi16);
v1 = Sse41.MultiplyLow(s3, cospi48);
v2 = Sse41.Subtract(v1, v0);
v2 = Sse2.Subtract(v1, v0);
v3 = Sse41.Add(v2, rnding);
u3 = Sse41.ShiftRightArithmetic(v3, cosBit);
v3 = Sse2.Add(v2, rnding);
u3 = Sse2.ShiftRightArithmetic(v3, cosBit);
// Note: shift[1] and shift[2] are zeros
// Transpose 4x4 32-bit
v0 = Sse41.UnpackLow(u0, u1);
v1 = Sse41.UnpackHigh(u0, u1);
v2 = Sse41.UnpackLow(u2, u3);
v3 = Sse41.UnpackHigh(u2, u3);
output = Sse41.UnpackLow(v0.AsInt64(), v2.AsInt64()).AsInt32();
Unsafe.Add(ref output, 1) = Sse41.UnpackHigh(v0.AsInt64(), v2.AsInt64()).AsInt32();
Unsafe.Add(ref output, 2) = Sse41.UnpackLow(v1.AsInt64(), v3.AsInt64()).AsInt32();
Unsafe.Add(ref output, 3) = Sse41.UnpackHigh(v1.AsInt64(), v3.AsInt64()).AsInt32();
*/
v0 = Sse2.UnpackLow(u0, u1);
v1 = Sse2.UnpackHigh(u0, u1);
v2 = Sse2.UnpackLow(u2, u3);
v3 = Sse2.UnpackHigh(u2, u3);
output = Sse2.UnpackLow(v0.AsInt64(), v2.AsInt64()).AsInt32();
Unsafe.Add(ref output, 1) = Sse2.UnpackHigh(v0.AsInt64(), v2.AsInt64()).AsInt32();
Unsafe.Add(ref output, 2) = Sse2.UnpackLow(v1.AsInt64(), v3.AsInt64()).AsInt32();
Unsafe.Add(ref output, 3) = Sse2.UnpackHigh(v1.AsInt64(), v3.AsInt64()).AsInt32();
#pragma warning restore CA1857 // A constant is expected for the parameter
}
private static int HalfBtf(int w0, int in0, int w1, int in1, int bit)
{
long result64 = (long)(w0 * in0) + (w1 * in1);
long intermediate = result64 + (1L << (bit - 1));
// NOTE(david.barker): The value 'result_64' may not necessarily fit
// into 32 bits. However, the result of this function is nominally
// ROUND_POWER_OF_TWO_64(result_64, bit)
// and that is required to fit into stage_range[stage] many bits
// (checked by range_check_buf()).
//
// Here we've unpacked that rounding operation, and it can be shown
// that the value of 'intermediate' here *does* fit into 32 bits
// for any conformant bitstream.
// The upshot is that, if you do all this calculation using
// wrapping 32-bit arithmetic instead of (non-wrapping) 64-bit arithmetic,
// then you'll still get the correct result.
return (int)(intermediate >> bit);
}
}

119
tests/ImageSharp.Tests/Formats/Heif/Av1/Av1ForwardTransformTests.cs

@ -3,6 +3,7 @@
using System.Runtime.CompilerServices;
using SixLabors.ImageSharp.Formats.Heif.Av1.Transform;
using SixLabors.ImageSharp.Formats.Heif.Av1.Transform.Forward;
namespace SixLabors.ImageSharp.Tests.Formats.Heif.Av1;
@ -35,22 +36,49 @@ public class Av1ForwardTransformTests
36, // 64x16 transform
];
private readonly short[] inputOfTest;
private readonly int[] outputOfTest;
private readonly double[] inputReference;
private readonly double[] outputReference;
public Av1ForwardTransformTests()
[Theory]
[MemberData(nameof(GetSizes))]
public void AccuracyDct1dTest(int txSize)
{
this.inputOfTest = new short[64 * 64];
this.outputOfTest = new int[64 * 64];
this.inputReference = new double[64 * 64];
this.outputReference = new double[64 * 64];
Random rnd = new(0);
const int testBlockCount = 1; // Originally set to: 1000
Av1TransformSize transformSize = (Av1TransformSize)txSize;
Av1Transform2dFlipConfiguration config = new(Av1TransformType.DctDct, transformSize);
int width = config.TransformSize.GetWidth();
int[] inputOfTest = new int[width];
double[] inputReference = new double[width];
int[] outputOfTest = new int[width];
double[] outputReference = new double[width];
for (int ti = 0; ti < testBlockCount; ++ti)
{
// prepare random test data
for (int ni = 0; ni < width; ++ni)
{
inputOfTest[ni] = (short)rnd.Next((1 << 10) - 1);
inputReference[ni] = inputOfTest[ni];
outputReference[ni] = 0;
outputOfTest[ni] = 255;
}
// calculate in forward transform functions
new Av1Dct4ForwardTransformer().Transform(
ref inputOfTest[0],
ref outputOfTest[0],
config.CosBitColumn,
config.StageRangeColumn);
// calculate in reference forward transform functions
Av1ReferenceTransform.ReferenceDct1d(inputReference, outputReference, width);
// Assert
Assert.True(CompareWithError(outputReference, outputOfTest, 1));
}
}
// [Theory]
// [MemberData(nameof(GetCombinations))]
public void Accuracy2dTest(int txSize, int txType, int maxAllowedError)
[Theory]
[MemberData(nameof(GetCombinations))]
public void Accuracy2dTest(int txSize, int txType, int maxAllowedError = 0)
{
const int bitDepth = 8;
Random rnd = new(0);
@ -63,53 +91,49 @@ public class Av1ForwardTransformTests
int blockSize = width * height;
double scaleFactor = Av1ReferenceTransform.GetScaleFactor(config, width, height);
short[] inputOfTest = new short[blockSize];
double[] inputReference = new double[blockSize];
int[] outputOfTest = new int[blockSize];
double[] outputReference = new double[blockSize];
for (int ti = 0; ti < testBlockCount; ++ti)
{
// prepare random test data
for (int ni = 0; ni < blockSize; ++ni)
{
this.inputOfTest[ni] = (short)rnd.Next((1 << 10) - 1);
this.inputReference[ni] = this.inputOfTest[ni];
this.outputReference[ni] = 0;
this.outputOfTest[ni] = 255;
inputOfTest[ni] = (short)rnd.Next((1 << 10) - 1);
inputReference[ni] = inputOfTest[ni];
outputReference[ni] = 0;
outputOfTest[ni] = 255;
}
// calculate in forward transform functions
Av1ForwardTransformer.Transform2d(
this.inputOfTest,
this.outputOfTest,
inputOfTest,
outputOfTest,
(uint)transformSize.GetWidth(),
transformType,
transformSize,
bitDepth);
// calculate in reference forward transform functions
Av1ReferenceTransform.ReferenceTransformFunction2d(this.inputReference, this.outputReference, transformType, transformSize, scaleFactor);
Av1ReferenceTransform.ReferenceTransformFunction2d(inputReference, outputReference, transformType, transformSize, scaleFactor);
// repack the coefficents for some tx_size
this.RepackCoefficients(width, height);
RepackCoefficients(outputOfTest, outputReference, width, height);
// compare for the result is in accuracy
double maximumErrorInTest = 0;
for (int ni = 0; ni < blockSize; ++ni)
{
maximumErrorInTest = Math.Max(maximumErrorInTest, Math.Abs(this.outputOfTest[ni] - Math.Round(this.outputReference[ni])));
}
maximumErrorInTest /= scaleFactor;
Assert.True(maxAllowedError >= maximumErrorInTest, $"Forward transform 2d test with transform type: {transformType}, transform size: {transformSize} and loop: {ti}");
Assert.True(CompareWithError(outputReference, outputOfTest, maxAllowedError * scaleFactor), $"Forward transform 2d test with transform type: {transformType}, transform size: {transformSize} and loop: {ti}");
}
}
// The max txb_width or txb_height is 32, as specified in spec 7.12.3.
// Clear the high frequency coefficents and repack it in linear layout.
private void RepackCoefficients(int tx_width, int tx_height)
private static void RepackCoefficients(Span<int> outputOfTest, Span<double> outputReference, int tx_width, int tx_height)
{
for (int i = 0; i < 2; ++i)
{
uint e_size = i == 0 ? (uint)sizeof(int) : sizeof(double);
ref byte output = ref (i == 0) ? ref Unsafe.As<int, byte>(ref this.outputOfTest[0])
: ref Unsafe.As<double, byte>(ref this.outputReference[0]);
ref byte output = ref (i == 0) ? ref Unsafe.As<int, byte>(ref outputOfTest[0])
: ref Unsafe.As<double, byte>(ref outputReference[0]);
if (tx_width == 64 && tx_height == 64)
{
@ -188,13 +212,34 @@ public class Av1ForwardTransformTests
}
}
private static bool CompareWithError(Span<double> expected, Span<int> actual, double allowedError)
{
// compare for the result is witghin accuracy
double maximumErrorInTest = 0;
for (int ni = 0; ni < expected.Length; ++ni)
{
maximumErrorInTest = Math.Max(maximumErrorInTest, Math.Abs(actual[ni] - Math.Round(expected[ni])));
}
return maximumErrorInTest <= allowedError;
}
public static TheoryData<int> GetSizes()
{
TheoryData<int> sizes = [];
// For now test only 4x4.
sizes.Add(0);
return sizes;
}
public static TheoryData<int, int, int> GetCombinations()
{
TheoryData<int, int, int> combinations = [];
for (int s = 0; s < (int)Av1TransformSize.AllSizes; s++)
{
double maxError = MaximumAllowedError[s];
for (int t = 0; t < (int)Av1TransformType.AllTransformTypes; ++t)
for (int t = 0; t < (int)Av1TransformType.AllTransformTypes; t++)
{
Av1TransformType transformType = (Av1TransformType)t;
Av1TransformSize transformSize = (Av1TransformSize)s;
@ -203,7 +248,13 @@ public class Av1ForwardTransformTests
{
combinations.Add(s, t, (int)maxError);
}
// For now only DCT.
break;
}
// For now only 4x4.
break;
}
return combinations;

2
tests/ImageSharp.Tests/Formats/Heif/Av1/Av1ReferenceTransform.cs

@ -174,7 +174,7 @@ internal class Av1ReferenceTransform
}
}
private static void ReferenceDct1d(Span<double> input, Span<double> output, int size)
internal static void ReferenceDct1d(Span<double> input, Span<double> output, int size)
{
const double kInvSqrt2 = 0.707106781186547524400844362104f;
for (int k = 0; k < size; ++k)

Loading…
Cancel
Save