Browse Source

Unit test for transform set indices

pull/2633/head
Ynse Hoornenborg 1 year ago
parent
commit
3eedbbbec0
  1. 56
      src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolContextHelper.cs
  2. 25
      src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolDecoder.cs
  3. 29
      src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolEncoder.cs
  4. 38
      tests/ImageSharp.Tests/Formats/Heif/Av1/Av1SymbolContextTests.cs

56
src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolContextHelper.cs

@ -19,6 +19,15 @@ internal static class Av1SymbolContextHelper
[7, 8, 9, 12, 10, 11, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6],
];
public static readonly int[][] ExtendedTransformIndicesInverse = [
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 0, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 0, 10, 11, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 10, 11, 0, 1, 2, 4, 5, 3, 6, 7, 8, 0, 0, 0, 0],
[9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 4, 5, 3, 6, 7, 8],
];
public static readonly int[] EndOfBlockOffsetBits = [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
public static readonly int[] EndOfBlockGroupStart = [0, 1, 2, 3, 5, 9, 17, 33, 65, 129, 257, 513];
private static readonly int[] TransformCountInSet = [1, 2, 5, 7, 12, 16];
@ -216,6 +225,9 @@ internal static class Av1SymbolContextHelper
return Av1NzMap.GetNzMapContextFromStats(stats, levels, position, transformSize, transformClass);
}
/// <summary>
/// SVT: get_ext_tx_set_type
/// </summary>
internal static Av1TransformSetType GetExtendedTransformSetType(Av1TransformSize transformSize, bool useReducedSet)
{
Av1TransformSize squareUpSize = transformSize.GetSquareUpSize();
@ -285,52 +297,14 @@ internal static class Av1SymbolContextHelper
}
/// <summary>
/// SVT: get_ext_tx_set_type
/// SVT: get_ext_tx_types
/// </summary>
internal static Av1TransformSetType GetExtendedTransformSetType(Av1TransformSize transformSize, bool isInter, bool useReducedTransformSet)
{
Av1TransformSize transformSizeSquareUp = transformSize.GetSquareUpSize();
if (transformSizeSquareUp > Av1TransformSize.Size32x32)
{
return Av1TransformSetType.DctOnly;
}
if (transformSizeSquareUp == Av1TransformSize.Size32x32)
{
return isInter ? Av1TransformSetType.DctIdentity : Av1TransformSetType.DctOnly;
}
if (useReducedTransformSet)
{
return isInter ? Av1TransformSetType.DctIdentity : Av1TransformSetType.Dtt4Identity;
}
Av1TransformSize transformSizeSquare = transformSize.GetSquareSize();
if (isInter)
{
return transformSizeSquare == Av1TransformSize.Size16x16 ? Av1TransformSetType.Dtt9Identity1dDct : Av1TransformSetType.All16;
}
else
{
return transformSizeSquare == Av1TransformSize.Size16x16 ? Av1TransformSetType.Dtt4Identity : Av1TransformSetType.Dtt4Identity1dDct;
}
}
internal static int GetExtendedTransformTypeCount(Av1TransformSize transformSize, bool useReducedTransformSet)
{
int setType = (int)GetExtendedTransformSetType(transformSize, useReducedTransformSet);
return TransformCountInSet[setType];
}
internal static int GetExtendedTransformTypeCount(Av1TransformSetType setType) => TransformCountInSet[(int)setType];
/// <summary>
/// SVT: get_ext_tx_set
/// </summary>
internal static int GetExtendedTransformSet(Av1TransformSize transformSize, bool useReducedTransformSet)
{
int setType = (int)GetExtendedTransformSetType(transformSize, useReducedTransformSet);
return ExtendedTransformSetToIndex[setType];
}
internal static int GetExtendedTransformSet(Av1TransformSetType setType) => ExtendedTransformSetToIndex[(int)setType];
/// <summary>
/// SVT: set_dc_sign

25
src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolDecoder.cs

