Browse Source

Unit test for transform set indices

pull/2633/head
Ynse Hoornenborg 1 year ago
parent
commit
3eedbbbec0
  1. 56
      src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolContextHelper.cs
  2. 25
      src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolDecoder.cs
  3. 29
      src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolEncoder.cs
  4. 38
      tests/ImageSharp.Tests/Formats/Heif/Av1/Av1SymbolContextTests.cs

56
src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolContextHelper.cs

@ -19,6 +19,15 @@ internal static class Av1SymbolContextHelper
[7, 8, 9, 12, 10, 11, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6], [7, 8, 9, 12, 10, 11, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6],
]; ];
public static readonly int[][] ExtendedTransformIndicesInverse = [
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 0, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 0, 10, 11, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 10, 11, 0, 1, 2, 4, 5, 3, 6, 7, 8, 0, 0, 0, 0],
[9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 4, 5, 3, 6, 7, 8],
];
public static readonly int[] EndOfBlockOffsetBits = [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; public static readonly int[] EndOfBlockOffsetBits = [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
public static readonly int[] EndOfBlockGroupStart = [0, 1, 2, 3, 5, 9, 17, 33, 65, 129, 257, 513]; public static readonly int[] EndOfBlockGroupStart = [0, 1, 2, 3, 5, 9, 17, 33, 65, 129, 257, 513];
private static readonly int[] TransformCountInSet = [1, 2, 5, 7, 12, 16]; private static readonly int[] TransformCountInSet = [1, 2, 5, 7, 12, 16];
@ -216,6 +225,9 @@ internal static class Av1SymbolContextHelper
return Av1NzMap.GetNzMapContextFromStats(stats, levels, position, transformSize, transformClass); return Av1NzMap.GetNzMapContextFromStats(stats, levels, position, transformSize, transformClass);
} }
/// <summary>
/// SVT: get_ext_tx_set_type
/// </summary>
internal static Av1TransformSetType GetExtendedTransformSetType(Av1TransformSize transformSize, bool useReducedSet) internal static Av1TransformSetType GetExtendedTransformSetType(Av1TransformSize transformSize, bool useReducedSet)
{ {
Av1TransformSize squareUpSize = transformSize.GetSquareUpSize(); Av1TransformSize squareUpSize = transformSize.GetSquareUpSize();
@ -285,52 +297,14 @@ internal static class Av1SymbolContextHelper
} }
/// <summary> /// <summary>
/// SVT: get_ext_tx_set_type /// SVT: get_ext_tx_types
/// </summary> /// </summary>
internal static Av1TransformSetType GetExtendedTransformSetType(Av1TransformSize transformSize, bool isInter, bool useReducedTransformSet) internal static int GetExtendedTransformTypeCount(Av1TransformSetType setType) => TransformCountInSet[(int)setType];
{
Av1TransformSize transformSizeSquareUp = transformSize.GetSquareUpSize();
if (transformSizeSquareUp > Av1TransformSize.Size32x32)
{
return Av1TransformSetType.DctOnly;
}
if (transformSizeSquareUp == Av1TransformSize.Size32x32)
{
return isInter ? Av1TransformSetType.DctIdentity : Av1TransformSetType.DctOnly;
}
if (useReducedTransformSet)
{
return isInter ? Av1TransformSetType.DctIdentity : Av1TransformSetType.Dtt4Identity;
}
Av1TransformSize transformSizeSquare = transformSize.GetSquareSize();
if (isInter)
{
return transformSizeSquare == Av1TransformSize.Size16x16 ? Av1TransformSetType.Dtt9Identity1dDct : Av1TransformSetType.All16;
}
else
{
return transformSizeSquare == Av1TransformSize.Size16x16 ? Av1TransformSetType.Dtt4Identity : Av1TransformSetType.Dtt4Identity1dDct;
}
}
internal static int GetExtendedTransformTypeCount(Av1TransformSize transformSize, bool useReducedTransformSet)
{
int setType = (int)GetExtendedTransformSetType(transformSize, useReducedTransformSet);
return TransformCountInSet[setType];
}
/// <summary> /// <summary>
/// SVT: get_ext_tx_set /// SVT: get_ext_tx_set
/// </summary> /// </summary>
internal static int GetExtendedTransformSet(Av1TransformSize transformSize, bool useReducedTransformSet) internal static int GetExtendedTransformSet(Av1TransformSetType setType) => ExtendedTransformSetToIndex[(int)setType];
{
int setType = (int)GetExtendedTransformSetType(transformSize, useReducedTransformSet);
return ExtendedTransformSetToIndex[setType];
}
/// <summary> /// <summary>
/// SVT: set_dc_sign /// SVT: set_dc_sign

