diff --git a/pyiceberg/utils/decimal.py b/pyiceberg/utils/decimal.py index 5ef82640d9..83b9bda498 100644 --- a/pyiceberg/utils/decimal.py +++ b/pyiceberg/utils/decimal.py @@ -58,11 +58,16 @@ def bytes_required(value: int | Decimal) -> int: int: the minimum number of bytes needed to serialize the value. """ if isinstance(value, int): - return (value.bit_length() + 8) // 8 + unscaled = value elif isinstance(value, Decimal): - return (decimal_to_unscaled(value).bit_length() + 8) // 8 - - raise ValueError(f"Unsupported value: {value}") + unscaled = decimal_to_unscaled(value) + else: + raise ValueError(f"Unsupported value: {value}") + + # bit_length() overcounts negatives equal to -2**(8k-1) (e.g. -128, -32768) by one byte; + # using (unscaled + 1) for negatives yields the true minimum, matching the Iceberg spec. + n_bits = unscaled.bit_length() if unscaled >= 0 else (unscaled + 1).bit_length() + return (n_bits + 8) // 8 def decimal_to_bytes(value: Decimal, byte_length: int | None = None) -> bytes: diff --git a/tests/utils/test_decimal.py b/tests/utils/test_decimal.py index 3e67bf691a..c08b3e6414 100644 --- a/tests/utils/test_decimal.py +++ b/tests/utils/test_decimal.py @@ -18,7 +18,7 @@ import pytest -from pyiceberg.utils.decimal import decimal_required_bytes, decimal_to_bytes +from pyiceberg.utils.decimal import bytes_required, decimal_required_bytes, decimal_to_bytes def test_decimal_required_bytes() -> None: @@ -42,8 +42,29 @@ def test_decimal_required_bytes() -> None: assert "(0, 40]" in str(exc_info.value) +def test_bytes_required() -> None: + # Positive values and the negative values just past a byte boundary are unaffected. + assert bytes_required(0) == 1 + assert bytes_required(127) == 1 + assert bytes_required(128) == 2 + assert bytes_required(-127) == 1 + assert bytes_required(-129) == 2 + # The most-negative value that fits in N bytes (-2**(8N-1)) must require exactly N bytes, + # not N + 1. These are the cases the previous (value.bit_length() + 8) // 8 formula overcounted. + assert bytes_required(-128) == 1 + assert bytes_required(-32768) == 2 + assert bytes_required(-8388608) == 3 + # The same applies when the unscaled value comes from a Decimal. + assert bytes_required(Decimal("-1.28")) == 1 + assert bytes_required(Decimal("-327.68")) == 2 + + def test_decimal_to_bytes() -> None: # Check the boundary between 2 and 3 bytes. # 2 bytes has a minimum of -32,768 and a maximum value of 32,767 (inclusive). assert decimal_to_bytes(Decimal("32767.")) == b"\x7f\xff" assert decimal_to_bytes(Decimal("32768.")) == b"\x00\x80\x00" + # The most-negative value for a given width must serialize to the minimum number of bytes + # (matching the Iceberg spec / Java BigInteger.toByteArray), not one byte longer. + assert decimal_to_bytes(Decimal("-1.28")) == b"\x80" + assert decimal_to_bytes(Decimal("-327.68")) == b"\x80\x00"