Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,12 +877,17 @@ def upsert(
# get list of rows that exist so we don't have to load the entire target table
matched_predicate = upsert_util.create_match_filter(df, join_cols)

# When ``when_matched_update_all=False`` the consumer loop below
# only ever reads ``join_cols`` off each destination batch.
selected_fields: tuple[str, ...] = ("*",) if when_matched_update_all else tuple(join_cols)

# We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes.

matched_iceberg_record_batches_scan = DataScan(
table_metadata=self.table_metadata,
io=self._table.io,
row_filter=matched_predicate,
selected_fields=selected_fields,
case_sensitive=case_sensitive,
)

Expand Down
61 changes: 61 additions & 0 deletions tests/table/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from pathlib import PosixPath
from typing import Any

import pyarrow as pa
import pytest
Expand Down Expand Up @@ -888,3 +889,63 @@ def test_upsert_snapshot_properties(catalog: Catalog) -> None:
for snapshot in snapshots[initial_snapshot_count:]:
assert snapshot.summary is not None
assert snapshot.summary.additional_properties.get("test_prop") == "test_value"


def test_upsert_narrows_destination_scan_projection_to_join_cols(
catalog: Catalog,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""``Transaction.upsert`` narrows the destination scan's
``selected_fields`` to ``join_cols`` when
``when_matched_update_all=False``.

The insert-on-no-match branch only reads ``join_cols`` from each
destination batch (to feed ``create_match_filter``), so projection
at the scan boundary lets the parquet reader skip wide non-key
columns. The ``("*",)`` fallback on the ``=True`` branch is
exercised by the rest of this module — ``get_rows_to_update``'s
value-drift detection would silently break if it ever regressed.
"""
import functools

from pyiceberg.table import DataScan

identifier = "default.test_upsert_narrows_projection"
_drop_table(catalog, identifier)
table = catalog.create_table(
identifier,
schema=Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(2, "payload", StringType(), required=True),
),
)
arrow_schema = pa.schema([pa.field("id", pa.int32(), nullable=False), pa.field("payload", pa.string(), nullable=False)])
table.append(pa.Table.from_pylist([{"id": 1, "payload": "a"}], schema=arrow_schema))

# Spy on ``DataScan.__init__`` to capture each constructed scan's
# ``selected_fields``. ``functools.wraps`` preserves the original
# signature so ``DataScan.update()``'s reflective parameter lookup
# (used inside ``use_ref``) still resolves correctly.
captured: list[tuple[str, ...] | None] = []
original_init = DataScan.__init__

@functools.wraps(original_init)
def _spy(self: DataScan, *args: Any, **kwargs: Any) -> None:
original_init(self, *args, **kwargs)
captured.append(kwargs.get("selected_fields"))

monkeypatch.setattr(DataScan, "__init__", _spy)

table.upsert(
df=pa.Table.from_pylist(
[{"id": 1, "payload": "a-new"}, {"id": 2, "payload": "b"}],
schema=arrow_schema,
),
join_cols=["id"],
when_matched_update_all=False,
)

assert captured, "upsert path constructed no DataScan — projection contract regression"
assert all(sf == ("id",) for sf in captured), (
f"expected every DataScan during upsert to use selected_fields=('id',); got {captured}"
)