@ -9,15 +9,6 @@ namespace SixLabors.ImageSharp.Formats.Heif.Av1.Entropy;
internal ref struct Av1SymbolDecoder
{
private static readonly int[][] ExtendedTransformIndicesInverse = [
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 0, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 0, 10, 11, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 10, 11, 0, 1, 2, 4, 5, 3, 6, 7, 8, 0, 0, 0, 0],
[9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 4, 5, 3, 6, 7, 8],
];
private static readonly int[] IntraModeContext = [0, 1, 2, 3, 4, 4, 4, 4, 3, 0, 1, 2, 0];
private static readonly int[] AlphaVContexts = [-1, 0, 3, -1, 1, 4, -1, 2, 5];
@ -199,6 +190,9 @@ internal ref struct Av1SymbolDecoder
return transformSize;
}
/// <summary>
/// SVT: parse_transform_type
/// </summary>
public Av1TransformType ReadTransformType(
Av1TransformSize transformSize,
bool useReducedTransformSet,
@ -216,12 +210,17 @@ internal ref struct Av1SymbolDecoder
return;
*/
if (baseQIndex == 0)
{
return transformType;
}
// Ignoring INTER blocks here, as these should not end up here.
// int inter_block = is_inter_block_dec(mbmi);
Av1TransformSetType transformSetType = Av1SymbolContextHelper.GetExtendedTransformSetType(transformSize, useReducedTransformSet);
if (Av1SymbolContextHelper.GetExtendedTransformTypeCount(transformSize, useReducedTransformSet) > 1 && baseQIndex > 0)
if (Av1SymbolContextHelper.GetExtendedTransformTypeCount(transformSetType) > 1 && baseQIndex > 0)
{
int extendedSet = Av1SymbolContextHelper.GetExtendedTransformSet(transformSize, useReducedTransformSet);
int extendedSet = Av1SymbolContextHelper.GetExtendedTransformSet(transformSetType);
// eset == 0 should correspond to a set with only DCT_DCT and
// there is no need to read the tx_type
@ -233,7 +232,7 @@ internal ref struct Av1SymbolDecoder
: intraDirection;
ref Av1SymbolReader r = ref this.reader;
int symbol = r.ReadSymbol(this.intraExtendedTransform[extendedSet][(int)squareTransformSize][(int)intraMode]);
transformType = (Av1TransformType)ExtendedTransformIndicesInverse[(int)transformSetType][symbol];
transformType = (Av1TransformType)Av1SymbolContextHelper.ExtendedTransformIndicesInverse[(int)transformSetType][symbol];
}
return transformType;
@ -311,7 +310,7 @@ internal ref struct Av1SymbolDecoder
if (plane == (int)Av1Plane.Y)
{
this.ReadTransformType(transformSize, useReducedTransformSet, modeInfo.FilterIntraModeInfo.UseFilterIntra, this.baseQIndex, modeInfo.FilterIntraModeInfo.Mode, modeInfo.YMode);
transformInfo.Type = this.ReadTransformType(transformSize, useReducedTransformSet, modeInfo.FilterIntraModeInfo.UseFilterIntra, this.baseQIndex, modeInfo.FilterIntraModeInfo.Mode, modeInfo.YMode);
}
transformInfo.Type = ComputeTransformType(planeType, modeInfo, isLossless, transformSize, transformInfo, useReducedTransformSet);

29
src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolEncoder.cs

@ -11,15 +11,6 @@ namespace SixLabors.ImageSharp.Formats.Heif.Av1.Entropy;
internal class Av1SymbolEncoder : IDisposable
{
private static readonly int[][] ExtendedTransformIndices = [
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 3, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 5, 6, 4, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0, 0, 0],
[3, 4, 5, 8, 6, 7, 9, 10, 11, 0, 1, 2, 0, 0, 0, 0],
[7, 8, 9, 12, 10, 11, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6],
];
private readonly Av1Distribution tileIntraBlockCopy = Av1DefaultDistributions.IntraBlockCopy;
private readonly Av1Distribution[] tilePartitionTypes = Av1DefaultDistributions.PartitionTypes;
private readonly Av1Distribution[][] keyFrameYMode = Av1DefaultDistributions.KeyFrameYMode;
@ -289,35 +280,35 @@ internal class Av1SymbolEncoder : IDisposable
Av1PredictionMode intraDirection)
{
// bool isInter = mbmi->block_mi.use_intrabc || is_inter_mode(mbmi->block_mi.mode);
ref Av1SymbolWriter w = ref this.writer;
if (Av1SymbolContextHelper.GetExtendedTransformTypeCount(transformSize, useReducedTransformSet) > 1 && baseQIndex > 0)
Av1TransformSetType transformSetType = Av1SymbolContextHelper.GetExtendedTransformSetType(transformSize, useReducedTransformSet);
if (Av1SymbolContextHelper.GetExtendedTransformTypeCount(transformSetType) > 1 && baseQIndex > 0)
{
Av1TransformSize squareTransformSize = transformSize.GetSquareSize();
Guard.MustBeLessThanOrEqualTo((int)squareTransformSize, Av1Constants.ExtendedTransformCount, nameof(squareTransformSize));
Av1TransformSetType transformSetType = Av1SymbolContextHelper.GetExtendedTransformSetType(transformSize, useReducedTransformSet);
int extendedSet = Av1SymbolContextHelper.GetExtendedTransformSet(transformSize, useReducedTransformSet);
int extendedSet = Av1SymbolContextHelper.GetExtendedTransformSet(transformSetType);
// eset == 0 should correspond to a set with only DCT_DCT and there
// is no need to send the tx_type
Guard.MustBeGreaterThan(extendedSet, 0, nameof(extendedSet));
// assert(av1_ext_tx_used[tx_set_type][transformType]);
Av1PredictionMode intraMode;
Av1PredictionMode intraDirectionContext;
if (filterIntraMode != Av1FilterIntraMode.AllFilterIntraModes)
{
intraMode = filterIntraMode.ToIntraDirection();
intraDirectionContext = filterIntraMode.ToIntraDirection();
}
else
{
intraMode = intraDirection;
intraDirectionContext = intraDirection;
}
Guard.MustBeLessThan((int)intraMode, 13, nameof(intraMode));
Guard.MustBeLessThan((int)intraDirectionContext, 13, nameof(intraDirectionContext));
Guard.MustBeLessThan((int)squareTransformSize, 4, nameof(squareTransformSize));
ref Av1SymbolWriter w = ref this.writer;
w.WriteSymbol(
ExtendedTransformIndices[(int)transformSetType][(int)transformType],
this.intraExtendedTransform[extendedSet][(int)squareTransformSize][(int)intraMode]);
Av1SymbolContextHelper.ExtendedTransformIndices[(int)transformSetType][(int)transformType],
this.intraExtendedTransform[extendedSet][(int)squareTransformSize][(int)intraDirectionContext]);
}
}

