diff --git a/vortex-duckdb/cpp/include/duckdb_vx/table_function.h b/vortex-duckdb/cpp/include/duckdb_vx/table_function.h index 4b60207a036..a83a76efb2c 100644 --- a/vortex-duckdb/cpp/include/duckdb_vx/table_function.h +++ b/vortex-duckdb/cpp/include/duckdb_vx/table_function.h @@ -90,6 +90,8 @@ typedef struct { const idx_t INVALID_IDX = UINT64_MAX; +extern void vortex_wait_and_fetch(void *task_ptr); + typedef struct { idx_t partition_index; // Either INVALID_IDX or position of column in output for file_index column @@ -116,7 +118,10 @@ typedef struct { duckdb_vx_data (*init_local)(void *init_global_data); - void (*function)(void *init_global_data, + // If chunk was exported, return nullptr. + // On error set error_out and return nullptr. + // If we're blocked on input, return task ptr. + void* (*function)(void *init_global_data, void *init_local_data, duckdb_data_chunk data_chunk_out, duckdb_vx_error *error_out); diff --git a/vortex-duckdb/cpp/table_function.cpp b/vortex-duckdb/cpp/table_function.cpp index 92d989d30d4..34614f153bd 100644 --- a/vortex-duckdb/cpp/table_function.cpp +++ b/vortex-duckdb/cpp/table_function.cpp @@ -14,6 +14,7 @@ DUCKDB_INCLUDES_BEGIN #include "duckdb/function/table_function.hpp" #include "duckdb/main/capi/capi_internal.hpp" #include "duckdb/main/connection.hpp" +#include "duckdb/parallel/async_result.hpp" #include "duckdb/parser/parsed_data/create_table_function_info.hpp" DUCKDB_INCLUDES_END @@ -249,6 +250,16 @@ init_local(ExecutionContext &, TableFunctionInitInput &input, GlobalTableFunctio return make_uniq(std::move(cdata)); } +struct VortexWaitTask final : AsyncTask { + explicit VortexWaitTask(void *task_ptr) : task_ptr(task_ptr) { + } + void Execute() override { + vortex_wait_and_fetch(task_ptr); + } + + void *task_ptr; +}; + void function(ClientContext &, TableFunctionInput &input, DataChunk &output) { const auto &bind = input.bind_data->Cast(); @@ -257,10 +268,19 @@ void function(ClientContext &, TableFunctionInput &input, DataChunk &output) { duckdb_data_chunk chunk = reinterpret_cast(&output); duckdb_vx_error error_out = nullptr; - bind.info.vtab.function(ffi_global, ffi_local, chunk, &error_out); + + void *const task_ptr = bind.info.vtab.function(ffi_global, ffi_local, chunk, &error_out); if (error_out) { throw InvalidInputException(IntoErrString(error_out)); } + if (!task_ptr) { // Chunk was exported + return; + } + + // We're blocked on IO + vector> tasks(1); + tasks[0] = make_uniq(task_ptr); + input.async_result = AsyncResult(std::move(tasks)); } void c_pushdown_complex_filter(ClientContext &, diff --git a/vortex-duckdb/include/vortex.h b/vortex-duckdb/include/vortex.h index 7148c4d9799..2eedae23ae8 100644 --- a/vortex-duckdb/include/vortex.h +++ b/vortex-duckdb/include/vortex.h @@ -38,6 +38,8 @@ const char *vortex_version_rust(void); */ const char *vortex_extension_version_rust(void); +void vortex_wait_and_fetch(void *self_ptr); + #ifdef __cplusplus } #endif diff --git a/vortex-duckdb/src/datasource.rs b/vortex-duckdb/src/datasource.rs index a1fa15a1b40..43dce45ca38 100644 --- a/vortex-duckdb/src/datasource.rs +++ b/vortex-duckdb/src/datasource.rs @@ -8,9 +8,11 @@ //! pushdown, cardinality, and partitioning. use std::cmp::max; +use std::ffi::c_void; use std::fmt::Debug; use std::ops::Range; use std::sync::Arc; +use std::sync::Mutex; use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; @@ -47,6 +49,7 @@ use vortex::file::v2::FileStatsLayoutReader; use vortex::io::kanal_ext::KanalExt; use vortex::io::runtime::BlockingRuntime; use vortex::io::runtime::current::ThreadSafeIterator; +use vortex::io::runtime::current::TryRecv; use vortex::layout::layouts::row_idx::row_idx; use vortex::layout::scan::multi::MultiLayoutChild; use vortex::layout::scan::multi::MultiLayoutDataSource; @@ -78,6 +81,7 @@ use crate::duckdb::DuckdbStringMapRef; use crate::duckdb::ExpressionRef; use crate::duckdb::LogicalType; use crate::duckdb::PartitionData; +use crate::duckdb::ScanResult; use crate::duckdb::TableFilterClass; use crate::duckdb::TableFilterSetRef; use crate::duckdb::TableFunction; @@ -158,7 +162,8 @@ impl Debug for DataSourceBindData { } } -type DataSourceIterator = ThreadSafeIterator)>>; +type DataSourceItem = VortexResult<(ArrayRef, Arc)>; +type DataSourceIterator = ThreadSafeIterator; /// Global scan state for driving a `DataSource` scan through DuckDB. pub struct DataSourceGlobal { @@ -170,12 +175,28 @@ pub struct DataSourceGlobal { file_row_number_column_pos: Option, } +struct WaitCtx { + results_rx: kanal::AsyncReceiver, + pending: Arc>>, +} + +#[unsafe(no_mangle)] +pub unsafe extern "C-unwind" fn vortex_wait_and_fetch(self_ptr: *mut c_void) { + let ctx = unsafe { Box::from_raw(self_ptr.cast::()) }; + RUNTIME.block_on(async { + if let Ok(item) = ctx.results_rx.recv().await { + *ctx.pending.lock().unwrap() = Some(item); + } + }); +} + /// Per-thread local scan state. pub struct DataSourceLocal { iterator: DataSourceIterator, exporter: Option, partition_index: u64, file_index: usize, + pending: Arc>>, } /// Returns scan progress as a percentage (0.0–100.0). @@ -456,6 +477,7 @@ impl TableFunction for T { exporter: None, partition_index: 0, file_index: 0, + pending: Arc::new(Mutex::new(None)), } } @@ -463,12 +485,24 @@ impl TableFunction for T { local_state: &mut Self::LocalState, global_state: &Self::GlobalState, chunk: &mut DataChunkRef, - ) -> VortexResult<()> { + ) -> VortexResult { loop { if local_state.exporter.is_none() { let mut ctx = SESSION.create_execution_ctx(); - let Some(result) = local_state.iterator.next() else { - return Ok(()); + let result = if let Some(item) = local_state.pending.lock().unwrap().take() { + item + } else { + match local_state.iterator.try_recv() { + TryRecv::Item(item) => item, + TryRecv::Empty => { + let wait = Box::new(WaitCtx { + results_rx: local_state.iterator.receiver(), + pending: Arc::clone(&local_state.pending), + }); + return Ok(ScanResult::Blocked(Box::into_raw(wait).cast())); + } + TryRecv::Closed => return Ok(ScanResult::Exported), + } }; let (array_result, conversion_cache) = result?; let array_result = array_result.optimize_recursive(ctx.session())?; @@ -530,7 +564,7 @@ impl TableFunction for T { .reference_value(&Value::from(local_state.file_index as u64)); } - Ok(()) + Ok(ScanResult::Exported) } fn table_scan_progress(global_state: &Self::GlobalState) -> f64 { diff --git a/vortex-duckdb/src/duckdb/table_function/mod.rs b/vortex-duckdb/src/duckdb/table_function/mod.rs index 986ac64d100..523afb9fe20 100644 --- a/vortex-duckdb/src/duckdb/table_function/mod.rs +++ b/vortex-duckdb/src/duckdb/table_function/mod.rs @@ -35,6 +35,12 @@ pub struct PartitionData { pub file_index: usize, } +pub enum ScanResult { + Exported, + Blocked(*mut c_void), +} +unsafe impl Send for ScanResult {} + #[derive(Debug, Default)] pub struct ColumnStatistics { pub min: Option, @@ -81,12 +87,12 @@ pub trait TableFunction: Sized + Debug { /// registered as a VIEW. fn statistics(bind_data: &Self::BindData, column_index: usize) -> Option; - /// The function is called during query execution and is responsible for producing the output + /// The function is called during query execution and is responsible for producing the output. fn scan( init_local: &mut Self::LocalState, init_global: &Self::GlobalState, chunk: &mut DataChunkRef, - ) -> VortexResult<()>; + ) -> VortexResult; /// Initialize the global operator state of the function. /// @@ -241,7 +247,7 @@ unsafe extern "C-unwind" fn function( local_init_data: *mut c_void, output: cpp::duckdb_data_chunk, error_out: *mut cpp::duckdb_vx_error, -) { +) -> *mut c_void { let global_init_data = unsafe { global_init_data.cast::().as_ref() } .vortex_expect("global_init_data null pointer"); let local_init_data = unsafe { local_init_data.cast::().as_mut() } @@ -249,15 +255,14 @@ unsafe extern "C-unwind" fn function( let data_chunk = unsafe { DataChunk::borrow_mut(output) }; match T::scan(local_init_data, global_init_data, data_chunk) { - Ok(()) => { - // The data chunk is already filled by the function. - // No need to do anything here. + Ok(ScanResult::Exported) => ptr::null_mut(), + Ok(ScanResult::Blocked(task_ptr)) => task_ptr, + Err(e) => { + let msg = e.to_string(); + unsafe { + error_out.write(cpp::duckdb_vx_error_create(msg.as_ptr().cast(), msg.len())); + } + ptr::null_mut() } - Err(e) => unsafe { - error_out.write(cpp::duckdb_vx_error_create( - e.to_string().as_ptr().cast(), - e.to_string().len(), - )); - }, } } diff --git a/vortex-io/src/runtime/current.rs b/vortex-io/src/runtime/current.rs index c8ff91b4aec..51b26e0c823 100644 --- a/vortex-io/src/runtime/current.rs +++ b/vortex-io/src/runtime/current.rs @@ -134,6 +134,26 @@ pub struct ThreadSafeIterator { results: kanal::AsyncReceiver, } +pub enum TryRecv { + Item(T), + Empty, + Closed, +} + +impl ThreadSafeIterator { + pub fn receiver(&self) -> kanal::AsyncReceiver { + self.results.clone() + } + + pub fn try_recv(&self) -> TryRecv { + match self.results.try_recv() { + Ok(Some(v)) => TryRecv::Item(v), + Ok(None) => TryRecv::Empty, + Err(_) => TryRecv::Closed, + } + } +} + // Manual clone implementation since `T` does not need to be `Clone`. impl Clone for ThreadSafeIterator { fn clone(&self) -> Self {