Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,9 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType:
else:
# Does not exist (yet)
raise TypeError(f"Unsupported integer type: {primitive}")
elif pa.types.is_float16(primitive):
# Iceberg has no half-precision float; widen to single precision (lossless)
return FloatType()
elif pa.types.is_float32(primitive):
return FloatType()
elif pa.types.is_float64(primitive):
Expand Down Expand Up @@ -1978,6 +1981,15 @@ def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
target_width = target_type.bit_width
if source_width < target_width:
return values.cast(target_type)
elif isinstance(field.field_type, (FloatType, DoubleType)):
# Cast smaller float types to target type for cross-platform compatibility
# Only allow widening conversions (smaller bit width to larger), e.g. float16 -> float32
# Narrowing conversions fall through to promote() handling below
if pa.types.is_floating(values.type):
source_width = values.type.bit_width
target_width = target_type.bit_width
if source_width < target_width:
return values.cast(target_type)

if field.field_type != file_field.field_type:
target_schema = schema_to_pyarrow(
Expand Down
29 changes: 29 additions & 0 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3109,6 +3109,35 @@ def test__to_requested_schema_integer_promotion(
assert result.column(0).to_pylist() == [1, 2, 3, None]


@pytest.mark.parametrize(
"arrow_type,iceberg_type,expected_arrow_type",
[
(pa.float16(), FloatType(), pa.float32()),
(pa.float16(), DoubleType(), pa.float64()),
(pa.float32(), DoubleType(), pa.float64()),
],
)
def test__to_requested_schema_float_promotion(
arrow_type: pa.DataType,
iceberg_type: PrimitiveType,
expected_arrow_type: pa.DataType,
) -> None:
"""Test that smaller float types are cast to target Iceberg type during write."""
requested_schema = Schema(NestedField(1, "col", iceberg_type, required=False))
file_schema = requested_schema

arrow_schema = pa.schema([pa.field("col", arrow_type)])
data = pa.array([1.5, 2.25, 3.0, None], type=arrow_type)
batch = pa.RecordBatch.from_arrays([data], schema=arrow_schema)

result = _to_requested_schema(
requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=False, include_field_ids=False
)

assert result.schema[0].type == expected_arrow_type
assert result.column(0).to_pylist() == [1.5, 2.25, 3.0, None]


def test_pyarrow_file_io_fs_by_scheme_cache() -> None:
# It's better to set up multi-region minio servers for an integration test once `endpoint_url` argument
# becomes available for `resolve_s3_region`
Expand Down
6 changes: 6 additions & 0 deletions tests/io/test_pyarrow_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def test_pyarrow_int64_to_iceberg() -> None:
assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pyarrow_type


def test_pyarrow_float16_to_iceberg() -> None:
pyarrow_type = pa.float16()
converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg())
assert converted_iceberg_type == FloatType()


def test_pyarrow_float32_to_iceberg() -> None:
pyarrow_type = pa.float32()
converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg())
Expand Down