Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 228 additions & 1 deletion pyiceberg/io/pyiceberg_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
from __future__ import annotations

import importlib
import os
import threading
import weakref
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
from typing import TYPE_CHECKING, Any

from pyiceberg.expressions import (
Expand Down Expand Up @@ -58,7 +62,7 @@
from pyiceberg.typedef import Record

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Iterable, Iterator


def _core_module(name: str) -> Any:
Expand Down Expand Up @@ -146,6 +150,16 @@ def _read_field_ids(
return ids


def can_read_projected_schema_with_pyiceberg_core(
schema: Schema,
projected_schema: Schema,
row_filter: BooleanExpression,
case_sensitive: bool,
) -> bool:
"""Return whether pyiceberg-core can read exactly the requested projection for this filter."""
return _expression_field_ids(row_filter, schema, case_sensitive).issubset(projected_schema.field_ids)


_UNARY_METHODS: dict[type[BooleanExpression], str] = {
IsNull: "is_null",
NotNull: "is_not_null",
Expand Down Expand Up @@ -291,3 +305,216 @@ def file_scan_task_to_pyiceberg_core(
name_mapping=_model_json(name_mapping) if name_mapping is not None else None,
case_sensitive=case_sensitive,
)


# Rows per Arrow batch handed back from the native reader. The native default emits very small
# batches, and the sharded fan-in marshals every batch back through the GIL-holding consumer, so a
# tiny batch makes per-batch Python orchestration dominate the decode (measured ~3x slower than a
# whole-shard drain). A larger batch amortizes that handoff while keeping in-flight memory bounded
# to shards x batch (so the streaming contract still holds). Override with
# PYICEBERG_RUST_ARROW_BATCH_SIZE.
_DEFAULT_ARROW_BATCH_SIZE = 262144


def _reader_kwargs() -> dict[str, int]:
batch_size = os.environ.get("PYICEBERG_RUST_ARROW_BATCH_SIZE")
kwargs: dict[str, int] = {"batch_size": int(batch_size) if batch_size else _DEFAULT_ARROW_BATCH_SIZE}
concurrency = os.environ.get("PYICEBERG_RUST_ARROW_FILE_CONCURRENCY")
if concurrency:
kwargs["data_file_concurrency_limit"] = int(concurrency)
return kwargs


def _shard_count(n_tasks: int) -> int:
"""How many decode threads to fan out across.

The native ArrowReader decodes a single stream on one core (it parallelizes I/O, not CPU
decode), so a single-stream read of many files leaves the box idle. Sharding the file tasks
across threads — each driving its own reader — recovers multi-core decode (the GIL is released
during the C-stream drain). Default scales with cores, capped so tiny scans don't pay thread
overhead; override with PYICEBERG_RUST_ARROW_SHARDS (1 disables sharding).
"""
override = os.environ.get("PYICEBERG_RUST_ARROW_SHARDS")
if override:
return max(1, int(override))
if n_tasks <= 1:
return 1
return max(1, min(n_tasks, (os.cpu_count() or 1)))


class _ShardedBatchStream:
"""Generator-backed, backpressured fan-in over several native shard readers.

Each shard owns one ``pyiceberg_core`` ``RecordBatchReader`` and is drained sequentially (a
stateful reader must not be polled concurrently), so at most one read per shard is ever in
flight. A ``ThreadPoolExecutor`` pulls the *next* batch from every idle shard at once and the
consumer is handed batches as they complete, so decode runs on up to ``n_shards`` cores (the
GIL is released during the C-stream drain). Peak memory is bounded to at most one decoded
batch per shard plus what the consumer holds — never the whole result — because a shard is not
asked for its next batch until its current one has been handed out (backpressure: a slow
consumer stalls the shards rather than buffering ahead).

Batches are yielded as they complete (``FIRST_COMPLETED``); ordering across shards is not
preserved, which is sound because the scan result is an unordered union of file tasks. Worker
exceptions are re-raised to the consumer, and the pool is shut down on exhaustion, on an
exception during iteration, on an explicit :meth:`close`, or — for a consumer that simply
stops iterating and drops the reader — via the ``weakref`` finalizer at garbage collection.
"""

def __init__(self, readers: list[Any]) -> None:
self._readers = readers
self._pool = ThreadPoolExecutor(max_workers=len(readers))
# Shard indices whose reader is idle and not yet known to be exhausted.
self._idle: list[int] = list(range(len(readers)))
self._in_flight: dict[Future[Any], int] = {}
self._closed = False
self._lock = threading.Lock()
# Shut the pool down if the consumer abandons the iterator without closing it.
self._finalizer = weakref.finalize(self, self._shutdown, self._pool, self._readers)

@staticmethod
def _next_batch(reader: Any) -> Any | None:
"""Pull one batch from a shard reader, returning ``None`` when the shard is exhausted."""
try:
return reader.read_next_batch()
except StopIteration:
return None

def _submit_idle(self) -> None:
"""Submit the next read for every idle, non-exhausted shard (one read per shard)."""
while self._idle:
shard = self._idle.pop()
future = self._pool.submit(self._next_batch, self._readers[shard])
self._in_flight[future] = shard

def __iter__(self) -> _ShardedBatchStream:
return self

def __next__(self) -> Any:
if self._closed:
raise StopIteration
try:
while True:
self._submit_idle()
if not self._in_flight:
# Every shard is exhausted: the union is complete.
self.close()
raise StopIteration
done, _ = wait(self._in_flight, return_when=FIRST_COMPLETED)
for future in done:
shard = self._in_flight.pop(future)
batch = future.result() # re-raises any worker exception here
if batch is None:
continue # shard exhausted; do not return it to the idle set
# Shard produced a batch: it may have more, so mark it idle again.
self._idle.append(shard)
return batch
except StopIteration:
raise
except BaseException:
# Worker exception, GeneratorExit, or KeyboardInterrupt: tear the workers down so an
# aborted scan never leaves decode threads running.
self.close()
raise

def close(self) -> None:
"""Cancel pending work and release shard readers; idempotent and consumer-safe."""
with self._lock:
if self._closed:
return
self._closed = True
self._finalizer.detach()
self._shutdown(self._pool, self._readers)

@staticmethod
def _shutdown(pool: ThreadPoolExecutor, readers: list[Any]) -> None:
# cancel_futures drops queued reads; in-flight reads are joined so no thread outlives us.
pool.shutdown(wait=True, cancel_futures=True)
readers.clear()


def _limited_batches(source: Any, limit: int) -> Iterator[Any]:
"""Yield batches from ``source`` until ``limit`` rows have been emitted, then stop.

The batch that crosses the limit is sliced, and the underlying source is closed so a sharded
scan stops decoding early instead of draining every file. An Iceberg scan limit has no ordering
guarantee, so returning the first ``limit`` rows the readers produce is correct.
"""
remaining = limit
try:
for batch in source:
if remaining <= 0:
break
if batch.num_rows > remaining:
yield batch.slice(0, remaining)
break
remaining -= batch.num_rows
yield batch
finally:
close = getattr(source, "close", None)
if close is not None:
close()


def arrow_batch_reader_from_pyiceberg_core(
file_io: FileIO,
tasks: Iterable[FileScanTask],
schema: Schema,
projected_schema: Schema,
partition_specs: dict[int, PartitionSpec],
name_mapping: NameMapping | None,
case_sensitive: bool = True,
limit: int | None = None,
) -> Any:
"""Read PyIceberg scan tasks through pyiceberg-core's ArrowReader as a streaming reader.

Multi-file scans are sharded across a thread pool (see ``_shard_count``) so decode uses
multiple cores; a single native reader over all files would decode on one core. Each shard
drives its own native ``RecordBatchReader``; the returned ``pyarrow.RecordBatchReader`` pulls
batches from the shards lazily (at most one decoded batch per shard in flight), so the whole
result is never materialized and peak memory stays bounded. The single-file or single-shard
case skips the fan-out entirely and returns the native reader directly.

Worker-thread exceptions propagate to the consumer, and the shard threads are shut down when
the reader is exhausted, closed early, or garbage collected.
"""
core_tasks = [
file_scan_task_to_pyiceberg_core(
task,
schema,
projected_schema,
partition_spec=partition_specs.get(task.file.spec_id),
name_mapping=name_mapping,
case_sensitive=case_sensitive,
project_field_ids=list(projected_schema.field_ids),
)
for task in tasks
]

core_projection = schema_to_pyiceberg_core(projected_schema)
reader_kwargs = _reader_kwargs()
if limit is not None:
# No point decoding a full default-sized batch per shard just to truncate to a small limit.
reader_kwargs["batch_size"] = max(1, min(reader_kwargs["batch_size"], limit))

def _read(shard_tasks: list[Any]) -> Any:
reader = _core_module("scan").ArrowReader(file_io_to_pyiceberg_core(file_io), **reader_kwargs)
return reader.read(core_projection, shard_tasks)

import pyarrow as pa

shards = _shard_count(len(core_tasks))
if shards <= 1 or len(core_tasks) <= 1:
reader = _read(core_tasks)
if limit is None:
return reader
return pa.RecordBatchReader.from_batches(reader.schema, _limited_batches(reader, limit))

groups = [g for g in (core_tasks[i::shards] for i in range(shards)) if g]
readers = [_read(group) for group in groups]
# Every native reader carries the same projected Arrow schema; use it to type the stream.
arrow_schema = readers[0].schema
stream: Any = _ShardedBatchStream(readers)
if limit is not None:
stream = _limited_batches(stream, limit)
return pa.RecordBatchReader.from_batches(arrow_schema, stream)
33 changes: 31 additions & 2 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@

ALWAYS_TRUE = AlwaysTrue()
DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write"
PYICEBERG_RUST_ARROW_SCAN = "PYICEBERG_RUST_ARROW_SCAN"


@dataclass()
Expand Down Expand Up @@ -2242,9 +2243,37 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader:

from pyiceberg.io.pyarrow import ArrowScan, schema_to_pyarrow

target_schema = schema_to_pyarrow(self.projection())
projected_schema = self.projection()
if os.environ.get(PYICEBERG_RUST_ARROW_SCAN, "").lower() in {"1", "true", "yes"}:
from pyiceberg.io.pyiceberg_core import (
arrow_batch_reader_from_pyiceberg_core,
can_read_projected_schema_with_pyiceberg_core,
)

if can_read_projected_schema_with_pyiceberg_core(
self.table_metadata.schema(), projected_schema, self.row_filter, self.case_sensitive
):
try:
return arrow_batch_reader_from_pyiceberg_core(
self.io,
self.plan_files(),
self.table_metadata.schema(),
projected_schema,
self.table_metadata.specs(),
self.table_metadata.name_mapping(),
self.case_sensitive,
limit=self.limit,
)
except (ModuleNotFoundError, NotImplementedError, ValueError) as exc:
warnings.warn(
f"Falling back to PyArrow scan because pyiceberg-core cannot handle this scan: {exc}",
RuntimeWarning,
stacklevel=2,
)

target_schema = schema_to_pyarrow(projected_schema)
batches = ArrowScan(
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
self.table_metadata, self.io, projected_schema, self.row_filter, self.case_sensitive, self.limit
).to_record_batches(self.plan_files())

return pa.RecordBatchReader.from_batches(
Expand Down
Loading