From f5865d4bb9549596147d1bf66149f1635bc4817a Mon Sep 17 00:00:00 2001 From: steven-winfield-quantohm Date: Mon, 15 Jun 2026 14:09:24 +0100 Subject: [PATCH 1/5] More efficient (and less segfaulty) create_match_filter --- pyiceberg/table/upsert_util.py | 40 ++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index 6f32826eb0..f68d538c3b 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -33,19 +33,35 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression: unique_keys = df.select(join_cols).group_by(join_cols).aggregate([]) + if unique_keys.num_rows == 0: + return AlwaysFalse() + if len(join_cols) == 1: - return In(join_cols[0], unique_keys[0].to_pylist()) - else: - filters = [ - functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist() - ] - - if len(filters) == 0: - return AlwaysFalse() - elif len(filters) == 1: - return filters[0] - else: - return Or(*filters) + return In(join_cols[0], unique_keys.column(join_cols[0]).to_pylist()) + + # Fold the column that leaves the fewest distinct "prefix" combinations into + # an In(); this minimises the disjunct count regardless of column order. + in_col = min( + join_cols, + key=lambda cand: unique_keys.select([c for c in join_cols if c != cand]) + .group_by([c for c in join_cols if c != cand]) + .aggregate([]) + .num_rows, + ) + prefix_cols = [c for c in join_cols if c != in_col] + + grouped = unique_keys.group_by(prefix_cols).aggregate([(in_col, "list")]) + in_values_col = f"{in_col}_list" + + disjuncts: list[BooleanExpression] = [] + for row in grouped.to_pylist(): + eqs = [EqualTo(c, row[c]) for c in prefix_cols] + prefix_pred = functools.reduce(operator.and_, eqs) if len(eqs) > 1 else eqs[0] + disjuncts.append(And(prefix_pred, In(in_col, row[in_values_col]))) + + if len(disjuncts) == 1: + return disjuncts[0] + return Or(*disjuncts) def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool: From 0362a91748fb36056c3d2ce0a2661acc678dead0 Mon Sep 17 00:00:00 2001 From: steven-winfield-quantohm Date: Mon, 15 Jun 2026 14:29:19 +0100 Subject: [PATCH 2/5] Missing import --- pyiceberg/table/upsert_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index f68d538c3b..c8b4225d37 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -23,6 +23,7 @@ from pyiceberg.expressions import ( AlwaysFalse, + And, BooleanExpression, EqualTo, In, From 269dcd2380de10fd18eda263039fe84fa490d283 Mon Sep 17 00:00:00 2001 From: steven-winfield-quantohm Date: Mon, 15 Jun 2026 14:38:52 +0100 Subject: [PATCH 3/5] Allow And(op1, op2) or And(op2, op1) in test_create_match_filter_single_condition --- tests/table/test_upsert.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 08f90c6600..b5a0c63b02 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -437,10 +437,10 @@ def test_create_match_filter_single_condition() -> None: schema = pa.schema([pa.field("order_id", pa.int32()), pa.field("order_line_id", pa.int32()), pa.field("extra", pa.string())]) table = pa.Table.from_pylist(data, schema=schema) expr = create_match_filter(table, ["order_id", "order_line_id"]) - assert expr == And( - EqualTo(term=Reference(name="order_id"), literal=LongLiteral(101)), - EqualTo(term=Reference(name="order_line_id"), literal=LongLiteral(1)), - ) + # Be insensitive to left/right operands + op1 = EqualTo(term=Reference(name="order_id"), literal=LongLiteral(101)) + op2 = EqualTo(term=Reference(name="order_line_id"), literal=LongLiteral(1)) + assert expr == And(op1, op2) or expr == And(op2, op1) def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None: From 691d13877483aac73e1a474d7d7fa1be4270b5d3 Mon Sep 17 00:00:00 2001 From: Steven Winfield Date: Wed, 17 Jun 2026 15:39:54 +0000 Subject: [PATCH 4/5] Add regression tests for multi-column create_match_filter Address review feedback on #3509: cover the "multiple columns, single prefix group" and "multiple columns, multiple prefix groups" cases. The tests verify matching semantics (the filter accepts exactly the unique source keys and rejects every other combination) rather than a specific expression tree, so they guard against regressions independent of the implementation. Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/table/test_upsert.py | 69 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index b5a0c63b02..9c4828494d 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import itertools from pathlib import PosixPath import pyarrow as pa @@ -25,11 +26,13 @@ from pyiceberg.exceptions import NoSuchTableError from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference from pyiceberg.expressions.literals import LongLiteral +from pyiceberg.expressions.visitors import expression_evaluator from pyiceberg.io.pyarrow import schema_to_pyarrow from pyiceberg.schema import Schema from pyiceberg.table import Table, UpsertResult from pyiceberg.table.snapshots import Operation from pyiceberg.table.upsert_util import create_match_filter +from pyiceberg.typedef import Record from pyiceberg.types import IntegerType, NestedField, StringType, StructType from tests.catalog.test_base import InMemoryCatalog @@ -443,6 +446,72 @@ def test_create_match_filter_single_condition() -> None: assert expr == And(op1, op2) or expr == And(op2, op1) +def _assert_match_filter_selects(data: list[dict[str, int]], join_cols: list[str], schema: Schema) -> None: + """Assert the filter from ``create_match_filter`` matches exactly the unique source keys. + + Rather than asserting a specific expression tree (which is implementation-specific), + this binds the filter and evaluates it against the full cross-product of the values + observed per column. The filter must accept exactly the unique keys present in + ``data`` and reject every other combination, so any over- or under-matching + (e.g. a cross-product regression) is caught. This holds for any correct + implementation of ``create_match_filter``. + """ + arrow_schema = schema_to_pyarrow(schema) + table = pa.Table.from_pylist(data, schema=arrow_schema) + expr = create_match_filter(table, join_cols) + + field_names = [field.name for field in schema.fields] + expected_keys = {tuple(row[name] for name in field_names) for row in data} + domains = [sorted({row[name] for row in data}) for name in field_names] + + evaluate = expression_evaluator(schema, expr, case_sensitive=True) + for candidate in itertools.product(*domains): + key = dict(zip(field_names, candidate, strict=True)) + should_match = candidate in expected_keys + verb = "rejected matching" if should_match else "matched non-matching" + assert evaluate(Record(*candidate)) is should_match, f"Filter {expr} {verb} key {key}" + + +def test_create_match_filter_single_prefix_group() -> None: + """ + Test create_match_filter with multiple key columns whose rows all share a single prefix combination. + + The filter must match the (one order_id, many order_line_id) keys and nothing else. + """ + schema = Schema( + NestedField(1, "order_id", IntegerType(), required=True), + NestedField(2, "order_line_id", IntegerType(), required=True), + ) + data = [ + {"order_id": 101, "order_line_id": 1}, + {"order_id": 101, "order_line_id": 2}, + {"order_id": 101, "order_line_id": 3}, + {"order_id": 101, "order_line_id": 3}, # duplicate + ] + _assert_match_filter_selects(data, ["order_id", "order_line_id"], schema) + + +def test_create_match_filter_multiple_prefix_groups() -> None: + """ + Test create_match_filter with multiple key columns that yield several distinct prefix combinations. + + The filter must match exactly the listed composite keys and must NOT match cross-product + combinations that never appear together (e.g. order_id 101 with order_line_id 2). + """ + schema = Schema( + NestedField(1, "order_id", IntegerType(), required=True), + NestedField(2, "order_line_id", IntegerType(), required=True), + ) + data = [ + {"order_id": 101, "order_line_id": 1}, + {"order_id": 102, "order_line_id": 1}, + {"order_id": 103, "order_line_id": 1}, + {"order_id": 201, "order_line_id": 2}, + {"order_id": 202, "order_line_id": 2}, + ] + _assert_match_filter_selects(data, ["order_id", "order_line_id"], schema) + + def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None: identifier = "default.test_upsert_with_duplicate_rows_in_table" From 25d08530ee9dc8b47caa790de78fa36bfcc4ea2f Mon Sep 17 00:00:00 2001 From: Steven Winfield Date: Wed, 17 Jun 2026 21:20:01 +0000 Subject: [PATCH 5/5] Fix in_col list-aggregate column collision and add edge-case tests When folding a key column into an In(), the list aggregation column was named f"{in_col}_list", which silently clobbered a join column of the same name and fed a Python list into EqualTo (TypeError). Rename the aggregate to a collision-free sentinel by position instead. Also add coverage for create_match_filter edge cases raised in the #3509 review: single column, single-value collapse to EqualTo, empty input, three key columns, the column-name collision regression, and a large multi-column upsert (#3508) that must not overflow PyArrow's expression canonicalizer when a key column is low-cardinality. Co-Authored-By: Claude Opus 4.8 (1M context) --- pyiceberg/table/upsert_util.py | 9 ++- tests/table/test_upsert.py | 105 ++++++++++++++++++++++++++++++++- 2 files changed, 111 insertions(+), 3 deletions(-) diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index c8b4225d37..f45823a808 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -51,8 +51,13 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpre ) prefix_cols = [c for c in join_cols if c != in_col] - grouped = unique_keys.group_by(prefix_cols).aggregate([(in_col, "list")]) - in_values_col = f"{in_col}_list" + # The group keys come first (in prefix_cols order) followed by the list aggregate. + # Rename the aggregate to a sentinel so it cannot collide with a join column that + # happens to be named f"{in_col}_list". + in_values_col = "__in_values" + while in_values_col in prefix_cols: + in_values_col += "_" + grouped = unique_keys.group_by(prefix_cols).aggregate([(in_col, "list")]).rename_columns([*prefix_cols, in_values_col]) disjuncts: list[BooleanExpression] = [] for row in grouped.to_pylist(): diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 9c4828494d..0a7168f9ec 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -24,7 +24,7 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError -from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference +from pyiceberg.expressions import AlwaysFalse, AlwaysTrue, And, EqualTo, In, Reference from pyiceberg.expressions.literals import LongLiteral from pyiceberg.expressions.visitors import expression_evaluator from pyiceberg.io.pyarrow import schema_to_pyarrow @@ -512,6 +512,109 @@ def test_create_match_filter_multiple_prefix_groups() -> None: _assert_match_filter_selects(data, ["order_id", "order_line_id"], schema) +def test_create_match_filter_single_column() -> None: + """A single join column collapses to a single In() over the unique values.""" + schema = pa.schema([pa.field("order_id", pa.int32())]) + table = pa.Table.from_pylist([{"order_id": 1}, {"order_id": 2}, {"order_id": 2}], schema=schema) + assert create_match_filter(table, ["order_id"]) == In("order_id", [1, 2]) + + +def test_create_match_filter_single_column_single_value() -> None: + """A single unique value collapses the In() down to an EqualTo().""" + schema = pa.schema([pa.field("order_id", pa.int32())]) + table = pa.Table.from_pylist([{"order_id": 1}, {"order_id": 1}], schema=schema) + assert create_match_filter(table, ["order_id"]) == EqualTo("order_id", 1) + + +def test_create_match_filter_empty_input() -> None: + """An empty source matches nothing (AlwaysFalse), for both single and composite keys.""" + schema = pa.schema([pa.field("order_id", pa.int32()), pa.field("order_line_id", pa.int32())]) + empty = pa.Table.from_pylist([], schema=schema) + assert create_match_filter(empty, ["order_id"]) == AlwaysFalse() + assert create_match_filter(empty, ["order_id", "order_line_id"]) == AlwaysFalse() + + +def test_create_match_filter_three_columns() -> None: + """ + Test create_match_filter with three key columns. + + Exercises the multi-column prefix branch where the prefix predicate is an And of two + EqualTo() conjuncts combined with an In() over the folded column. + """ + schema = Schema( + NestedField(1, "a", IntegerType(), required=True), + NestedField(2, "b", IntegerType(), required=True), + NestedField(3, "c", IntegerType(), required=True), + ) + data = [ + {"a": 1, "b": 1, "c": 1}, + {"a": 1, "b": 1, "c": 2}, + {"a": 1, "b": 1, "c": 3}, + {"a": 2, "b": 9, "c": 5}, + {"a": 2, "b": 9, "c": 6}, + ] + _assert_match_filter_selects(data, ["a", "b", "c"], schema) + + +def test_create_match_filter_column_named_like_aggregate() -> None: + """ + Regression test for #3509 review feedback. + + A join column named ``_list`` must not collide with the internal list-aggregation + column used to fold values into an In(). Before the fix this raised a TypeError. + """ + schema = Schema( + NestedField(1, "a", IntegerType(), required=True), + NestedField(2, "a_list", IntegerType(), required=True), + ) + data = [ + {"a": 1, "a_list": 7}, + {"a": 2, "a_list": 7}, + {"a": 3, "a_list": 8}, + ] + _assert_match_filter_selects(data, ["a", "a_list"], schema) + + +def test_upsert_large_composite_key_does_not_overflow(catalog: Catalog) -> None: + """ + Regression test for #3508: a large multi-column upsert must not overflow PyArrow's + expression canonicalizer when at least one key column is low-cardinality (see #3509). + """ + identifier = "default.test_upsert_large_composite_key" + _drop_table(catalog, identifier) + + n = 20_000 + schema = pa.schema( + [ + pa.field("order_id", pa.int64(), nullable=False), + pa.field("region", pa.string(), nullable=False), + pa.field("amount", pa.int64(), nullable=False), + ] + ) + + def make(order_ids: range, amount: int) -> pa.Table: + # region is intentionally low-cardinality (4 values) so the fix folds order_id into an In(). + return pa.Table.from_pylist( + [{"order_id": oid, "region": "ABCD"[oid % 4], "amount": amount} for oid in order_ids], + schema=schema, + ) + + tbl = catalog.create_table(identifier, schema) + tbl.append(make(range(1, n + 1), amount=1)) + + # Update the first half (amount changes) and insert a tenth of brand-new keys. + source = pa.concat_tables( + [ + make(range(1, n // 2 + 1), amount=2), + make(range(n + 1, n + n // 10 + 1), amount=2), + ] + ) + + res = tbl.upsert(source, join_cols=["order_id", "region"]) + assert res.rows_updated == n // 2 + assert res.rows_inserted == n // 10 + + def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None: identifier = "default.test_upsert_with_duplicate_rows_in_table"