25
src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolDecoder.cs

@ -9,15 +9,6 @@ namespace SixLabors.ImageSharp.Formats.Heif.Av1.Entropy;
internal ref struct Av1SymbolDecoder internal ref struct Av1SymbolDecoder
{ {
private static readonly int[][] ExtendedTransformIndicesInverse = [
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 0, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 0, 10, 11, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 10, 11, 0, 1, 2, 4, 5, 3, 6, 7, 8, 0, 0, 0, 0],
[9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 4, 5, 3, 6, 7, 8],
];
private static readonly int[] IntraModeContext = [0, 1, 2, 3, 4, 4, 4, 4, 3, 0, 1, 2, 0]; private static readonly int[] IntraModeContext = [0, 1, 2, 3, 4, 4, 4, 4, 3, 0, 1, 2, 0];
private static readonly int[] AlphaVContexts = [-1, 0, 3, -1, 1, 4, -1, 2, 5]; private static readonly int[] AlphaVContexts = [-1, 0, 3, -1, 1, 4, -1, 2, 5];
@ -199,6 +190,9 @@ internal ref struct Av1SymbolDecoder
return transformSize; return transformSize;
} }
/// <summary>
/// SVT: parse_transform_type
/// </summary>
public Av1TransformType ReadTransformType( public Av1TransformType ReadTransformType(
Av1TransformSize transformSize, Av1TransformSize transformSize,
bool useReducedTransformSet, bool useReducedTransformSet,
@ -216,12 +210,17 @@ internal ref struct Av1SymbolDecoder
return; return;
*/ */
if (baseQIndex == 0)
{
return transformType;
}
// Ignoring INTER blocks here, as these should not end up here. // Ignoring INTER blocks here, as these should not end up here.
// int inter_block = is_inter_block_dec(mbmi); // int inter_block = is_inter_block_dec(mbmi);
Av1TransformSetType transformSetType = Av1SymbolContextHelper.GetExtendedTransformSetType(transformSize, useReducedTransformSet); Av1TransformSetType transformSetType = Av1SymbolContextHelper.GetExtendedTransformSetType(transformSize, useReducedTransformSet);
if (Av1SymbolContextHelper.GetExtendedTransformTypeCount(transformSize, useReducedTransformSet) > 1 && baseQIndex > 0) if (Av1SymbolContextHelper.GetExtendedTransformTypeCount(transformSetType) > 1 && baseQIndex > 0)
{ {
int extendedSet = Av1SymbolContextHelper.GetExtendedTransformSet(transformSize, useReducedTransformSet); int extendedSet = Av1SymbolContextHelper.GetExtendedTransformSet(transformSetType);
// eset == 0 should correspond to a set with only DCT_DCT and // eset == 0 should correspond to a set with only DCT_DCT and
// there is no need to read the tx_type // there is no need to read the tx_type
@ -233,7 +232,7 @@ internal ref struct Av1SymbolDecoder
: intraDirection; : intraDirection;
ref Av1SymbolReader r = ref this.reader; ref Av1SymbolReader r = ref this.reader;
int symbol = r.ReadSymbol(this.intraExtendedTransform[extendedSet][(int)squareTransformSize][(int)intraMode]); int symbol = r.ReadSymbol(this.intraExtendedTransform[extendedSet][(int)squareTransformSize][(int)intraMode]);
transformType = (Av1TransformType)ExtendedTransformIndicesInverse[(int)transformSetType][symbol]; transformType = (Av1TransformType)Av1SymbolContextHelper.ExtendedTransformIndicesInverse[(int)transformSetType][symbol];
} }
return transformType; return transformType;
@ -311,7 +310,7 @@ internal ref struct Av1SymbolDecoder
if (plane == (int)Av1Plane.Y) if (plane == (int)Av1Plane.Y)
{ {
this.ReadTransformType(transformSize, useReducedTransformSet, modeInfo.FilterIntraModeInfo.UseFilterIntra, this.baseQIndex, modeInfo.FilterIntraModeInfo.Mode, modeInfo.YMode); transformInfo.Type = this.ReadTransformType(transformSize, useReducedTransformSet, modeInfo.FilterIntraModeInfo.UseFilterIntra, this.baseQIndex, modeInfo.FilterIntraModeInfo.Mode, modeInfo.YMode);
} }
transformInfo.Type = ComputeTransformType(planeType, modeInfo, isLossless, transformSize, transformInfo, useReducedTransformSet); transformInfo.Type = ComputeTransformType(planeType, modeInfo, isLossless, transformSize, transformInfo, useReducedTransformSet);

29
src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolEncoder.cs

@ -11,15 +11,6 @@ namespace SixLabors.ImageSharp.Formats.Heif.Av1.Entropy;
internal class Av1SymbolEncoder : IDisposable internal class Av1SymbolEncoder : IDisposable
{ {
private static readonly int[][] ExtendedTransformIndices = [
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 3, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 5, 6, 4, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0, 0, 0],
[3, 4, 5, 8, 6, 7, 9, 10, 11, 0, 1, 2, 0, 0, 0, 0],
[7, 8, 9, 12, 10, 11, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6],
];
private readonly Av1Distribution tileIntraBlockCopy = Av1DefaultDistributions.IntraBlockCopy; private readonly Av1Distribution tileIntraBlockCopy = Av1DefaultDistributions.IntraBlockCopy;
private readonly Av1Distribution[] tilePartitionTypes = Av1DefaultDistributions.PartitionTypes; private readonly Av1Distribution[] tilePartitionTypes = Av1DefaultDistributions.PartitionTypes;
private readonly Av1Distribution[][] keyFrameYMode = Av1DefaultDistributions.KeyFrameYMode; private readonly Av1Distribution[][] keyFrameYMode = Av1DefaultDistributions.KeyFrameYMode;
@ -289,35 +280,35 @@ internal class Av1SymbolEncoder : IDisposable
Av1PredictionMode intraDirection) Av1PredictionMode intraDirection)
{ {
// bool isInter = mbmi->block_mi.use_intrabc || is_inter_mode(mbmi->block_mi.mode); // bool isInter = mbmi->block_mi.use_intrabc || is_inter_mode(mbmi->block_mi.mode);
ref Av1SymbolWriter w = ref this.writer; Av1TransformSetType transformSetType = Av1SymbolContextHelper.GetExtendedTransformSetType(transformSize, useReducedTransformSet);
if (Av1SymbolContextHelper.GetExtendedTransformTypeCount(transformSize, useReducedTransformSet) > 1 && baseQIndex > 0) if (Av1SymbolContextHelper.GetExtendedTransformTypeCount(transformSetType) > 1 && baseQIndex > 0)
{ {
Av1TransformSize squareTransformSize = transformSize.GetSquareSize(); Av1TransformSize squareTransformSize = transformSize.GetSquareSize();
Guard.MustBeLessThanOrEqualTo((int)squareTransformSize, Av1Constants.ExtendedTransformCount, nameof(squareTransformSize)); Guard.MustBeLessThanOrEqualTo((int)squareTransformSize, Av1Constants.ExtendedTransformCount, nameof(squareTransformSize));
Av1TransformSetType transformSetType = Av1SymbolContextHelper.GetExtendedTransformSetType(transformSize, useReducedTransformSet); int extendedSet = Av1SymbolContextHelper.GetExtendedTransformSet(transformSetType);
int extendedSet = Av1SymbolContextHelper.GetExtendedTransformSet(transformSize, useReducedTransformSet);
// eset == 0 should correspond to a set with only DCT_DCT and there // eset == 0 should correspond to a set with only DCT_DCT and there
// is no need to send the tx_type // is no need to send the tx_type
Guard.MustBeGreaterThan(extendedSet, 0, nameof(extendedSet)); Guard.MustBeGreaterThan(extendedSet, 0, nameof(extendedSet));
// assert(av1_ext_tx_used[tx_set_type][transformType]); // assert(av1_ext_tx_used[tx_set_type][transformType]);
Av1PredictionMode intraMode; Av1PredictionMode intraDirectionContext;
if (filterIntraMode != Av1FilterIntraMode.AllFilterIntraModes) if (filterIntraMode != Av1FilterIntraMode.AllFilterIntraModes)
{ {
intraMode = filterIntraMode.ToIntraDirection(); intraDirectionContext = filterIntraMode.ToIntraDirection();
} }
else else
{ {
intraMode = intraDirection; intraDirectionContext = intraDirection;
} }
Guard.MustBeLessThan((int)intraMode, 13, nameof(intraMode)); Guard.MustBeLessThan((int)intraDirectionContext, 13, nameof(intraDirectionContext));
Guard.MustBeLessThan((int)squareTransformSize, 4, nameof(squareTransformSize)); Guard.MustBeLessThan((int)squareTransformSize, 4, nameof(squareTransformSize));
ref Av1SymbolWriter w = ref this.writer;
w.WriteSymbol( w.WriteSymbol(
ExtendedTransformIndices[(int)transformSetType][(int)transformType], Av1SymbolContextHelper.ExtendedTransformIndices[(int)transformSetType][(int)transformType],
this.intraExtendedTransform[extendedSet][(int)squareTransformSize][(int)intraMode]); this.intraExtendedTransform[extendedSet][(int)squareTransformSize][(int)intraDirectionContext]);
} }
} }

38
tests/ImageSharp.Tests/Formats/Heif/Av1/Av1SymbolContextTests.cs

@ -1,9 +1,11 @@
// Copyright (c) Six Labors. // Copyright (c) Six Labors.
// Licensed under the Six Labors Split License. // Licensed under the Six Labors Split License.
using Microsoft.Diagnostics.Symbols;
using SixLabors.ImageSharp.Formats.Heif.Av1; using SixLabors.ImageSharp.Formats.Heif.Av1;
using SixLabors.ImageSharp.Formats.Heif.Av1.Entropy; using SixLabors.ImageSharp.Formats.Heif.Av1.Entropy;
using SixLabors.ImageSharp.Formats.Heif.Av1.Tiling; using SixLabors.ImageSharp.Formats.Heif.Av1.Tiling;
using SixLabors.ImageSharp.Formats.Heif.Av1.Transform;
namespace SixLabors.ImageSharp.Tests.Formats.Heif.Av1; namespace SixLabors.ImageSharp.Tests.Formats.Heif.Av1;
@ -11,8 +13,8 @@ namespace SixLabors.ImageSharp.Tests.Formats.Heif.Av1;
public class Av1SymbolContextTests public class Av1SymbolContextTests
{ {
[Theory] [Theory]
[MemberData(nameof(GetCombinations))] [MemberData(nameof(GetLowLevelContextEndOfBlockData))]
public void TestAccuracy(int width, int height, int index) public void TestLowLevelContextEndOfBlockAccuracy(int width, int height, int index)
{ {
// Arrange // Arrange
Size size = new(width, height); Size size = new(width, height);
@ -28,7 +30,22 @@ public class Av1SymbolContextTests
Assert.Equal(expectedContext, actualContext); Assert.Equal(expectedContext, actualContext);
} }
public static TheoryData<int, int, int> GetCombinations() [Theory]
[MemberData(nameof(GetExtendedTransformIndicesData))]
public void RoundTripExtendedTransformIndices(int setType, int index)
{
// Arrange
Av1TransformSetType transformSetType = (Av1TransformSetType)setType;
// Act
int transformType = Av1SymbolContextHelper.ExtendedTransformIndicesInverse[(int)transformSetType][index];
int actualIndex = Av1SymbolContextHelper.ExtendedTransformIndices[(int)transformSetType][transformType];
// Assert
Assert.Equal(actualIndex, index);
}
public static TheoryData<int, int, int> GetLowLevelContextEndOfBlockData()
{ {
TheoryData<int, int, int> result = []; TheoryData<int, int, int> result = [];
for (int y = 1; y < 6; y++) for (int y = 1; y < 6; y++)
@ -46,6 +63,21 @@ public class Av1SymbolContextTests
return result; return result;
} }
public static TheoryData<int, int> GetExtendedTransformIndicesData()
{
TheoryData<int, int> result = [];
for (Av1TransformSetType setType = Av1TransformSetType.DctOnly; setType <= Av1TransformSetType.All16; setType++)
{
int count = Av1SymbolContextHelper.GetExtendedTransformTypeCount(setType);
for (int type = 1; type < count; type++)
{
result.Add((int)setType, type);
}
}
return result;
}
/// <summary> /// <summary>
/// SVT: get_lower_levels_ctx_eob /// SVT: get_lower_levels_ctx_eob
/// </summary> /// </summary>

Loading…
Cancel
Save