diff --git a/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolContextHelper.cs b/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolContextHelper.cs index d257c3a8b5..9d6b15e6f4 100644 --- a/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolContextHelper.cs +++ b/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolContextHelper.cs @@ -19,6 +19,15 @@ internal static class Av1SymbolContextHelper [7, 8, 9, 12, 10, 11, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6], ]; + public static readonly int[][] ExtendedTransformIndicesInverse = [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [9, 0, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [9, 0, 10, 11, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [9, 10, 11, 0, 1, 2, 4, 5, 3, 6, 7, 8, 0, 0, 0, 0], + [9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 4, 5, 3, 6, 7, 8], + ]; + public static readonly int[] EndOfBlockOffsetBits = [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; public static readonly int[] EndOfBlockGroupStart = [0, 1, 2, 3, 5, 9, 17, 33, 65, 129, 257, 513]; private static readonly int[] TransformCountInSet = [1, 2, 5, 7, 12, 16]; @@ -216,6 +225,9 @@ internal static class Av1SymbolContextHelper return Av1NzMap.GetNzMapContextFromStats(stats, levels, position, transformSize, transformClass); } + /// + /// SVT: get_ext_tx_set_type + /// internal static Av1TransformSetType GetExtendedTransformSetType(Av1TransformSize transformSize, bool useReducedSet) { Av1TransformSize squareUpSize = transformSize.GetSquareUpSize(); @@ -285,52 +297,14 @@ internal static class Av1SymbolContextHelper } /// - /// SVT: get_ext_tx_set_type + /// SVT: get_ext_tx_types /// - internal static Av1TransformSetType GetExtendedTransformSetType(Av1TransformSize transformSize, bool isInter, bool useReducedTransformSet) - { - Av1TransformSize transformSizeSquareUp = transformSize.GetSquareUpSize(); - - if (transformSizeSquareUp > Av1TransformSize.Size32x32) - { - return Av1TransformSetType.DctOnly; - } - - if (transformSizeSquareUp == Av1TransformSize.Size32x32) - { - return isInter ? Av1TransformSetType.DctIdentity : Av1TransformSetType.DctOnly; - } - - if (useReducedTransformSet) - { - return isInter ? Av1TransformSetType.DctIdentity : Av1TransformSetType.Dtt4Identity; - } - - Av1TransformSize transformSizeSquare = transformSize.GetSquareSize(); - if (isInter) - { - return transformSizeSquare == Av1TransformSize.Size16x16 ? Av1TransformSetType.Dtt9Identity1dDct : Av1TransformSetType.All16; - } - else - { - return transformSizeSquare == Av1TransformSize.Size16x16 ? Av1TransformSetType.Dtt4Identity : Av1TransformSetType.Dtt4Identity1dDct; - } - } - - internal static int GetExtendedTransformTypeCount(Av1TransformSize transformSize, bool useReducedTransformSet) - { - int setType = (int)GetExtendedTransformSetType(transformSize, useReducedTransformSet); - return TransformCountInSet[setType]; - } + internal static int GetExtendedTransformTypeCount(Av1TransformSetType setType) => TransformCountInSet[(int)setType]; /// /// SVT: get_ext_tx_set /// - internal static int GetExtendedTransformSet(Av1TransformSize transformSize, bool useReducedTransformSet) - { - int setType = (int)GetExtendedTransformSetType(transformSize, useReducedTransformSet); - return ExtendedTransformSetToIndex[setType]; - } + internal static int GetExtendedTransformSet(Av1TransformSetType setType) => ExtendedTransformSetToIndex[(int)setType]; /// /// SVT: set_dc_sign diff --git a/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolDecoder.cs b/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolDecoder.cs index e74b905c2a..f265156a76 100644 --- a/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolDecoder.cs +++ b/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolDecoder.cs @@ -9,15 +9,6 @@ namespace SixLabors.ImageSharp.Formats.Heif.Av1.Entropy; internal ref struct Av1SymbolDecoder { - private static readonly int[][] ExtendedTransformIndicesInverse = [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [9, 0, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [9, 0, 10, 11, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [9, 10, 11, 0, 1, 2, 4, 5, 3, 6, 7, 8, 0, 0, 0, 0], - [9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 4, 5, 3, 6, 7, 8], - ]; - private static readonly int[] IntraModeContext = [0, 1, 2, 3, 4, 4, 4, 4, 3, 0, 1, 2, 0]; private static readonly int[] AlphaVContexts = [-1, 0, 3, -1, 1, 4, -1, 2, 5]; @@ -199,6 +190,9 @@ internal ref struct Av1SymbolDecoder return transformSize; } + /// + /// SVT: parse_transform_type + /// public Av1TransformType ReadTransformType( Av1TransformSize transformSize, bool useReducedTransformSet, @@ -216,12 +210,17 @@ internal ref struct Av1SymbolDecoder return; */ + if (baseQIndex == 0) + { + return transformType; + } + // Ignoring INTER blocks here, as these should not end up here. // int inter_block = is_inter_block_dec(mbmi); Av1TransformSetType transformSetType = Av1SymbolContextHelper.GetExtendedTransformSetType(transformSize, useReducedTransformSet); - if (Av1SymbolContextHelper.GetExtendedTransformTypeCount(transformSize, useReducedTransformSet) > 1 && baseQIndex > 0) + if (Av1SymbolContextHelper.GetExtendedTransformTypeCount(transformSetType) > 1 && baseQIndex > 0) { - int extendedSet = Av1SymbolContextHelper.GetExtendedTransformSet(transformSize, useReducedTransformSet); + int extendedSet = Av1SymbolContextHelper.GetExtendedTransformSet(transformSetType); // eset == 0 should correspond to a set with only DCT_DCT and // there is no need to read the tx_type @@ -233,7 +232,7 @@ internal ref struct Av1SymbolDecoder : intraDirection; ref Av1SymbolReader r = ref this.reader; int symbol = r.ReadSymbol(this.intraExtendedTransform[extendedSet][(int)squareTransformSize][(int)intraMode]); - transformType = (Av1TransformType)ExtendedTransformIndicesInverse[(int)transformSetType][symbol]; + transformType = (Av1TransformType)Av1SymbolContextHelper.ExtendedTransformIndicesInverse[(int)transformSetType][symbol]; } return transformType; @@ -311,7 +310,7 @@ internal ref struct Av1SymbolDecoder if (plane == (int)Av1Plane.Y) { - this.ReadTransformType(transformSize, useReducedTransformSet, modeInfo.FilterIntraModeInfo.UseFilterIntra, this.baseQIndex, modeInfo.FilterIntraModeInfo.Mode, modeInfo.YMode); + transformInfo.Type = this.ReadTransformType(transformSize, useReducedTransformSet, modeInfo.FilterIntraModeInfo.UseFilterIntra, this.baseQIndex, modeInfo.FilterIntraModeInfo.Mode, modeInfo.YMode); } transformInfo.Type = ComputeTransformType(planeType, modeInfo, isLossless, transformSize, transformInfo, useReducedTransformSet); diff --git a/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolEncoder.cs b/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolEncoder.cs index 72db8b02e7..1c9ef131c2 100644 --- a/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolEncoder.cs +++ b/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolEncoder.cs @@ -11,15 +11,6 @@ namespace SixLabors.ImageSharp.Formats.Heif.Av1.Entropy; internal class Av1SymbolEncoder : IDisposable { - private static readonly int[][] ExtendedTransformIndices = [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 3, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 5, 6, 4, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0, 0, 0], - [3, 4, 5, 8, 6, 7, 9, 10, 11, 0, 1, 2, 0, 0, 0, 0], - [7, 8, 9, 12, 10, 11, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6], - ]; - private readonly Av1Distribution tileIntraBlockCopy = Av1DefaultDistributions.IntraBlockCopy; private readonly Av1Distribution[] tilePartitionTypes = Av1DefaultDistributions.PartitionTypes; private readonly Av1Distribution[][] keyFrameYMode = Av1DefaultDistributions.KeyFrameYMode; @@ -289,35 +280,35 @@ internal class Av1SymbolEncoder : IDisposable Av1PredictionMode intraDirection) { // bool isInter = mbmi->block_mi.use_intrabc || is_inter_mode(mbmi->block_mi.mode); - ref Av1SymbolWriter w = ref this.writer; - if (Av1SymbolContextHelper.GetExtendedTransformTypeCount(transformSize, useReducedTransformSet) > 1 && baseQIndex > 0) + Av1TransformSetType transformSetType = Av1SymbolContextHelper.GetExtendedTransformSetType(transformSize, useReducedTransformSet); + if (Av1SymbolContextHelper.GetExtendedTransformTypeCount(transformSetType) > 1 && baseQIndex > 0) { Av1TransformSize squareTransformSize = transformSize.GetSquareSize(); Guard.MustBeLessThanOrEqualTo((int)squareTransformSize, Av1Constants.ExtendedTransformCount, nameof(squareTransformSize)); - Av1TransformSetType transformSetType = Av1SymbolContextHelper.GetExtendedTransformSetType(transformSize, useReducedTransformSet); - int extendedSet = Av1SymbolContextHelper.GetExtendedTransformSet(transformSize, useReducedTransformSet); + int extendedSet = Av1SymbolContextHelper.GetExtendedTransformSet(transformSetType); // eset == 0 should correspond to a set with only DCT_DCT and there // is no need to send the tx_type Guard.MustBeGreaterThan(extendedSet, 0, nameof(extendedSet)); // assert(av1_ext_tx_used[tx_set_type][transformType]); - Av1PredictionMode intraMode; + Av1PredictionMode intraDirectionContext; if (filterIntraMode != Av1FilterIntraMode.AllFilterIntraModes) { - intraMode = filterIntraMode.ToIntraDirection(); + intraDirectionContext = filterIntraMode.ToIntraDirection(); } else { - intraMode = intraDirection; + intraDirectionContext = intraDirection; } - Guard.MustBeLessThan((int)intraMode, 13, nameof(intraMode)); + Guard.MustBeLessThan((int)intraDirectionContext, 13, nameof(intraDirectionContext)); Guard.MustBeLessThan((int)squareTransformSize, 4, nameof(squareTransformSize)); + ref Av1SymbolWriter w = ref this.writer; w.WriteSymbol( - ExtendedTransformIndices[(int)transformSetType][(int)transformType], - this.intraExtendedTransform[extendedSet][(int)squareTransformSize][(int)intraMode]); + Av1SymbolContextHelper.ExtendedTransformIndices[(int)transformSetType][(int)transformType], + this.intraExtendedTransform[extendedSet][(int)squareTransformSize][(int)intraDirectionContext]); } } diff --git a/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1SymbolContextTests.cs b/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1SymbolContextTests.cs index 01c88913e9..f21aa57159 100644 --- a/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1SymbolContextTests.cs +++ b/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1SymbolContextTests.cs @@ -1,9 +1,11 @@ // Copyright (c) Six Labors. // Licensed under the Six Labors Split License. +using Microsoft.Diagnostics.Symbols; using SixLabors.ImageSharp.Formats.Heif.Av1; using SixLabors.ImageSharp.Formats.Heif.Av1.Entropy; using SixLabors.ImageSharp.Formats.Heif.Av1.Tiling; +using SixLabors.ImageSharp.Formats.Heif.Av1.Transform; namespace SixLabors.ImageSharp.Tests.Formats.Heif.Av1; @@ -11,8 +13,8 @@ namespace SixLabors.ImageSharp.Tests.Formats.Heif.Av1; public class Av1SymbolContextTests { [Theory] - [MemberData(nameof(GetCombinations))] - public void TestAccuracy(int width, int height, int index) + [MemberData(nameof(GetLowLevelContextEndOfBlockData))] + public void TestLowLevelContextEndOfBlockAccuracy(int width, int height, int index) { // Arrange Size size = new(width, height); @@ -28,7 +30,22 @@ public class Av1SymbolContextTests Assert.Equal(expectedContext, actualContext); } - public static TheoryData GetCombinations() + [Theory] + [MemberData(nameof(GetExtendedTransformIndicesData))] + public void RoundTripExtendedTransformIndices(int setType, int index) + { + // Arrange + Av1TransformSetType transformSetType = (Av1TransformSetType)setType; + + // Act + int transformType = Av1SymbolContextHelper.ExtendedTransformIndicesInverse[(int)transformSetType][index]; + int actualIndex = Av1SymbolContextHelper.ExtendedTransformIndices[(int)transformSetType][transformType]; + + // Assert + Assert.Equal(actualIndex, index); + } + + public static TheoryData GetLowLevelContextEndOfBlockData() { TheoryData result = []; for (int y = 1; y < 6; y++) @@ -46,6 +63,21 @@ public class Av1SymbolContextTests return result; } + public static TheoryData GetExtendedTransformIndicesData() + { + TheoryData result = []; + for (Av1TransformSetType setType = Av1TransformSetType.DctOnly; setType <= Av1TransformSetType.All16; setType++) + { + int count = Av1SymbolContextHelper.GetExtendedTransformTypeCount(setType); + for (int type = 1; type < count; type++) + { + result.Add((int)setType, type); + } + } + + return result; + } + /// /// SVT: get_lower_levels_ctx_eob ///