diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index 6f32826eb0..f45823a808 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -23,6 +23,7 @@ from pyiceberg.expressions import ( AlwaysFalse, + And, BooleanExpression, EqualTo, In, @@ -33,19 +34,40 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression: unique_keys = df.select(join_cols).group_by(join_cols).aggregate([]) + if unique_keys.num_rows == 0: + return AlwaysFalse() + if len(join_cols) == 1: - return In(join_cols[0], unique_keys[0].to_pylist()) - else: - filters = [ - functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist() - ] - - if len(filters) == 0: - return AlwaysFalse() - elif len(filters) == 1: - return filters[0] - else: - return Or(*filters) + return In(join_cols[0], unique_keys.column(join_cols[0]).to_pylist()) + + # Fold the column that leaves the fewest distinct "prefix" combinations into + # an In(); this minimises the disjunct count regardless of column order. + in_col = min( + join_cols, + key=lambda cand: unique_keys.select([c for c in join_cols if c != cand]) + .group_by([c for c in join_cols if c != cand]) + .aggregate([]) + .num_rows, + ) + prefix_cols = [c for c in join_cols if c != in_col] + + # The group keys come first (in prefix_cols order) followed by the list aggregate. + # Rename the aggregate to a sentinel so it cannot collide with a join column that + # happens to be named f"{in_col}_list". + in_values_col = "__in_values" + while in_values_col in prefix_cols: + in_values_col += "_" + grouped = unique_keys.group_by(prefix_cols).aggregate([(in_col, "list")]).rename_columns([*prefix_cols, in_values_col]) + + disjuncts: list[BooleanExpression] = [] + for row in grouped.to_pylist(): + eqs = [EqualTo(c, row[c]) for c in prefix_cols] + prefix_pred = functools.reduce(operator.and_, eqs) if len(eqs) > 1 else eqs[0] + disjuncts.append(And(prefix_pred, In(in_col, row[in_values_col]))) + + if len(disjuncts) == 1: + return disjuncts[0] + return Or(*disjuncts) def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool: diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 08f90c6600..0a7168f9ec 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import itertools from pathlib import PosixPath import pyarrow as pa @@ -23,13 +24,15 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError -from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference +from pyiceberg.expressions import AlwaysFalse, AlwaysTrue, And, EqualTo, In, Reference from pyiceberg.expressions.literals import LongLiteral +from pyiceberg.expressions.visitors import expression_evaluator from pyiceberg.io.pyarrow import schema_to_pyarrow from pyiceberg.schema import Schema from pyiceberg.table import Table, UpsertResult from pyiceberg.table.snapshots import Operation from pyiceberg.table.upsert_util import create_match_filter +from pyiceberg.typedef import Record from pyiceberg.types import IntegerType, NestedField, StringType, StructType from tests.catalog.test_base import InMemoryCatalog @@ -437,10 +440,179 @@ def test_create_match_filter_single_condition() -> None: schema = pa.schema([pa.field("order_id", pa.int32()), pa.field("order_line_id", pa.int32()), pa.field("extra", pa.string())]) table = pa.Table.from_pylist(data, schema=schema) expr = create_match_filter(table, ["order_id", "order_line_id"]) - assert expr == And( - EqualTo(term=Reference(name="order_id"), literal=LongLiteral(101)), - EqualTo(term=Reference(name="order_line_id"), literal=LongLiteral(1)), + # Be insensitive to left/right operands + op1 = EqualTo(term=Reference(name="order_id"), literal=LongLiteral(101)) + op2 = EqualTo(term=Reference(name="order_line_id"), literal=LongLiteral(1)) + assert expr == And(op1, op2) or expr == And(op2, op1) + + +def _assert_match_filter_selects(data: list[dict[str, int]], join_cols: list[str], schema: Schema) -> None: + """Assert the filter from ``create_match_filter`` matches exactly the unique source keys. + + Rather than asserting a specific expression tree (which is implementation-specific), + this binds the filter and evaluates it against the full cross-product of the values + observed per column. The filter must accept exactly the unique keys present in + ``data`` and reject every other combination, so any over- or under-matching + (e.g. a cross-product regression) is caught. This holds for any correct + implementation of ``create_match_filter``. + """ + arrow_schema = schema_to_pyarrow(schema) + table = pa.Table.from_pylist(data, schema=arrow_schema) + expr = create_match_filter(table, join_cols) + + field_names = [field.name for field in schema.fields] + expected_keys = {tuple(row[name] for name in field_names) for row in data} + domains = [sorted({row[name] for row in data}) for name in field_names] + + evaluate = expression_evaluator(schema, expr, case_sensitive=True) + for candidate in itertools.product(*domains): + key = dict(zip(field_names, candidate, strict=True)) + should_match = candidate in expected_keys + verb = "rejected matching" if should_match else "matched non-matching" + assert evaluate(Record(*candidate)) is should_match, f"Filter {expr} {verb} key {key}" + + +def test_create_match_filter_single_prefix_group() -> None: + """ + Test create_match_filter with multiple key columns whose rows all share a single prefix combination. + + The filter must match the (one order_id, many order_line_id) keys and nothing else. + """ + schema = Schema( + NestedField(1, "order_id", IntegerType(), required=True), + NestedField(2, "order_line_id", IntegerType(), required=True), ) + data = [ + {"order_id": 101, "order_line_id": 1}, + {"order_id": 101, "order_line_id": 2}, + {"order_id": 101, "order_line_id": 3}, + {"order_id": 101, "order_line_id": 3}, # duplicate + ] + _assert_match_filter_selects(data, ["order_id", "order_line_id"], schema) + + +def test_create_match_filter_multiple_prefix_groups() -> None: + """ + Test create_match_filter with multiple key columns that yield several distinct prefix combinations. + + The filter must match exactly the listed composite keys and must NOT match cross-product + combinations that never appear together (e.g. order_id 101 with order_line_id 2). + """ + schema = Schema( + NestedField(1, "order_id", IntegerType(), required=True), + NestedField(2, "order_line_id", IntegerType(), required=True), + ) + data = [ + {"order_id": 101, "order_line_id": 1}, + {"order_id": 102, "order_line_id": 1}, + {"order_id": 103, "order_line_id": 1}, + {"order_id": 201, "order_line_id": 2}, + {"order_id": 202, "order_line_id": 2}, + ] + _assert_match_filter_selects(data, ["order_id", "order_line_id"], schema) + + +def test_create_match_filter_single_column() -> None: + """A single join column collapses to a single In() over the unique values.""" + schema = pa.schema([pa.field("order_id", pa.int32())]) + table = pa.Table.from_pylist([{"order_id": 1}, {"order_id": 2}, {"order_id": 2}], schema=schema) + assert create_match_filter(table, ["order_id"]) == In("order_id", [1, 2]) + + +def test_create_match_filter_single_column_single_value() -> None: + """A single unique value collapses the In() down to an EqualTo().""" + schema = pa.schema([pa.field("order_id", pa.int32())]) + table = pa.Table.from_pylist([{"order_id": 1}, {"order_id": 1}], schema=schema) + assert create_match_filter(table, ["order_id"]) == EqualTo("order_id", 1) + + +def test_create_match_filter_empty_input() -> None: + """An empty source matches nothing (AlwaysFalse), for both single and composite keys.""" + schema = pa.schema([pa.field("order_id", pa.int32()), pa.field("order_line_id", pa.int32())]) + empty = pa.Table.from_pylist([], schema=schema) + assert create_match_filter(empty, ["order_id"]) == AlwaysFalse() + assert create_match_filter(empty, ["order_id", "order_line_id"]) == AlwaysFalse() + + +def test_create_match_filter_three_columns() -> None: + """ + Test create_match_filter with three key columns. + + Exercises the multi-column prefix branch where the prefix predicate is an And of two + EqualTo() conjuncts combined with an In() over the folded column. + """ + schema = Schema( + NestedField(1, "a", IntegerType(), required=True), + NestedField(2, "b", IntegerType(), required=True), + NestedField(3, "c", IntegerType(), required=True), + ) + data = [ + {"a": 1, "b": 1, "c": 1}, + {"a": 1, "b": 1, "c": 2}, + {"a": 1, "b": 1, "c": 3}, + {"a": 2, "b": 9, "c": 5}, + {"a": 2, "b": 9, "c": 6}, + ] + _assert_match_filter_selects(data, ["a", "b", "c"], schema) + + +def test_create_match_filter_column_named_like_aggregate() -> None: + """ + Regression test for #3509 review feedback. + + A join column named ``_list`` must not collide with the internal list-aggregation + column used to fold values into an In(). Before the fix this raised a TypeError. + """ + schema = Schema( + NestedField(1, "a", IntegerType(), required=True), + NestedField(2, "a_list", IntegerType(), required=True), + ) + data = [ + {"a": 1, "a_list": 7}, + {"a": 2, "a_list": 7}, + {"a": 3, "a_list": 8}, + ] + _assert_match_filter_selects(data, ["a", "a_list"], schema) + + +def test_upsert_large_composite_key_does_not_overflow(catalog: Catalog) -> None: + """ + Regression test for #3508: a large multi-column upsert must not overflow PyArrow's + expression canonicalizer when at least one key column is low-cardinality (see #3509). + """ + identifier = "default.test_upsert_large_composite_key" + _drop_table(catalog, identifier) + + n = 20_000 + schema = pa.schema( + [ + pa.field("order_id", pa.int64(), nullable=False), + pa.field("region", pa.string(), nullable=False), + pa.field("amount", pa.int64(), nullable=False), + ] + ) + + def make(order_ids: range, amount: int) -> pa.Table: + # region is intentionally low-cardinality (4 values) so the fix folds order_id into an In(). + return pa.Table.from_pylist( + [{"order_id": oid, "region": "ABCD"[oid % 4], "amount": amount} for oid in order_ids], + schema=schema, + ) + + tbl = catalog.create_table(identifier, schema) + tbl.append(make(range(1, n + 1), amount=1)) + + # Update the first half (amount changes) and insert a tenth of brand-new keys. + source = pa.concat_tables( + [ + make(range(1, n // 2 + 1), amount=2), + make(range(n + 1, n + n // 10 + 1), amount=2), + ] + ) + + res = tbl.upsert(source, join_cols=["order_id", "region"]) + assert res.rows_updated == n // 2 + assert res.rows_inserted == n // 10 def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None: