Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions pyiceberg/table/upsert_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from pyiceberg.expressions import (
AlwaysFalse,
And,
BooleanExpression,
EqualTo,
In,
Expand All @@ -33,19 +34,35 @@
def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])

if unique_keys.num_rows == 0:
return AlwaysFalse()

if len(join_cols) == 1:
return In(join_cols[0], unique_keys[0].to_pylist())
else:
filters = [
functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()
]

if len(filters) == 0:
return AlwaysFalse()
elif len(filters) == 1:
return filters[0]
else:
return Or(*filters)
return In(join_cols[0], unique_keys.column(join_cols[0]).to_pylist())

# Fold the column that leaves the fewest distinct "prefix" combinations into
# an In(); this minimises the disjunct count regardless of column order.
in_col = min(
join_cols,
key=lambda cand: unique_keys.select([c for c in join_cols if c != cand])
.group_by([c for c in join_cols if c != cand])
.aggregate([])
.num_rows,
)
prefix_cols = [c for c in join_cols if c != in_col]

grouped = unique_keys.group_by(prefix_cols).aggregate([(in_col, "list")])
in_values_col = f"{in_col}_list"

disjuncts: list[BooleanExpression] = []
for row in grouped.to_pylist():
eqs = [EqualTo(c, row[c]) for c in prefix_cols]
prefix_pred = functools.reduce(operator.and_, eqs) if len(eqs) > 1 else eqs[0]
disjuncts.append(And(prefix_pred, In(in_col, row[in_values_col])))

if len(disjuncts) == 1:
return disjuncts[0]
return Or(*disjuncts)


def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:
Expand Down
75 changes: 72 additions & 3 deletions tests/table/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import itertools
from pathlib import PosixPath

import pyarrow as pa
Expand All @@ -25,11 +26,13 @@
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference
from pyiceberg.expressions.literals import LongLiteral
from pyiceberg.expressions.visitors import expression_evaluator
from pyiceberg.io.pyarrow import schema_to_pyarrow
from pyiceberg.schema import Schema
from pyiceberg.table import Table, UpsertResult
from pyiceberg.table.snapshots import Operation
from pyiceberg.table.upsert_util import create_match_filter
from pyiceberg.typedef import Record
from pyiceberg.types import IntegerType, NestedField, StringType, StructType
from tests.catalog.test_base import InMemoryCatalog

Expand Down Expand Up @@ -437,10 +440,76 @@ def test_create_match_filter_single_condition() -> None:
schema = pa.schema([pa.field("order_id", pa.int32()), pa.field("order_line_id", pa.int32()), pa.field("extra", pa.string())])
table = pa.Table.from_pylist(data, schema=schema)
expr = create_match_filter(table, ["order_id", "order_line_id"])
assert expr == And(
EqualTo(term=Reference(name="order_id"), literal=LongLiteral(101)),
EqualTo(term=Reference(name="order_line_id"), literal=LongLiteral(1)),
# Be insensitive to left/right operands
op1 = EqualTo(term=Reference(name="order_id"), literal=LongLiteral(101))
op2 = EqualTo(term=Reference(name="order_line_id"), literal=LongLiteral(1))
assert expr == And(op1, op2) or expr == And(op2, op1)


def _assert_match_filter_selects(data: list[dict[str, int]], join_cols: list[str], schema: Schema) -> None:
"""Assert the filter from ``create_match_filter`` matches exactly the unique source keys.

Rather than asserting a specific expression tree (which is implementation-specific),
this binds the filter and evaluates it against the full cross-product of the values
observed per column. The filter must accept exactly the unique keys present in
``data`` and reject every other combination, so any over- or under-matching
(e.g. a cross-product regression) is caught. This holds for any correct
implementation of ``create_match_filter``.
"""
arrow_schema = schema_to_pyarrow(schema)
table = pa.Table.from_pylist(data, schema=arrow_schema)
expr = create_match_filter(table, join_cols)

field_names = [field.name for field in schema.fields]
expected_keys = {tuple(row[name] for name in field_names) for row in data}
domains = [sorted({row[name] for row in data}) for name in field_names]

evaluate = expression_evaluator(schema, expr, case_sensitive=True)
for candidate in itertools.product(*domains):
key = dict(zip(field_names, candidate, strict=True))
should_match = candidate in expected_keys
verb = "rejected matching" if should_match else "matched non-matching"
assert evaluate(Record(*candidate)) is should_match, f"Filter {expr} {verb} key {key}"


def test_create_match_filter_single_prefix_group() -> None:
"""
Test create_match_filter with multiple key columns whose rows all share a single prefix combination.

The filter must match the (one order_id, many order_line_id) keys and nothing else.
"""
schema = Schema(
NestedField(1, "order_id", IntegerType(), required=True),
NestedField(2, "order_line_id", IntegerType(), required=True),
)
data = [
{"order_id": 101, "order_line_id": 1},
{"order_id": 101, "order_line_id": 2},
{"order_id": 101, "order_line_id": 3},
{"order_id": 101, "order_line_id": 3}, # duplicate
]
_assert_match_filter_selects(data, ["order_id", "order_line_id"], schema)


def test_create_match_filter_multiple_prefix_groups() -> None:
"""
Test create_match_filter with multiple key columns that yield several distinct prefix combinations.

The filter must match exactly the listed composite keys and must NOT match cross-product
combinations that never appear together (e.g. order_id 101 with order_line_id 2).
"""
schema = Schema(
NestedField(1, "order_id", IntegerType(), required=True),
NestedField(2, "order_line_id", IntegerType(), required=True),
)
data = [
{"order_id": 101, "order_line_id": 1},
{"order_id": 102, "order_line_id": 1},
{"order_id": 103, "order_line_id": 1},
{"order_id": 201, "order_line_id": 2},
{"order_id": 202, "order_line_id": 2},
]
_assert_match_filter_selects(data, ["order_id", "order_line_id"], schema)


def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None:
Expand Down