diff --git a/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolContextHelper.cs b/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolContextHelper.cs index 07067974b7..9addf24e3e 100644 --- a/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolContextHelper.cs +++ b/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolContextHelper.cs @@ -66,23 +66,6 @@ internal static class Av1SymbolContextHelper return endOfBlock; } - internal static int GetBaseRangeContextEndOfBlock(Point pos, Av1TransformClass transformClass) - { - if (pos.X == 0 && pos.Y == 0) - { - return 0; - } - - if ((transformClass == Av1TransformClass.Class2D && pos.Y < 2 && pos.X < 2) || - (transformClass == Av1TransformClass.ClassHorizontal && pos.X == 0) || - (transformClass == Av1TransformClass.ClassVertical && pos.Y == 0)) - { - return 7; - } - - return 14; - } - /// /// SVT: get_lower_levels_ctx_eob /// @@ -108,27 +91,6 @@ internal static class Av1SymbolContextHelper return 3; } - /// - /// SVT: get_br_ctx_2d - /// - internal static int GetBaseRangeContext2d(Av1LevelBuffer levels, Point position) - { - DebugGuard.MustBeGreaterThan(position.X + position.Y, 0, nameof(position)); - Span row0 = levels.GetRow(position.Y); - Span row1 = levels.GetRow(position.Y + 1); - int mag = - Math.Min((int)row0[1], Av1Constants.MaxBaseRange) + - Math.Min((int)row1[0], Av1Constants.MaxBaseRange) + - Math.Min((int)row1[1], Av1Constants.MaxBaseRange); - mag = Math.Min((mag + 1) >> 1, 6); - if ((position.Y | position.X) < 2) - { - return mag + 7; - } - - return mag + 14; - } - /// /// SVT: get_lower_levels_ctx_2d /// @@ -150,6 +112,23 @@ internal static class Av1SymbolContextHelper return ctx + Av1NzMap.GetNzMapContext(transformSize, index); } + internal static int GetBaseRangeContextEndOfBlock(Point pos, Av1TransformClass transformClass) + { + if (pos.X == 0 && pos.Y == 0) + { + return 0; + } + + if ((transformClass == Av1TransformClass.Class2D && pos.Y < 2 && pos.X < 2) || + (transformClass == Av1TransformClass.ClassHorizontal && pos.X == 0) || + (transformClass == Av1TransformClass.ClassVertical && pos.Y == 0)) + { + return 7; + } + + return 14; + } + /// /// SVT: get_br_ctx /// @@ -210,6 +189,27 @@ internal static class Av1SymbolContextHelper return mag + 14; } + /// + /// SVT: get_br_ctx_2d + /// + internal static int GetBaseRangeContext2d(Av1LevelBuffer levels, Point position) + { + DebugGuard.MustBeGreaterThan(position.X + position.Y, 0, nameof(position)); + Span row0 = levels.GetRow(position.Y); + Span row1 = levels.GetRow(position.Y + 1); + int mag = + Math.Min((int)row0[1], Av1Constants.MaxBaseRange) + + Math.Min((int)row1[0], Av1Constants.MaxBaseRange) + + Math.Min((int)row1[1], Av1Constants.MaxBaseRange); + mag = Math.Min((mag + 1) >> 1, 6); + if (position.Y < 2 && position.X < 2) + { + return mag + 7; + } + + return mag + 14; + } + internal static int GetLowerLevelsContext(Av1LevelBuffer levels, Point position, Av1TransformSize transformSize, Av1TransformClass transformClass) { int stats = Av1NzMap.GetNzMagnitude(levels, position, transformClass); diff --git a/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolDecoder.cs b/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolDecoder.cs index 982c089dc5..e1ed5c1c28 100644 --- a/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolDecoder.cs +++ b/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolDecoder.cs @@ -332,8 +332,6 @@ internal ref struct Av1SymbolDecoder } } - coefficientBuffer[0] = endOfBlock; - DebugGuard.MustBeGreaterThan(scan.Length, 0, nameof(scan)); culLevel = this.ReadCoefficientsSign(coefficientBuffer, endOfBlock, scan, levels, transformBlockContext.DcSignContext, planeType); UpdateCoefficientContext(modeInfo, aboveContexts, leftContexts, blocksWide, blocksHigh, transformSize, blockPosition, aboveOffset, leftOffset, culLevel, modeBlocksToRightEdge, modeBlocksToBottomEdge); @@ -355,6 +353,7 @@ internal ref struct Av1SymbolDecoder { Av1Math.SetBit(ref endOfBlockExtra, endOfBlockShift - 1); } + for (int j = 1; j < endOfBlockShift; j++) { if (this.ReadLiteral(1) != 0) @@ -373,12 +372,13 @@ internal ref struct Av1SymbolDecoder Point position = levels.GetPosition(scan[i]); int coefficientContext = Av1SymbolContextHelper.GetLowerLevelContextEndOfBlock(levels, position); int level = this.ReadBaseEndOfBlock(transformSizeContext, planeType, coefficientContext); + Av1TransformSize limitedTransformSizeContext = (Av1TransformSize)Math.Min((int)transformSizeContext, (int)Av1TransformSize.Size32x32); if (level > Av1Constants.BaseLevelsCount) { int baseRangeContext = Av1SymbolContextHelper.GetBaseRangeContextEndOfBlock(position, transformClass); for (int idx = 0; idx < Av1Constants.CoefficientBaseRange; idx += Av1Constants.BaseRangeSizeMinus1) { - int coefficientBaseRange = this.ReadCoefficientsBaseRange(transformSizeContext, planeType, baseRangeContext); + int coefficientBaseRange = this.ReadCoefficientsBaseRange(limitedTransformSizeContext, planeType, baseRangeContext); level += coefficientBaseRange; if (coefficientBaseRange < Av1Constants.BaseRangeSizeMinus1) { @@ -387,11 +387,12 @@ internal ref struct Av1SymbolDecoder } } - levels.GetRow(position)[0] = (byte)level; + levels.GetRow(position)[position.X] = (byte)level; } public void ReadCoefficientsReverse2d(Av1TransformSize transformSize, int startScanIndex, int endScanIndex, ReadOnlySpan scan, Av1LevelBuffer levels, Av1TransformSize transformSizeContext, Av1PlaneType planeType) { + Av1TransformSize limitedTransformSizeContext = (Av1TransformSize)Math.Min((int)transformSizeContext, (int)Av1TransformSize.Size32x32); for (int c = endScanIndex; c >= startScanIndex; --c) { Point position = levels.GetPosition(scan[c]); @@ -402,7 +403,7 @@ internal ref struct Av1SymbolDecoder int baseRangeContext = Av1SymbolContextHelper.GetBaseRangeContext2d(levels, position); for (int idx = 0; idx < Av1Constants.CoefficientBaseRange; idx += Av1Constants.BaseRangeSizeMinus1) { - int coefficientBaseRange = this.ReadCoefficientsBaseRange(transformSizeContext, planeType, baseRangeContext); + int coefficientBaseRange = this.ReadCoefficientsBaseRange(limitedTransformSizeContext, planeType, baseRangeContext); level += coefficientBaseRange; if (coefficientBaseRange < Av1Constants.BaseRangeSizeMinus1) { @@ -411,12 +412,13 @@ internal ref struct Av1SymbolDecoder } } - levels.GetRow(position)[0] = (byte)level; + levels.GetRow(position)[position.X] = (byte)level; } } public void ReadCoefficientsReverse(Av1TransformSize transformSize, Av1TransformClass transformClass, int startScanIndex, int endScanIndex, ReadOnlySpan scan, Av1LevelBuffer levels, Av1TransformSize transformSizeContext, Av1PlaneType planeType) { + Av1TransformSize limitedTransformSizeContext = (Av1TransformSize)Math.Min((int)transformSizeContext, (int)Av1TransformSize.Size32x32); for (int c = endScanIndex; c >= startScanIndex; --c) { int pos = scan[c]; @@ -428,16 +430,16 @@ internal ref struct Av1SymbolDecoder int baseRangeContext = Av1SymbolContextHelper.GetBaseRangeContext(levels, position, transformClass); for (int idx = 0; idx < Av1Constants.CoefficientBaseRange; idx += Av1Constants.BaseRangeSizeMinus1) { - int k = this.ReadCoefficientsBaseRange(transformSizeContext, planeType, baseRangeContext); - level += k; - if (k < Av1Constants.BaseRangeSizeMinus1) + int coefficientBaseRange = this.ReadCoefficientsBaseRange(limitedTransformSizeContext, planeType, baseRangeContext); + level += coefficientBaseRange; + if (coefficientBaseRange < Av1Constants.BaseRangeSizeMinus1) { break; } } } - levels.GetRow(position)[0] = (byte)level; + levels.GetRow(position)[position.X] = (byte)level; } } @@ -451,7 +453,7 @@ internal ref struct Av1SymbolDecoder { int sign = 0; Point position = levels.GetPosition(scan[c]); - int level = levels.GetRow(position)[0]; + int level = levels.GetRow(position)[position.X]; if (level != 0) { maxScanLine = Math.Max(maxScanLine, scan[c]); @@ -504,8 +506,7 @@ internal ref struct Av1SymbolDecoder private int ReadCoefficientsBaseRange(Av1TransformSize transformSizeContext, Av1PlaneType planeType, int baseRangeContext) { ref Av1SymbolReader r = ref this.reader; - int transformContext = Math.Min((int)transformSizeContext, (int)Av1TransformSize.Size32x32); - return r.ReadSymbol(this.coefficientsBaseRange[transformContext][(int)planeType][baseRangeContext]); + return r.ReadSymbol(this.coefficientsBaseRange[(int)transformSizeContext][(int)planeType][baseRangeContext]); } private int ReadDcSign(Av1PlaneType planeType, int dcSignContext) diff --git a/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1CoefficientsEntropyTests.cs b/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1CoefficientsEntropyTests.cs index db4782e59e..08367e4e3d 100644 --- a/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1CoefficientsEntropyTests.cs +++ b/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1CoefficientsEntropyTests.cs @@ -100,7 +100,41 @@ public class Av1CoefficientsEntropyTests // Assert Assert.Equal(endOfBlock, actuals[0]); + } + + [Fact] + public void RoundTripFullCoefficients() + { + // Assign + const ushort endOfBlock = 16; + const Av1BlockSize blockSize = Av1BlockSize.Block4x4; + const Av1TransformSize transformSize = Av1TransformSize.Size4x4; + const Av1TransformType transformType = Av1TransformType.Identity; + const Av1PredictionMode intraDirection = Av1PredictionMode.DC; + const Av1ComponentType componentType = Av1ComponentType.Luminance; + const Av1FilterIntraMode filterIntraMode = Av1FilterIntraMode.DC; + Av1BlockModeInfo modeInfo = new(Av1Constants.MaxPlanes, blockSize, new Point(0, 0)); + Av1TransformInfo transformInfo = new(transformSize, 0, 0); + int[] aboveContexts = new int[1]; + int[] leftContexts = new int[1]; + Av1TransformBlockContext transformBlockContext = new(); + Configuration configuration = Configuration.Default; + Av1SymbolEncoder encoder = new(configuration, 100 / 8, BaseQIndex); + Span coefficientsBuffer = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + Span actuals = new int[16 + 1]; + + // Act + encoder.WriteCoefficients(transformSize, transformType, intraDirection, coefficientsBuffer, componentType, transformBlockContext, endOfBlock, true, filterIntraMode); + + using IMemoryOwner encoded = encoder.Exit(); + + Av1SymbolDecoder decoder = new(Configuration.Default, encoded.GetSpan(), BaseQIndex); + Av1SymbolReader reader = new(encoded.GetSpan()); + int plane = Math.Min((int)componentType, 1); + decoder.ReadCoefficients(modeInfo, new Point(0, 0), aboveContexts, leftContexts, 0, 0, plane, 1, 1, transformBlockContext, transformSize, false, true, transformInfo, 0, 0, actuals); - // Assert.Equal(coefficientsBuffer[..endOfBlock], actuals[1..(endOfBlock + 1)]); + // Assert + Assert.Equal(endOfBlock, actuals[0]); + Assert.Equal(coefficientsBuffer[..endOfBlock], actuals[1..(endOfBlock + 1)]); } }