Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ protected virtual void Dispose(bool disposing)
public abstract ValueTask<RecordBatch> ReadNextRecordBatchAsync(CancellationToken cancellationToken);
public abstract RecordBatch ReadNextRecordBatch();

/// <summary>
/// Custom metadata from the most recently read RecordBatch Message, if any.
/// </summary>
internal IReadOnlyDictionary<string, string> LastBatchCustomMetadata { get; private protected set; }

internal static T ReadMessage<T>(ByteBuffer bb)
where T : struct, IFlatbufferObject
{
Expand Down Expand Up @@ -142,6 +147,7 @@ protected RecordBatch CreateArrowObjectFromMessage(
case Flatbuf.MessageHeader.RecordBatch:
Flatbuf.RecordBatch rb = message.Header<Flatbuf.RecordBatch>().Value;
List<IArrowArray> arrays = BuildArrays(message.Version, Schema, bodyByteBuffer, rb);
LastBatchCustomMetadata = ReadMessageCustomMetadata(message);
return new RecordBatch(Schema, memoryOwner, arrays, (int)rb.Length);
default:
// NOTE: Skip unsupported message type
Expand All @@ -152,6 +158,23 @@ protected RecordBatch CreateArrowObjectFromMessage(
return null;
}

private static IReadOnlyDictionary<string, string> ReadMessageCustomMetadata(Flatbuf.Message message)
{
int count = message.CustomMetadataLength;
if (count == 0)
return null;

var result = new Dictionary<string, string>(count);
for (int i = 0; i < count; i++)
{
Flatbuf.KeyValue kv = message.CustomMetadata(i).GetValueOrDefault();
string key = kv.Key;
if (key != null)
result[key] = kv.Value ?? "";
}
return result;
}

internal static ByteBuffer CreateByteBuffer(ReadOnlyMemory<byte> buffer)
{
return new ByteBuffer(new ReadOnlyMemoryBufferAllocator(buffer), 0);
Expand Down
8 changes: 8 additions & 0 deletions src/Apache.Arrow/Ipc/ArrowStreamReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
// limitations under the License.

using System;
using System.Collections.Generic;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -136,5 +137,12 @@ public RecordBatch ReadNextRecordBatch()
{
return _implementation.ReadNextRecordBatch();
}

/// <summary>
/// Custom metadata from the most recently read RecordBatch Message.
/// Updated after each call to ReadNextRecordBatch/ReadNextRecordBatchAsync.
/// Returns null if the last batch had no custom metadata.
/// </summary>
public IReadOnlyDictionary<string, string> LastBatchCustomMetadata => _implementation.LastBatchCustomMetadata;
}
}
57 changes: 53 additions & 4 deletions src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,11 @@ public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen, IpcOp
}

private protected void WriteRecordBatchInternal(RecordBatch recordBatch)
{
WriteRecordBatchInternal(recordBatch, customMetadata: null);
}

private protected void WriteRecordBatchInternal(RecordBatch recordBatch, IReadOnlyDictionary<string, string> customMetadata)
{
// TODO: Truncate buffers with extraneous padding / unused capacity

Expand All @@ -714,6 +719,14 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch)

VectorOffset buffersVectorOffset = Builder.EndVector();

// Build custom metadata for the Message if provided
VectorOffset customMetadataVectorOffset = default;
if (customMetadata != null && customMetadata.Count > 0)
{
Offset<Flatbuf.KeyValue>[] metadataOffsets = GetMetadataOffsets(customMetadata);
customMetadataVectorOffset = Flatbuf.Message.CreateCustomMetadataVector(Builder, metadataOffsets);
}

// Serialize record batch

StartingWritingRecordBatch();
Expand All @@ -725,14 +738,21 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch)
variadicCountsOffset);

long metadataLength = WriteMessage(Flatbuf.MessageHeader.RecordBatch,
recordBatchOffset, recordBatchBuilder.TotalLength);
recordBatchOffset, recordBatchBuilder.TotalLength, customMetadataVectorOffset);

long bufferLength = WriteBufferData(recordBatchBuilder.Buffers);

FinishedWritingRecordBatch(bufferLength, metadataLength);
}

private protected Task WriteRecordBatchInternalAsync(RecordBatch recordBatch,
CancellationToken cancellationToken = default)
{
return WriteRecordBatchInternalAsync(recordBatch, customMetadata: null, cancellationToken);
}

