Browse Source

Partial accuracy test for forward transform

pull/2633/head
Ynse Hoornenborg 2 years ago
parent
commit
101e841944
  1. 132
      tests/ImageSharp.Tests/Formats/Heif/Av1/Av1ForwardTransformTests.cs
  2. 2
      tests/ImageSharp.Tests/Formats/Heif/Av1/Av1ReferenceTransform.cs

132
tests/ImageSharp.Tests/Formats/Heif/Av1/Av1ForwardTransformTests.cs

@ -242,7 +242,7 @@ public class Av1ForwardTransformTests
} }
[Fact] [Fact]
public void NonSquareTransformSizeTest() public void NonSquareTransformSizeLandscapeTest()
{ {
// Arrange // Arrange
short[] input = [ short[] input = [
@ -251,7 +251,7 @@ public class Av1ForwardTransformTests
17, 18, 19, 20, 21, 22, 23, 24, 17, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32]; 25, 26, 27, 28, 29, 30, 31, 32];
// Expected is multiplied by Sqrt(2). // Expected is divided by Sqrt(2).
int[] expected = [ int[] expected = [
18, 20, 21, 23, 24, 25, 27, 28, 18, 20, 21, 23, 24, 25, 27, 28,
13, 14, 16, 17, 18, 20, 21, 23, 13, 14, 16, 17, 18, 20, 21, 23,
@ -277,8 +277,8 @@ public class Av1ForwardTransformTests
Assert.Equal(expected, actual); Assert.Equal(expected, actual);
} }
// [Fact] [Fact]
public void NonSquareTransformSize2Test() public void NonSquareTransformSizePortraitTest()
{ {
// Arrange // Arrange
short[] input = [ short[] input = [
@ -290,15 +290,17 @@ public class Av1ForwardTransformTests
21, 22, 23, 24, 21, 22, 23, 24,
25, 26, 27, 28, 25, 26, 27, 28,
29, 30, 31, 32]; 29, 30, 31, 32];
// Expected is multiplied by Sqrt(2).
int[] expected = [ int[] expected = [
29, 30, 31, 32, 41, 42, 44, 45,
25, 26, 27, 28, 35, 37, 38, 40,
21, 22, 23, 24, 30, 31, 33, 34,
17, 18, 19, 20, 24, 25, 27, 28,
13, 14, 15, 16, 18, 20, 21, 23,
9, 10, 11, 12, 13, 14, 16, 17,
5, 6, 7, 8, 7, 8, 10, 11,
1, 2, 3, 4]; 1, 3, 4, 6];
int[] actual = new int[32]; int[] actual = new int[32];
Av1Transform2dFlipConfiguration config = new(Av1TransformType.Identity, Av1TransformSize.Size4x8); Av1Transform2dFlipConfiguration config = new(Av1TransformType.Identity, Av1TransformSize.Size4x8);
config.SetFlip(true, false); config.SetFlip(true, false);
@ -375,8 +377,8 @@ public class Av1ForwardTransformTests
public void AccuracyOfIdentity1dTransformSize64Test() public void AccuracyOfIdentity1dTransformSize64Test()
=> AssertAccuracy1d(Av1TransformSize.Size64x64, Av1TransformType.Identity, new Av1Identity64Forward1dTransformer()); => AssertAccuracy1d(Av1TransformSize.Size64x64, Av1TransformType.Identity, new Av1Identity64Forward1dTransformer());
// [Theory] [Theory]
// [MemberData(nameof(GetCombinations))] [MemberData(nameof(GetCombinations))]
public void Accuracy2dTest(int txSize, int txType, int maxAllowedError = 0) public void Accuracy2dTest(int txSize, int txType, int maxAllowedError = 0)
{ {
const int bitDepth = 8; const int bitDepth = 8;
@ -390,18 +392,29 @@ public class Av1ForwardTransformTests
int blockSize = width * height; int blockSize = width * height;
double scaleFactor = Av1ReferenceTransform.GetScaleFactor(config); double scaleFactor = Av1ReferenceTransform.GetScaleFactor(config);
// TODO: Still some limitations in either reference or the actual implementation.
if (config.TransformTypeColumn == Av1TransformType1d.FlipAdst || config.TransformTypeRow == Av1TransformType1d.FlipAdst)
{
return;
}
if (width == 64 || height == 64 || width != height)
{
return;
}
short[] inputOfTest = new short[blockSize]; short[] inputOfTest = new short[blockSize];
double[] inputReference = new double[blockSize]; double[] inputOfReference = new double[blockSize];
int[] outputOfTest = new int[blockSize]; int[] outputOfTest = new int[blockSize];
double[] outputReference = new double[blockSize]; double[] outputOfReference = new double[blockSize];
for (int ti = 0; ti < testBlockCount; ++ti) for (int ti = 0; ti < testBlockCount; ++ti)
{ {
// prepare random test data // prepare random test data
for (int ni = 0; ni < blockSize; ++ni) for (int ni = 0; ni < blockSize; ++ni)
{ {
inputOfTest[ni] = (short)rnd.Next((1 << 10) - 1); inputOfTest[ni] = (short)rnd.Next((1 << 10) - 1);
inputReference[ni] = inputOfTest[ni]; inputOfReference[ni] = inputOfTest[ni];
outputReference[ni] = 0; outputOfReference[ni] = 0;
outputOfTest[ni] = 255; outputOfTest[ni] = 255;
} }
@ -409,18 +422,51 @@ public class Av1ForwardTransformTests
Av1ForwardTransformer.Transform2d( Av1ForwardTransformer.Transform2d(
inputOfTest, inputOfTest,
outputOfTest, outputOfTest,
(uint)transformSize.GetWidth(), (uint)width,
transformType, transformType,
transformSize, transformSize,
bitDepth); bitDepth);
// calculate in reference forward transform functions // calculate in reference forward transform functions
Av1ReferenceTransform.ReferenceTransformFunction2d(inputReference, outputReference, transformType, transformSize, scaleFactor); FlipInput(config, inputOfReference);
Av1ReferenceTransform.ReferenceTransformFunction2d(inputOfReference, outputOfReference, transformType, transformSize, scaleFactor);
// repack the coefficents for some tx_size // repack the coefficents for some tx_size
RepackCoefficients(outputOfTest, outputReference, width, height); RepackCoefficients(outputOfTest, outputOfReference, width, height);
Assert.True(CompareWithError(outputOfReference, outputOfTest, maxAllowedError * scaleFactor), $"{transformType} of {transformSize}, error: {GetMaximumError(outputOfReference, outputOfTest)}.");
}
}
private static void FlipInput(Av1Transform2dFlipConfiguration config, Span<double> input)
{
int width = config.TransformSize.GetWidth();
int height = config.TransformSize.GetHeight();
double tmp;
if (config.FlipLeftToRight)
{
for (int r = 0; r < height; ++r)
{
for (int c = 0; c < width / 2; ++c)
{
tmp = input[(r * width) + c];
input[(r * width) + c] = input[(r * width) + width - 1 - c];
input[(r * width) + width - 1 - c] = tmp;
}
}
}
Assert.True(CompareWithError(outputReference, outputOfTest, maxAllowedError * scaleFactor), $"Transform type: {transformType}, transform size: {transformSize}."); if (config.FlipUpsideDown)
{
for (int c = 0; c < width; ++c)
{
for (int r = 0; r < height / 2; ++r)
{
tmp = input[(r * width) + c];
input[(r * width) + c] = input[((height - 1 - r) * width) + c];
input[((height - 1 - r) * width) + c] = tmp;
}
}
} }
} }
@ -430,7 +476,7 @@ public class Av1ForwardTransformTests
{ {
for (int i = 0; i < 2; ++i) for (int i = 0; i < 2; ++i)
{ {
uint e_size = i == 0 ? (uint)sizeof(int) : sizeof(double); uint elementSize = i == 0 ? (uint)sizeof(int) : sizeof(double);
ref byte output = ref (i == 0) ? ref Unsafe.As<int, byte>(ref outputOfTest[0]) ref byte output = ref (i == 0) ? ref Unsafe.As<int, byte>(ref outputOfTest[0])
: ref Unsafe.As<double, byte>(ref outputReference[0]); : ref Unsafe.As<double, byte>(ref outputReference[0]);
@ -440,26 +486,26 @@ public class Av1ForwardTransformTests
// zero out top-right 32x32 area. // zero out top-right 32x32 area.
for (uint row = 0; row < 32; ++row) for (uint row = 0; row < 32; ++row)
{ {
Unsafe.InitBlock(ref Unsafe.Add(ref output, ((row * 64) + 32) * e_size), 0, 32 * e_size); Unsafe.InitBlock(ref Unsafe.Add(ref output, ((row * 64) + 32) * elementSize), 0, 32 * elementSize);
} }
// zero out the bottom 64x32 area. // zero out the bottom 64x32 area.
Unsafe.InitBlock(ref Unsafe.Add(ref output, 32 * 64 * e_size), 0, 32 * 64 * e_size); Unsafe.InitBlock(ref Unsafe.Add(ref output, 32 * 64 * elementSize), 0, 32 * 64 * elementSize);
// Re-pack non-zero coeffs in the first 32x32 indices. // Re-pack non-zero coeffs in the first 32x32 indices.
for (uint row = 1; row < 32; ++row) for (uint row = 1; row < 32; ++row)
{ {
Unsafe.CopyBlock( Unsafe.CopyBlock(
ref Unsafe.Add(ref output, row * 32 * e_size), ref Unsafe.Add(ref output, row * 32 * elementSize),
ref Unsafe.Add(ref output, row * 64 * e_size), ref Unsafe.Add(ref output, row * 64 * elementSize),
32 * e_size); 32 * elementSize);
} }
} }
else if (tx_width == 32 && tx_height == 64) else if (tx_width == 32 && tx_height == 64)
{ {
// tx_size == TX_32X64 // tx_size == TX_32X64
// zero out the bottom 32x32 area. // zero out the bottom 32x32 area.
Unsafe.InitBlock(ref Unsafe.Add(ref output, 32 * 32 * e_size), 0, 32 * 32 * e_size); Unsafe.InitBlock(ref Unsafe.Add(ref output, 32 * 32 * elementSize), 0, 32 * 32 * elementSize);
// Note: no repacking needed here. // Note: no repacking needed here.
} }
@ -469,23 +515,23 @@ public class Av1ForwardTransformTests
// zero out right 32x32 area. // zero out right 32x32 area.
for (uint row = 0; row < 32; ++row) for (uint row = 0; row < 32; ++row)
{ {
Unsafe.InitBlock(ref Unsafe.Add(ref output, ((row * 64) + 32) * e_size), 0, 32 * e_size); Unsafe.InitBlock(ref Unsafe.Add(ref output, ((row * 64) + 32) * elementSize), 0, 32 * elementSize);
} }
// Re-pack non-zero coeffs in the first 32x32 indices. // Re-pack non-zero coeffs in the first 32x32 indices.
for (uint row = 1; row < 32; ++row) for (uint row = 1; row < 32; ++row)
{ {
Unsafe.CopyBlock( Unsafe.CopyBlock(
ref Unsafe.Add(ref output, row * 32 * e_size), ref Unsafe.Add(ref output, row * 32 * elementSize),
ref Unsafe.Add(ref output, row * 64 * e_size), ref Unsafe.Add(ref output, row * 64 * elementSize),
32 * e_size); 32 * elementSize);
} }
} }
else if (tx_width == 16 && tx_height == 64) else if (tx_width == 16 && tx_height == 64)
{ {
// tx_size == TX_16X64 // tx_size == TX_16X64
// zero out the bottom 16x32 area. // zero out the bottom 16x32 area.
Unsafe.InitBlock(ref Unsafe.Add(ref output, 16 * 32 * e_size), 0, 16 * 32 * e_size); Unsafe.InitBlock(ref Unsafe.Add(ref output, 16 * 32 * elementSize), 0, 16 * 32 * elementSize);
// Note: no repacking needed here. // Note: no repacking needed here.
} }
@ -496,16 +542,16 @@ public class Av1ForwardTransformTests
// zero out right 32x16 area. // zero out right 32x16 area.
for (uint row = 0; row < 16; ++row) for (uint row = 0; row < 16; ++row)
{ {
Unsafe.InitBlock(ref Unsafe.Add(ref output, ((row * 64) + 32) * e_size), 0, 32 * e_size); Unsafe.InitBlock(ref Unsafe.Add(ref output, ((row * 64) + 32) * elementSize), 0, 32 * elementSize);
} }
// Re-pack non-zero coeffs in the first 32x16 indices. // Re-pack non-zero coeffs in the first 32x16 indices.
for (uint row = 1; row < 16; ++row) for (uint row = 1; row < 16; ++row)
{ {
Unsafe.CopyBlock( Unsafe.CopyBlock(
ref Unsafe.Add(ref output, row * 32 * e_size), ref Unsafe.Add(ref output, row * 32 * elementSize),
ref Unsafe.Add(ref output, row * 64 * e_size), ref Unsafe.Add(ref output, row * 64 * elementSize),
32 * e_size); 32 * elementSize);
} }
} }
} }
@ -554,14 +600,20 @@ public class Av1ForwardTransformTests
private static bool CompareWithError(Span<double> expected, Span<int> actual, double allowedError) private static bool CompareWithError(Span<double> expected, Span<int> actual, double allowedError)
{ {
// compare for the result is witghin accuracy // compare for the result is within accuracy
double maximumErrorInTest = GetMaximumError(expected, actual);
return maximumErrorInTest <= allowedError;
}
private static double GetMaximumError(Span<double> expected, Span<int> actual)
{
double maximumErrorInTest = 0d; double maximumErrorInTest = 0d;
for (int ni = 0; ni < expected.Length; ++ni) for (int ni = 0; ni < expected.Length; ++ni)
{ {
maximumErrorInTest = Math.Max(maximumErrorInTest, Math.Abs(actual[ni] - Math.Round(expected[ni]))); maximumErrorInTest = Math.Max(maximumErrorInTest, Math.Abs(actual[ni] - Math.Round(expected[ni])));
} }
return maximumErrorInTest <= allowedError; return maximumErrorInTest;
} }
public static TheoryData<int, int, int> GetCombinations() public static TheoryData<int, int, int> GetCombinations()

2
tests/ImageSharp.Tests/Formats/Heif/Av1/Av1ReferenceTransform.cs

@ -36,7 +36,7 @@ internal class Av1ReferenceTransform
int rectType = Av1ForwardTransformer.GetRectangularRatio(transformWidth, transformHeight); int rectType = Av1ForwardTransformer.GetRectangularRatio(transformWidth, transformHeight);
if (Math.Abs(rectType) == 1) if (Math.Abs(rectType) == 1)
{ {
scaleFactor *= Math.Pow(2, 0.5); scaleFactor *= Math.Sqrt(2);
} }
return scaleFactor; return scaleFactor;

Loading…
Cancel
Save