diff --git a/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolContextHelper.cs b/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolContextHelper.cs
index 07067974b7..9addf24e3e 100644
--- a/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolContextHelper.cs
+++ b/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolContextHelper.cs
@@ -66,23 +66,6 @@ internal static class Av1SymbolContextHelper
return endOfBlock;
}
- internal static int GetBaseRangeContextEndOfBlock(Point pos, Av1TransformClass transformClass)
- {
- if (pos.X == 0 && pos.Y == 0)
- {
- return 0;
- }
-
- if ((transformClass == Av1TransformClass.Class2D && pos.Y < 2 && pos.X < 2) ||
- (transformClass == Av1TransformClass.ClassHorizontal && pos.X == 0) ||
- (transformClass == Av1TransformClass.ClassVertical && pos.Y == 0))
- {
- return 7;
- }
-
- return 14;
- }
-
///
/// SVT: get_lower_levels_ctx_eob
///
@@ -108,27 +91,6 @@ internal static class Av1SymbolContextHelper
return 3;
}
- ///
- /// SVT: get_br_ctx_2d
- ///
- internal static int GetBaseRangeContext2d(Av1LevelBuffer levels, Point position)
- {
- DebugGuard.MustBeGreaterThan(position.X + position.Y, 0, nameof(position));
- Span row0 = levels.GetRow(position.Y);
- Span row1 = levels.GetRow(position.Y + 1);
- int mag =
- Math.Min((int)row0[1], Av1Constants.MaxBaseRange) +
- Math.Min((int)row1[0], Av1Constants.MaxBaseRange) +
- Math.Min((int)row1[1], Av1Constants.MaxBaseRange);
- mag = Math.Min((mag + 1) >> 1, 6);
- if ((position.Y | position.X) < 2)
- {
- return mag + 7;
- }
-
- return mag + 14;
- }
-
///
/// SVT: get_lower_levels_ctx_2d
///
@@ -150,6 +112,23 @@ internal static class Av1SymbolContextHelper
return ctx + Av1NzMap.GetNzMapContext(transformSize, index);
}
+ internal static int GetBaseRangeContextEndOfBlock(Point pos, Av1TransformClass transformClass)
+ {
+ if (pos.X == 0 && pos.Y == 0)
+ {
+ return 0;
+ }
+
+ if ((transformClass == Av1TransformClass.Class2D && pos.Y < 2 && pos.X < 2) ||
+ (transformClass == Av1TransformClass.ClassHorizontal && pos.X == 0) ||
+ (transformClass == Av1TransformClass.ClassVertical && pos.Y == 0))
+ {
+ return 7;
+ }
+
+ return 14;
+ }
+
///
/// SVT: get_br_ctx
///
@@ -210,6 +189,27 @@ internal static class Av1SymbolContextHelper
return mag + 14;
}
+ ///
+ /// SVT: get_br_ctx_2d
+ ///
+ internal static int GetBaseRangeContext2d(Av1LevelBuffer levels, Point position)
+ {
+ DebugGuard.MustBeGreaterThan(position.X + position.Y, 0, nameof(position));
+ Span row0 = levels.GetRow(position.Y);
+ Span row1 = levels.GetRow(position.Y + 1);
+ int mag =
+ Math.Min((int)row0[1], Av1Constants.MaxBaseRange) +
+ Math.Min((int)row1[0], Av1Constants.MaxBaseRange) +
+ Math.Min((int)row1[1], Av1Constants.MaxBaseRange);
+ mag = Math.Min((mag + 1) >> 1, 6);
+ if (position.Y < 2 && position.X < 2)
+ {
+ return mag + 7;
+ }
+
+ return mag + 14;
+ }
+
internal static int GetLowerLevelsContext(Av1LevelBuffer levels, Point position, Av1TransformSize transformSize, Av1TransformClass transformClass)
{
int stats = Av1NzMap.GetNzMagnitude(levels, position, transformClass);
diff --git a/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolDecoder.cs b/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolDecoder.cs
index 982c089dc5..e1ed5c1c28 100644
--- a/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolDecoder.cs
+++ b/src/ImageSharp/Formats/Heif/Av1/Entropy/Av1SymbolDecoder.cs
@@ -332,8 +332,6 @@ internal ref struct Av1SymbolDecoder
}
}
- coefficientBuffer[0] = endOfBlock;
-
DebugGuard.MustBeGreaterThan(scan.Length, 0, nameof(scan));
culLevel = this.ReadCoefficientsSign(coefficientBuffer, endOfBlock, scan, levels, transformBlockContext.DcSignContext, planeType);
UpdateCoefficientContext(modeInfo, aboveContexts, leftContexts, blocksWide, blocksHigh, transformSize, blockPosition, aboveOffset, leftOffset, culLevel, modeBlocksToRightEdge, modeBlocksToBottomEdge);
@@ -355,6 +353,7 @@ internal ref struct Av1SymbolDecoder
{
Av1Math.SetBit(ref endOfBlockExtra, endOfBlockShift - 1);
}
+
for (int j = 1; j < endOfBlockShift; j++)
{
if (this.ReadLiteral(1) != 0)
@@ -373,12 +372,13 @@ internal ref struct Av1SymbolDecoder
Point position = levels.GetPosition(scan[i]);
int coefficientContext = Av1SymbolContextHelper.GetLowerLevelContextEndOfBlock(levels, position);
int level = this.ReadBaseEndOfBlock(transformSizeContext, planeType, coefficientContext);
+ Av1TransformSize limitedTransformSizeContext = (Av1TransformSize)Math.Min((int)transformSizeContext, (int)Av1TransformSize.Size32x32);
if (level > Av1Constants.BaseLevelsCount)
{
int baseRangeContext = Av1SymbolContextHelper.GetBaseRangeContextEndOfBlock(position, transformClass);
for (int idx = 0; idx < Av1Constants.CoefficientBaseRange; idx += Av1Constants.BaseRangeSizeMinus1)
{
- int coefficientBaseRange = this.ReadCoefficientsBaseRange(transformSizeContext, planeType, baseRangeContext);
+ int coefficientBaseRange = this.ReadCoefficientsBaseRange(limitedTransformSizeContext, planeType, baseRangeContext);
level += coefficientBaseRange;
if (coefficientBaseRange < Av1Constants.BaseRangeSizeMinus1)
{
@@ -387,11 +387,12 @@ internal ref struct Av1SymbolDecoder
}
}
- levels.GetRow(position)[0] = (byte)level;
+ levels.GetRow(position)[position.X] = (byte)level;
}
public void ReadCoefficientsReverse2d(Av1TransformSize transformSize, int startScanIndex, int endScanIndex, ReadOnlySpan scan, Av1LevelBuffer levels, Av1TransformSize transformSizeContext, Av1PlaneType planeType)
{
+ Av1TransformSize limitedTransformSizeContext = (Av1TransformSize)Math.Min((int)transformSizeContext, (int)Av1TransformSize.Size32x32);
for (int c = endScanIndex; c >= startScanIndex; --c)
{
Point position = levels.GetPosition(scan[c]);
@@ -402,7 +403,7 @@ internal ref struct Av1SymbolDecoder
int baseRangeContext = Av1SymbolContextHelper.GetBaseRangeContext2d(levels, position);
for (int idx = 0; idx < Av1Constants.CoefficientBaseRange; idx += Av1Constants.BaseRangeSizeMinus1)
{
- int coefficientBaseRange = this.ReadCoefficientsBaseRange(transformSizeContext, planeType, baseRangeContext);
+ int coefficientBaseRange = this.ReadCoefficientsBaseRange(limitedTransformSizeContext, planeType, baseRangeContext);
level += coefficientBaseRange;
if (coefficientBaseRange < Av1Constants.BaseRangeSizeMinus1)
{
@@ -411,12 +412,13 @@ internal ref struct Av1SymbolDecoder
}
}
- levels.GetRow(position)[0] = (byte)level;
+ levels.GetRow(position)[position.X] = (byte)level;
}
}
public void ReadCoefficientsReverse(Av1TransformSize transformSize, Av1TransformClass transformClass, int startScanIndex, int endScanIndex, ReadOnlySpan scan, Av1LevelBuffer levels, Av1TransformSize transformSizeContext, Av1PlaneType planeType)
{
+ Av1TransformSize limitedTransformSizeContext = (Av1TransformSize)Math.Min((int)transformSizeContext, (int)Av1TransformSize.Size32x32);
for (int c = endScanIndex; c >= startScanIndex; --c)
{
int pos = scan[c];
@@ -428,16 +430,16 @@ internal ref struct Av1SymbolDecoder
int baseRangeContext = Av1SymbolContextHelper.GetBaseRangeContext(levels, position, transformClass);
for (int idx = 0; idx < Av1Constants.CoefficientBaseRange; idx += Av1Constants.BaseRangeSizeMinus1)
{
- int k = this.ReadCoefficientsBaseRange(transformSizeContext, planeType, baseRangeContext);
- level += k;
- if (k < Av1Constants.BaseRangeSizeMinus1)
+ int coefficientBaseRange = this.ReadCoefficientsBaseRange(limitedTransformSizeContext, planeType, baseRangeContext);
+ level += coefficientBaseRange;
+ if (coefficientBaseRange < Av1Constants.BaseRangeSizeMinus1)
{
break;
}
}
}
- levels.GetRow(position)[0] = (byte)level;
+ levels.GetRow(position)[position.X] = (byte)level;
}
}
@@ -451,7 +453,7 @@ internal ref struct Av1SymbolDecoder
{
int sign = 0;
Point position = levels.GetPosition(scan[c]);
- int level = levels.GetRow(position)[0];
+ int level = levels.GetRow(position)[position.X];
if (level != 0)
{
maxScanLine = Math.Max(maxScanLine, scan[c]);
@@ -504,8 +506,7 @@ internal ref struct Av1SymbolDecoder
private int ReadCoefficientsBaseRange(Av1TransformSize transformSizeContext, Av1PlaneType planeType, int baseRangeContext)
{
ref Av1SymbolReader r = ref this.reader;
- int transformContext = Math.Min((int)transformSizeContext, (int)Av1TransformSize.Size32x32);
- return r.ReadSymbol(this.coefficientsBaseRange[transformContext][(int)planeType][baseRangeContext]);
+ return r.ReadSymbol(this.coefficientsBaseRange[(int)transformSizeContext][(int)planeType][baseRangeContext]);
}
private int ReadDcSign(Av1PlaneType planeType, int dcSignContext)
diff --git a/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1CoefficientsEntropyTests.cs b/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1CoefficientsEntropyTests.cs
index db4782e59e..08367e4e3d 100644
--- a/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1CoefficientsEntropyTests.cs
+++ b/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1CoefficientsEntropyTests.cs
@@ -100,7 +100,41 @@ public class Av1CoefficientsEntropyTests
// Assert
Assert.Equal(endOfBlock, actuals[0]);
+ }
+
+ [Fact]
+ public void RoundTripFullCoefficients()
+ {
+ // Assign
+ const ushort endOfBlock = 16;
+ const Av1BlockSize blockSize = Av1BlockSize.Block4x4;
+ const Av1TransformSize transformSize = Av1TransformSize.Size4x4;
+ const Av1TransformType transformType = Av1TransformType.Identity;
+ const Av1PredictionMode intraDirection = Av1PredictionMode.DC;
+ const Av1ComponentType componentType = Av1ComponentType.Luminance;
+ const Av1FilterIntraMode filterIntraMode = Av1FilterIntraMode.DC;
+ Av1BlockModeInfo modeInfo = new(Av1Constants.MaxPlanes, blockSize, new Point(0, 0));
+ Av1TransformInfo transformInfo = new(transformSize, 0, 0);
+ int[] aboveContexts = new int[1];
+ int[] leftContexts = new int[1];
+ Av1TransformBlockContext transformBlockContext = new();
+ Configuration configuration = Configuration.Default;
+ Av1SymbolEncoder encoder = new(configuration, 100 / 8, BaseQIndex);
+ Span coefficientsBuffer = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
+ Span actuals = new int[16 + 1];
+
+ // Act
+ encoder.WriteCoefficients(transformSize, transformType, intraDirection, coefficientsBuffer, componentType, transformBlockContext, endOfBlock, true, filterIntraMode);
+
+ using IMemoryOwner encoded = encoder.Exit();
+
+ Av1SymbolDecoder decoder = new(Configuration.Default, encoded.GetSpan(), BaseQIndex);
+ Av1SymbolReader reader = new(encoded.GetSpan());
+ int plane = Math.Min((int)componentType, 1);
+ decoder.ReadCoefficients(modeInfo, new Point(0, 0), aboveContexts, leftContexts, 0, 0, plane, 1, 1, transformBlockContext, transformSize, false, true, transformInfo, 0, 0, actuals);
- // Assert.Equal(coefficientsBuffer[..endOfBlock], actuals[1..(endOfBlock + 1)]);
+ // Assert
+ Assert.Equal(endOfBlock, actuals[0]);
+ Assert.Equal(coefficientsBuffer[..endOfBlock], actuals[1..(endOfBlock + 1)]);
}
}