private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBatch,
IReadOnlyDictionary<string, string> customMetadata,
CancellationToken cancellationToken = default)
{
if (!HasWrittenSchema)
Expand All @@ -753,6 +773,14 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat

VectorOffset buffersVectorOffset = Builder.EndVector();

// Build custom metadata for the Message if provided
VectorOffset customMetadataVectorOffset = default;
if (customMetadata != null && customMetadata.Count > 0)
{
Offset<Flatbuf.KeyValue>[] metadataOffsets = GetMetadataOffsets(customMetadata);
customMetadataVectorOffset = Flatbuf.Message.CreateCustomMetadataVector(Builder, metadataOffsets);
}

// Serialize record batch

StartingWritingRecordBatch();
Expand All @@ -765,6 +793,7 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat

long metadataLength = await WriteMessageAsync(Flatbuf.MessageHeader.RecordBatch,
recordBatchOffset, recordBatchBuilder.TotalLength,
customMetadataVectorOffset,
cancellationToken).ConfigureAwait(false);

long bufferLength = await WriteBufferDataAsync(recordBatchBuilder.Buffers, cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -1011,11 +1040,21 @@ public virtual void WriteRecordBatch(RecordBatch recordBatch)
WriteRecordBatchInternal(recordBatch);
}

public virtual void WriteRecordBatch(RecordBatch recordBatch, IReadOnlyDictionary<string, string> customMetadata)
{
WriteRecordBatchInternal(recordBatch, customMetadata);
}

public virtual Task WriteRecordBatchAsync(RecordBatch recordBatch, CancellationToken cancellationToken = default)
{
return WriteRecordBatchInternalAsync(recordBatch, cancellationToken);
}

public virtual Task WriteRecordBatchAsync(RecordBatch recordBatch, IReadOnlyDictionary<string, string> customMetadata, CancellationToken cancellationToken = default)
{
return WriteRecordBatchInternalAsync(recordBatch, customMetadata, cancellationToken);
}

public void WriteStart()
{
if (!HasWrittenStart)
Expand Down Expand Up @@ -1226,12 +1265,13 @@ await WriteMessageAsync(Flatbuf.MessageHeader.Schema, schemaOffset, 0, cancellat
/// The number of bytes written to the stream.
/// </returns>
private protected long WriteMessage<T>(
Flatbuf.MessageHeader headerType, Offset<T> headerOffset, int bodyLength)
Flatbuf.MessageHeader headerType, Offset<T> headerOffset, int bodyLength,
VectorOffset customMetadataOffset = default)
where T : struct
{
Offset<Flatbuf.Message> messageOffset = Flatbuf.Message.CreateMessage(
Builder, CurrentMetadataVersion, headerType, headerOffset.Value,
bodyLength);
bodyLength, customMetadataOffset);

Builder.Finish(messageOffset.Value);

Expand All @@ -1255,14 +1295,23 @@ private protected long WriteMessage<T>(
/// <returns>
/// The number of bytes written to the stream.
/// </returns>
private protected virtual ValueTask<long> WriteMessageAsync<T>(
Flatbuf.MessageHeader headerType, Offset<T> headerOffset, int bodyLength,
CancellationToken cancellationToken)
where T : struct
{
return WriteMessageAsync(headerType, headerOffset, bodyLength, default, cancellationToken);
}

private protected virtual async ValueTask<long> WriteMessageAsync<T>(
Flatbuf.MessageHeader headerType, Offset<T> headerOffset, int bodyLength,
VectorOffset customMetadataOffset,
CancellationToken cancellationToken)
where T : struct
{
Offset<Flatbuf.Message> messageOffset = Flatbuf.Message.CreateMessage(
Builder, CurrentMetadataVersion, headerType, headerOffset.Value,
bodyLength);
bodyLength, customMetadataOffset);

Builder.Finish(messageOffset.Value);

Expand Down
159 changes: 159 additions & 0 deletions test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -736,5 +736,164 @@ public async Task MemoryOwnerDisposalSlicedArray(int sliceOffset, int sliceLengt
Assert.True(allocator.Statistics.Allocations > 0);
Assert.Equal(0, allocator.Rented);
}

[Fact]
public void WriteCustomMetadata_RoundTrips()
{
RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 10);
var customMetadata = new Dictionary<string, string>
{
["rpc.method"] = "add",
["rpc.version"] = "1",
["request_id"] = "abc-123",
};

using var stream = new MemoryStream();
using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true))
{
writer.WriteRecordBatch(originalBatch, customMetadata);
writer.WriteEnd();
}

stream.Position = 0;

using var reader = new ArrowStreamReader(stream);
RecordBatch readBatch = reader.ReadNextRecordBatch();
Assert.NotNull(readBatch);
ArrowReaderVerifier.CompareBatches(originalBatch, readBatch);

var readMetadata = reader.LastBatchCustomMetadata;
Assert.NotNull(readMetadata);
Assert.Equal(3, readMetadata.Count);
Assert.Equal("add", readMetadata["rpc.method"]);
Assert.Equal("1", readMetadata["rpc.version"]);
Assert.Equal("abc-123", readMetadata["request_id"]);
}

