Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions pyiceberg/utils/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,16 @@ def bytes_required(value: int | Decimal) -> int:
int: the minimum number of bytes needed to serialize the value.
"""
if isinstance(value, int):
return (value.bit_length() + 8) // 8
unscaled = value
elif isinstance(value, Decimal):
return (decimal_to_unscaled(value).bit_length() + 8) // 8

raise ValueError(f"Unsupported value: {value}")
unscaled = decimal_to_unscaled(value)
else:
raise ValueError(f"Unsupported value: {value}")

# bit_length() overcounts negatives equal to -2**(8k-1) (e.g. -128, -32768) by one byte;
# using (unscaled + 1) for negatives yields the true minimum, matching the Iceberg spec.
n_bits = unscaled.bit_length() if unscaled >= 0 else (unscaled + 1).bit_length()
return (n_bits + 8) // 8


def decimal_to_bytes(value: Decimal, byte_length: int | None = None) -> bytes:
Expand Down
23 changes: 22 additions & 1 deletion tests/utils/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import pytest

from pyiceberg.utils.decimal import decimal_required_bytes, decimal_to_bytes
from pyiceberg.utils.decimal import bytes_required, decimal_required_bytes, decimal_to_bytes


def test_decimal_required_bytes() -> None:
Expand All @@ -42,8 +42,29 @@ def test_decimal_required_bytes() -> None:
assert "(0, 40]" in str(exc_info.value)


def test_bytes_required() -> None:
# Positive values and the negative values just past a byte boundary are unaffected.
assert bytes_required(0) == 1
assert bytes_required(127) == 1
assert bytes_required(128) == 2
assert bytes_required(-127) == 1
assert bytes_required(-129) == 2
# The most-negative value that fits in N bytes (-2**(8N-1)) must require exactly N bytes,
# not N + 1. These are the cases the previous (value.bit_length() + 8) // 8 formula overcounted.
assert bytes_required(-128) == 1
assert bytes_required(-32768) == 2
assert bytes_required(-8388608) == 3
# The same applies when the unscaled value comes from a Decimal.
assert bytes_required(Decimal("-1.28")) == 1
assert bytes_required(Decimal("-327.68")) == 2


def test_decimal_to_bytes() -> None:
# Check the boundary between 2 and 3 bytes.
# 2 bytes has a minimum of -32,768 and a maximum value of 32,767 (inclusive).
assert decimal_to_bytes(Decimal("32767.")) == b"\x7f\xff"
assert decimal_to_bytes(Decimal("32768.")) == b"\x00\x80\x00"
# The most-negative value for a given width must serialize to the minimum number of bytes
# (matching the Iceberg spec / Java BigInteger.toByteArray), not one byte longer.
assert decimal_to_bytes(Decimal("-1.28")) == b"\x80"
assert decimal_to_bytes(Decimal("-327.68")) == b"\x80\x00"