38
tests/ImageSharp.Tests/Formats/Heif/Av1/Av1SymbolContextTests.cs

@ -1,9 +1,11 @@
// Copyright (c) Six Labors.
// Licensed under the Six Labors Split License.
using Microsoft.Diagnostics.Symbols;
using SixLabors.ImageSharp.Formats.Heif.Av1;
using SixLabors.ImageSharp.Formats.Heif.Av1.Entropy;
using SixLabors.ImageSharp.Formats.Heif.Av1.Tiling;
using SixLabors.ImageSharp.Formats.Heif.Av1.Transform;
namespace SixLabors.ImageSharp.Tests.Formats.Heif.Av1;
@ -11,8 +13,8 @@ namespace SixLabors.ImageSharp.Tests.Formats.Heif.Av1;
public class Av1SymbolContextTests
{
[Theory]
[MemberData(nameof(GetCombinations))]
public void TestAccuracy(int width, int height, int index)
[MemberData(nameof(GetLowLevelContextEndOfBlockData))]
public void TestLowLevelContextEndOfBlockAccuracy(int width, int height, int index)
{
// Arrange
Size size = new(width, height);
@ -28,7 +30,22 @@ public class Av1SymbolContextTests
Assert.Equal(expectedContext, actualContext);
}
public static TheoryData<int, int, int> GetCombinations()
[Theory]
[MemberData(nameof(GetExtendedTransformIndicesData))]
public void RoundTripExtendedTransformIndices(int setType, int index)
{
// Arrange
Av1TransformSetType transformSetType = (Av1TransformSetType)setType;
// Act
int transformType = Av1SymbolContextHelper.ExtendedTransformIndicesInverse[(int)transformSetType][index];
int actualIndex = Av1SymbolContextHelper.ExtendedTransformIndices[(int)transformSetType][transformType];
// Assert
Assert.Equal(actualIndex, index);
}
public static TheoryData<int, int, int> GetLowLevelContextEndOfBlockData()
{
TheoryData<int, int, int> result = [];
for (int y = 1; y < 6; y++)
@ -46,6 +63,21 @@ public class Av1SymbolContextTests
return result;
}
public static TheoryData<int, int> GetExtendedTransformIndicesData()
{
TheoryData<int, int> result = [];
for (Av1TransformSetType setType = Av1TransformSetType.DctOnly; setType <= Av1TransformSetType.All16; setType++)
{
int count = Av1SymbolContextHelper.GetExtendedTransformTypeCount(setType);
for (int type = 1; type < count; type++)
{
result.Add((int)setType, type);
}
}
return result;
}
/// <summary>
/// SVT: get_lower_levels_ctx_eob
/// </summary>

Loading…
Cancel
Save