[Fact]
public async Task WriteCustomMetadataAsync_RoundTrips()
{
RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 10);
var customMetadata = new Dictionary<string, string>
{
["key1"] = "value1",
["key2"] = "value2",
};

using var stream = new MemoryStream();
using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true))
{
await writer.WriteRecordBatchAsync(originalBatch, customMetadata);
await writer.WriteEndAsync();
}

stream.Position = 0;

using var reader = new ArrowStreamReader(stream);
RecordBatch readBatch = reader.ReadNextRecordBatch();
Assert.NotNull(readBatch);
ArrowReaderVerifier.CompareBatches(originalBatch, readBatch);

Assert.NotNull(reader.LastBatchCustomMetadata);
Assert.Equal("value1", reader.LastBatchCustomMetadata["key1"]);
Assert.Equal("value2", reader.LastBatchCustomMetadata["key2"]);
}

[Fact]
public void WriteCustomMetadata_MultipleBatches_EachHasOwnMetadata()
{
RecordBatch batch = TestData.CreateSampleRecordBatch(length: 5);
var meta1 = new Dictionary<string, string> { ["batch"] = "first" };
var meta2 = new Dictionary<string, string> { ["batch"] = "second", ["extra"] = "data" };

using var stream = new MemoryStream();
using (var writer = new ArrowStreamWriter(stream, batch.Schema, leaveOpen: true))
{
writer.WriteRecordBatch(batch, meta1);
writer.WriteRecordBatch(batch, meta2);
writer.WriteEnd();
}

stream.Position = 0;

using var reader = new ArrowStreamReader(stream);

reader.ReadNextRecordBatch();
Assert.NotNull(reader.LastBatchCustomMetadata);
Assert.Single(reader.LastBatchCustomMetadata);
Assert.Equal("first", reader.LastBatchCustomMetadata["batch"]);

reader.ReadNextRecordBatch();
Assert.NotNull(reader.LastBatchCustomMetadata);
Assert.Equal(2, reader.LastBatchCustomMetadata.Count);
Assert.Equal("second", reader.LastBatchCustomMetadata["batch"]);
Assert.Equal("data", reader.LastBatchCustomMetadata["extra"]);
}

[Fact]
public void WriteWithoutCustomMetadata_LastBatchCustomMetadataIsNull()
{
RecordBatch batch = TestData.CreateSampleRecordBatch(length: 5);

using var stream = new MemoryStream();
using (var writer = new ArrowStreamWriter(stream, batch.Schema, leaveOpen: true))
{
writer.WriteRecordBatch(batch);
writer.WriteEnd();
}

stream.Position = 0;

using var reader = new ArrowStreamReader(stream);
reader.ReadNextRecordBatch();
Assert.Null(reader.LastBatchCustomMetadata);
}

[Fact]
public void WriteCustomMetadata_MixedBatches_WithAndWithoutMetadata()
{
RecordBatch batch = TestData.CreateSampleRecordBatch(length: 5);
var meta = new Dictionary<string, string> { ["key"] = "value" };

using var stream = new MemoryStream();
using (var writer = new ArrowStreamWriter(stream, batch.Schema, leaveOpen: true))
{
writer.WriteRecordBatch(batch, meta);
writer.WriteRecordBatch(batch); // no metadata
writer.WriteEnd();
}

stream.Position = 0;

using var reader = new ArrowStreamReader(stream);

reader.ReadNextRecordBatch();
Assert.NotNull(reader.LastBatchCustomMetadata);
Assert.Equal("value", reader.LastBatchCustomMetadata["key"]);

reader.ReadNextRecordBatch();
Assert.Null(reader.LastBatchCustomMetadata);
}

[Fact]
public void WriteCustomMetadata_EmptyValues_RoundTrips()
{
RecordBatch batch = TestData.CreateSampleRecordBatch(length: 5);
var meta = new Dictionary<string, string> { ["empty"] = "" };

using var stream = new MemoryStream();
using (var writer = new ArrowStreamWriter(stream, batch.Schema, leaveOpen: true))
{
writer.WriteRecordBatch(batch, meta);
writer.WriteEnd();
}

stream.Position = 0;

using var reader = new ArrowStreamReader(stream);
reader.ReadNextRecordBatch();
Assert.NotNull(reader.LastBatchCustomMetadata);
Assert.Equal("", reader.LastBatchCustomMetadata["empty"]);
}
}
}
Loading
Loading