diff --git a/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index 1ed8d8a3..b1f074f5 100644 --- a/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -81,6 +81,11 @@ protected virtual void Dispose(bool disposing) public abstract ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken); public abstract RecordBatch ReadNextRecordBatch(); + /// + /// Custom metadata from the most recently read RecordBatch Message, if any. + /// + internal IReadOnlyDictionary LastBatchCustomMetadata { get; private protected set; } + internal static T ReadMessage(ByteBuffer bb) where T : struct, IFlatbufferObject { @@ -142,6 +147,7 @@ protected RecordBatch CreateArrowObjectFromMessage( case Flatbuf.MessageHeader.RecordBatch: Flatbuf.RecordBatch rb = message.Header().Value; List arrays = BuildArrays(message.Version, Schema, bodyByteBuffer, rb); + LastBatchCustomMetadata = ReadMessageCustomMetadata(message); return new RecordBatch(Schema, memoryOwner, arrays, (int)rb.Length); default: // NOTE: Skip unsupported message type @@ -152,6 +158,23 @@ protected RecordBatch CreateArrowObjectFromMessage( return null; } + private static IReadOnlyDictionary ReadMessageCustomMetadata(Flatbuf.Message message) + { + int count = message.CustomMetadataLength; + if (count == 0) + return null; + + var result = new Dictionary(count); + for (int i = 0; i < count; i++) + { + Flatbuf.KeyValue kv = message.CustomMetadata(i).GetValueOrDefault(); + string key = kv.Key; + if (key != null) + result[key] = kv.Value ?? ""; + } + return result; + } + internal static ByteBuffer CreateByteBuffer(ReadOnlyMemory buffer) { return new ByteBuffer(new ReadOnlyMemoryBufferAllocator(buffer), 0); diff --git a/src/Apache.Arrow/Ipc/ArrowStreamReader.cs b/src/Apache.Arrow/Ipc/ArrowStreamReader.cs index e5dade2b..a3c9d600 100644 --- a/src/Apache.Arrow/Ipc/ArrowStreamReader.cs +++ b/src/Apache.Arrow/Ipc/ArrowStreamReader.cs @@ -14,6 +14,7 @@ // limitations under the License. using System; +using System.Collections.Generic; using System.IO; using System.Threading; using System.Threading.Tasks; @@ -136,5 +137,12 @@ public RecordBatch ReadNextRecordBatch() { return _implementation.ReadNextRecordBatch(); } + + /// + /// Custom metadata from the most recently read RecordBatch Message. + /// Updated after each call to ReadNextRecordBatch/ReadNextRecordBatchAsync. + /// Returns null if the last batch had no custom metadata. + /// + public IReadOnlyDictionary LastBatchCustomMetadata => _implementation.LastBatchCustomMetadata; } } diff --git a/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index c5eed3b3..6017f485 100644 --- a/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -693,6 +693,11 @@ public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen, IpcOp } private protected void WriteRecordBatchInternal(RecordBatch recordBatch) + { + WriteRecordBatchInternal(recordBatch, customMetadata: null); + } + + private protected void WriteRecordBatchInternal(RecordBatch recordBatch, IReadOnlyDictionary customMetadata) { // TODO: Truncate buffers with extraneous padding / unused capacity @@ -714,6 +719,14 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch) VectorOffset buffersVectorOffset = Builder.EndVector(); + // Build custom metadata for the Message if provided + VectorOffset customMetadataVectorOffset = default; + if (customMetadata != null && customMetadata.Count > 0) + { + Offset[] metadataOffsets = GetMetadataOffsets(customMetadata); + customMetadataVectorOffset = Flatbuf.Message.CreateCustomMetadataVector(Builder, metadataOffsets); + } + // Serialize record batch StartingWritingRecordBatch(); @@ -725,14 +738,21 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch) variadicCountsOffset); long metadataLength = WriteMessage(Flatbuf.MessageHeader.RecordBatch, - recordBatchOffset, recordBatchBuilder.TotalLength); + recordBatchOffset, recordBatchBuilder.TotalLength, customMetadataVectorOffset); long bufferLength = WriteBufferData(recordBatchBuilder.Buffers); FinishedWritingRecordBatch(bufferLength, metadataLength); } + private protected Task WriteRecordBatchInternalAsync(RecordBatch recordBatch, + CancellationToken cancellationToken = default) + { + return WriteRecordBatchInternalAsync(recordBatch, customMetadata: null, cancellationToken); + } + private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBatch, + IReadOnlyDictionary customMetadata, CancellationToken cancellationToken = default) { if (!HasWrittenSchema) @@ -753,6 +773,14 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat VectorOffset buffersVectorOffset = Builder.EndVector(); + // Build custom metadata for the Message if provided + VectorOffset customMetadataVectorOffset = default; + if (customMetadata != null && customMetadata.Count > 0) + { + Offset[] metadataOffsets = GetMetadataOffsets(customMetadata); + customMetadataVectorOffset = Flatbuf.Message.CreateCustomMetadataVector(Builder, metadataOffsets); + } + // Serialize record batch StartingWritingRecordBatch(); @@ -765,6 +793,7 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat long metadataLength = await WriteMessageAsync(Flatbuf.MessageHeader.RecordBatch, recordBatchOffset, recordBatchBuilder.TotalLength, + customMetadataVectorOffset, cancellationToken).ConfigureAwait(false); long bufferLength = await WriteBufferDataAsync(recordBatchBuilder.Buffers, cancellationToken).ConfigureAwait(false); @@ -1011,11 +1040,21 @@ public virtual void WriteRecordBatch(RecordBatch recordBatch) WriteRecordBatchInternal(recordBatch); } + public virtual void WriteRecordBatch(RecordBatch recordBatch, IReadOnlyDictionary customMetadata) + { + WriteRecordBatchInternal(recordBatch, customMetadata); + } + public virtual Task WriteRecordBatchAsync(RecordBatch recordBatch, CancellationToken cancellationToken = default) { return WriteRecordBatchInternalAsync(recordBatch, cancellationToken); } + public virtual Task WriteRecordBatchAsync(RecordBatch recordBatch, IReadOnlyDictionary customMetadata, CancellationToken cancellationToken = default) + { + return WriteRecordBatchInternalAsync(recordBatch, customMetadata, cancellationToken); + } + public void WriteStart() { if (!HasWrittenStart) @@ -1226,12 +1265,13 @@ await WriteMessageAsync(Flatbuf.MessageHeader.Schema, schemaOffset, 0, cancellat /// The number of bytes written to the stream. /// private protected long WriteMessage( - Flatbuf.MessageHeader headerType, Offset headerOffset, int bodyLength) + Flatbuf.MessageHeader headerType, Offset headerOffset, int bodyLength, + VectorOffset customMetadataOffset = default) where T : struct { Offset messageOffset = Flatbuf.Message.CreateMessage( Builder, CurrentMetadataVersion, headerType, headerOffset.Value, - bodyLength); + bodyLength, customMetadataOffset); Builder.Finish(messageOffset.Value); @@ -1255,14 +1295,23 @@ private protected long WriteMessage( /// /// The number of bytes written to the stream. /// + private protected virtual ValueTask WriteMessageAsync( + Flatbuf.MessageHeader headerType, Offset headerOffset, int bodyLength, + CancellationToken cancellationToken) + where T : struct + { + return WriteMessageAsync(headerType, headerOffset, bodyLength, default, cancellationToken); + } + private protected virtual async ValueTask WriteMessageAsync( Flatbuf.MessageHeader headerType, Offset headerOffset, int bodyLength, + VectorOffset customMetadataOffset, CancellationToken cancellationToken) where T : struct { Offset messageOffset = Flatbuf.Message.CreateMessage( Builder, CurrentMetadataVersion, headerType, headerOffset.Value, - bodyLength); + bodyLength, customMetadataOffset); Builder.Finish(messageOffset.Value); diff --git a/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs b/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs index 42e69dcc..b672e94f 100644 --- a/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs +++ b/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs @@ -736,5 +736,164 @@ public async Task MemoryOwnerDisposalSlicedArray(int sliceOffset, int sliceLengt Assert.True(allocator.Statistics.Allocations > 0); Assert.Equal(0, allocator.Rented); } + + [Fact] + public void WriteCustomMetadata_RoundTrips() + { + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 10); + var customMetadata = new Dictionary + { + ["rpc.method"] = "add", + ["rpc.version"] = "1", + ["request_id"] = "abc-123", + }; + + using var stream = new MemoryStream(); + using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true)) + { + writer.WriteRecordBatch(originalBatch, customMetadata); + writer.WriteEnd(); + } + + stream.Position = 0; + + using var reader = new ArrowStreamReader(stream); + RecordBatch readBatch = reader.ReadNextRecordBatch(); + Assert.NotNull(readBatch); + ArrowReaderVerifier.CompareBatches(originalBatch, readBatch); + + var readMetadata = reader.LastBatchCustomMetadata; + Assert.NotNull(readMetadata); + Assert.Equal(3, readMetadata.Count); + Assert.Equal("add", readMetadata["rpc.method"]); + Assert.Equal("1", readMetadata["rpc.version"]); + Assert.Equal("abc-123", readMetadata["request_id"]); + } + + [Fact] + public async Task WriteCustomMetadataAsync_RoundTrips() + { + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 10); + var customMetadata = new Dictionary + { + ["key1"] = "value1", + ["key2"] = "value2", + }; + + using var stream = new MemoryStream(); + using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true)) + { + await writer.WriteRecordBatchAsync(originalBatch, customMetadata); + await writer.WriteEndAsync(); + } + + stream.Position = 0; + + using var reader = new ArrowStreamReader(stream); + RecordBatch readBatch = reader.ReadNextRecordBatch(); + Assert.NotNull(readBatch); + ArrowReaderVerifier.CompareBatches(originalBatch, readBatch); + + Assert.NotNull(reader.LastBatchCustomMetadata); + Assert.Equal("value1", reader.LastBatchCustomMetadata["key1"]); + Assert.Equal("value2", reader.LastBatchCustomMetadata["key2"]); + } + + [Fact] + public void WriteCustomMetadata_MultipleBatches_EachHasOwnMetadata() + { + RecordBatch batch = TestData.CreateSampleRecordBatch(length: 5); + var meta1 = new Dictionary { ["batch"] = "first" }; + var meta2 = new Dictionary { ["batch"] = "second", ["extra"] = "data" }; + + using var stream = new MemoryStream(); + using (var writer = new ArrowStreamWriter(stream, batch.Schema, leaveOpen: true)) + { + writer.WriteRecordBatch(batch, meta1); + writer.WriteRecordBatch(batch, meta2); + writer.WriteEnd(); + } + + stream.Position = 0; + + using var reader = new ArrowStreamReader(stream); + + reader.ReadNextRecordBatch(); + Assert.NotNull(reader.LastBatchCustomMetadata); + Assert.Single(reader.LastBatchCustomMetadata); + Assert.Equal("first", reader.LastBatchCustomMetadata["batch"]); + + reader.ReadNextRecordBatch(); + Assert.NotNull(reader.LastBatchCustomMetadata); + Assert.Equal(2, reader.LastBatchCustomMetadata.Count); + Assert.Equal("second", reader.LastBatchCustomMetadata["batch"]); + Assert.Equal("data", reader.LastBatchCustomMetadata["extra"]); + } + + [Fact] + public void WriteWithoutCustomMetadata_LastBatchCustomMetadataIsNull() + { + RecordBatch batch = TestData.CreateSampleRecordBatch(length: 5); + + using var stream = new MemoryStream(); + using (var writer = new ArrowStreamWriter(stream, batch.Schema, leaveOpen: true)) + { + writer.WriteRecordBatch(batch); + writer.WriteEnd(); + } + + stream.Position = 0; + + using var reader = new ArrowStreamReader(stream); + reader.ReadNextRecordBatch(); + Assert.Null(reader.LastBatchCustomMetadata); + } + + [Fact] + public void WriteCustomMetadata_MixedBatches_WithAndWithoutMetadata() + { + RecordBatch batch = TestData.CreateSampleRecordBatch(length: 5); + var meta = new Dictionary { ["key"] = "value" }; + + using var stream = new MemoryStream(); + using (var writer = new ArrowStreamWriter(stream, batch.Schema, leaveOpen: true)) + { + writer.WriteRecordBatch(batch, meta); + writer.WriteRecordBatch(batch); // no metadata + writer.WriteEnd(); + } + + stream.Position = 0; + + using var reader = new ArrowStreamReader(stream); + + reader.ReadNextRecordBatch(); + Assert.NotNull(reader.LastBatchCustomMetadata); + Assert.Equal("value", reader.LastBatchCustomMetadata["key"]); + + reader.ReadNextRecordBatch(); + Assert.Null(reader.LastBatchCustomMetadata); + } + + [Fact] + public void WriteCustomMetadata_EmptyValues_RoundTrips() + { + RecordBatch batch = TestData.CreateSampleRecordBatch(length: 5); + var meta = new Dictionary { ["empty"] = "" }; + + using var stream = new MemoryStream(); + using (var writer = new ArrowStreamWriter(stream, batch.Schema, leaveOpen: true)) + { + writer.WriteRecordBatch(batch, meta); + writer.WriteEnd(); + } + + stream.Position = 0; + + using var reader = new ArrowStreamReader(stream); + reader.ReadNextRecordBatch(); + Assert.NotNull(reader.LastBatchCustomMetadata); + Assert.Equal("", reader.LastBatchCustomMetadata["empty"]); + } } } diff --git a/test/Apache.Arrow.Tests/CustomMetadataPythonTests.cs b/test/Apache.Arrow.Tests/CustomMetadataPythonTests.cs new file mode 100644 index 00000000..56f68f32 --- /dev/null +++ b/test/Apache.Arrow.Tests/CustomMetadataPythonTests.cs @@ -0,0 +1,183 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.IO; +using Apache.Arrow.Ipc; +using Python.Runtime; +using Xunit; + +namespace Apache.Arrow.Tests +{ + + // ------------------------------------------------------------------- + // Cross-language Python tests for custom_metadata + // ------------------------------------------------------------------- + + public class CustomMetadataPythonTests : IClassFixture + { + public class PythonNet : IDisposable + { + public bool Initialized { get; } + + public bool VersionMismatch { get; } + + public PythonNet() + { + bool pythonSet = Environment.GetEnvironmentVariable("PYTHONNET_PYDLL") != null; + if (!pythonSet) + { + Initialized = false; + return; + } + + try + { + PythonEngine.Initialize(); + } + catch (NotSupportedException e) when (e.Message.Contains("Python ABI ") && e.Message.Contains("not supported")) + { + Initialized = false; + VersionMismatch = true; + return; + } + + if (System.Runtime.InteropServices.RuntimeInformation.IsOSPlatform(System.Runtime.InteropServices.OSPlatform.Windows) && + PythonEngine.PythonPath.IndexOf("dlls", StringComparison.OrdinalIgnoreCase) < 0) + { + dynamic sys = Py.Import("sys"); + sys.path.append(Path.Combine(Path.GetDirectoryName(Environment.GetEnvironmentVariable("PYTHONNET_PYDLL")), "DLLs")); + } + + Initialized = true; + } + + public void Dispose() + { + PythonEngine.Shutdown(); + } + } + + public CustomMetadataPythonTests(PythonNet pythonNet) + { + if (!pythonNet.Initialized) + { + var errorReason = pythonNet.VersionMismatch ? "Python version is incompatible with PythonNet" : "PYTHONNET_PYDLL not set"; + + bool inCIJob = Environment.GetEnvironmentVariable("GITHUB_ACTIONS") == "true"; + bool inVerificationJob = Environment.GetEnvironmentVariable("TEST_CSHARP") == "1"; + + Skip.If(inVerificationJob || !inCIJob, $"{errorReason}; skipping custom metadata Python tests."); + + throw new Exception($"{errorReason}; cannot run custom metadata Python tests."); + } + } + + // ------------------------------------------------------------------- + // C# writes IPC with custom_metadata → Python reads + // ------------------------------------------------------------------- + + [SkippableFact] + public void ExportCustomMetadata_PythonReads() + { + RecordBatch batch = TestData.CreateSampleRecordBatch(length: 5); + var batchMetadata = new Dictionary + { + ["rpc.method"] = "greet", + ["request_id"] = "abc-123", + ["custom_key"] = "custom_value", + }; + + // Serialize to IPC stream with custom batch metadata + byte[] ipcBytes; + using (var ms = new MemoryStream()) + { + using (var writer = new ArrowStreamWriter(ms, batch.Schema, leaveOpen: true)) + { + writer.WriteRecordBatch(batch, batchMetadata); + writer.WriteEnd(); + } + ipcBytes = ms.ToArray(); + } + + // Python reads and verifies custom_metadata + using (Py.GIL()) + { + dynamic pa = Py.Import("pyarrow"); + dynamic reader = pa.ipc.open_stream(pa.BufferReader(ipcBytes.ToPython())); + + PyObject result = reader.read_next_batch_with_custom_metadata(); + dynamic pyBatch = result[0]; + dynamic customMeta = result[1]; + + // Verify batch data round-tripped + Assert.Equal(5, (int)pyBatch.num_rows); + + // Verify custom_metadata (pyarrow returns bytes — decode to str) + Assert.Equal("greet", (string)customMeta["rpc.method"].decode()); + Assert.Equal("abc-123", (string)customMeta["request_id"].decode()); + Assert.Equal("custom_value", (string)customMeta["custom_key"].decode()); + } + } + + // ------------------------------------------------------------------- + // Python writes IPC with custom_metadata → C# reads + // ------------------------------------------------------------------- + + [SkippableFact] + public void ImportCustomMetadata_PythonWrites() + { + byte[] ipcBytes; + + // Python creates a batch with custom_metadata and serializes to IPC + using (Py.GIL()) + { + dynamic pa = Py.Import("pyarrow"); + dynamic io = Py.Import("io"); + + dynamic pyBatch = pa.record_batch(new PyList(new PyObject[] + { + pa.array(new int[] { 1, 2, 3, 4, 5 }), + }), new[] { "x" }); + + dynamic buf = io.BytesIO(); + dynamic writer = pa.ipc.new_stream(buf, pyBatch.schema); + dynamic customMeta = pa.KeyValueMetadata(new PyDict + { + ["origin"] = "python".ToPython(), + ["version"] = "2".ToPython(), + }); + writer.write_batch(pyBatch, custom_metadata: customMeta); + writer.close(); + + ipcBytes = ((PyObject)buf.getvalue()).As(); + } + + // C# reads and verifies custom_metadata + using var ms = new MemoryStream(ipcBytes); + using var reader = new ArrowStreamReader(ms); + + RecordBatch batch = reader.ReadNextRecordBatch(); + Assert.NotNull(batch); + Assert.Equal(5, batch.Length); + + var metadata = reader.LastBatchCustomMetadata; + Assert.NotNull(metadata); + Assert.Equal("python", metadata["origin"]); + Assert.Equal("2", metadata["version"]); + } + } +}