Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2c0bd23
feat: add oom-guard feature and accounting helpers
andygrove Jun 3, 2026
d5f1355
refactor: feature-gate oom_guard module and harden should_trip
andygrove Jun 3, 2026
97d9124
feat: oom-guard process globals, arming, and enforcement
andygrove Jun 3, 2026
470c981
feat: wrap global allocator with AccountingAllocator under oom-guard
andygrove Jun 3, 2026
2439345
docs: document oom-guard allocator reentrancy and cast safety
andygrove Jun 3, 2026
32c47be
feat: stamp query-worker threads and arm oom-guard from config
andygrove Jun 3, 2026
1b06d94
refactor: name memoryGuard config keys and warn on zero limit
andygrove Jun 3, 2026
2ed112e
feat: map OomGuardPanic to ResourcesExhausted at JNI boundaries
andygrove Jun 3, 2026
f2be0c0
refactor: share oom-guard error message and harden panic recovery
andygrove Jun 3, 2026
6669c21
feat: register spark.comet.exec.memoryGuard configs
andygrove Jun 3, 2026
0a7d686
docs: clarify memoryGuard config descriptions
andygrove Jun 3, 2026
ffb52f2
test: oom-guard trips on a real over-budget allocation
andygrove Jun 3, 2026
adef2b2
fix: map OomGuardPanic on the spawned-path consumer thread
andygrove Jun 3, 2026
89c782d
refactor: consolidate oom-guard panic mapping and simplify allocator cfg
andygrove Jun 3, 2026
423cfce
fix: avoid realloc use-after-free and serialize guard panics in OomGu…
andygrove Jun 3, 2026
6b3411f
feat: add RealUsagePool decorator gating growth on real allocator usage
andygrove Jun 20, 2026
eb03dfe
feat: gate memory pool on real usage when memory guard is enabled
andygrove Jun 20, 2026
cba97f0
test: real-usage gate trips on a real over-budget allocation
andygrove Jun 20, 2026
0d6f225
docs: clarify cooperative gate vs OomGuard breaker layering and ordering
andygrove Jun 20, 2026
0ef4ec0
refactor: gate before delegating in RealUsagePool and align with pool…
andygrove Jun 20, 2026
6efc722
feat: add fair-share guard to RealUsagePool real-usage gate
andygrove Jun 20, 2026
ee115a8
feat: track active task count for fair-share divisor
andygrove Jun 20, 2026
12e4d58
refactor: drop redundant dead-code guard on record_task_started
andygrove Jun 20, 2026
3684851
docs: surface fair_share in Debug and note fallback divisor for non-s…
andygrove Jun 20, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions native/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ hdfs = ["datafusion-comet-objectstore-hdfs"]
hdfs-opendal = ["opendal", "object_store_opendal", "hdfs-sys"]
jemalloc = ["tikv-jemallocator", "tikv-jemalloc-ctl"]

# Allocator-level OOM circuit breaker. When enabled, the global allocator is
# wrapped to track real allocated bytes and panic an over-budget query-worker
# thread (caught at the task boundary). Off by default; zero overhead when off.
oom-guard = []

# exclude optional packages from cargo machete verifications
[package.metadata.cargo-machete]
ignored = ["hdfs-sys", "paste"]
Expand Down
230 changes: 171 additions & 59 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,13 @@ use crate::execution::spark_config::{
SparkConfig, COMET_DEBUG_ENABLED, COMET_DEBUG_MEMORY, COMET_EXPLAIN_NATIVE_ENABLED,
COMET_MAX_TEMP_DIRECTORY_SIZE, COMET_TRACING_ENABLED, SPARK_EXECUTOR_CORES,
};
#[cfg(feature = "oom-guard")]
use crate::execution::spark_config::{COMET_MEMORY_GUARD_ENABLED, COMET_MEMORY_GUARD_SIZE};
use crate::parquet::encryption_support::{CometEncryptionFactory, ENCRYPTION_FACTORY_ID};
use datafusion_comet_proto::spark_operator::operator::OpStruct;
use log::info;
#[cfg(feature = "oom-guard")]
use log::warn;
use std::sync::OnceLock;
#[cfg(feature = "jemalloc")]
use tikv_jemalloc_ctl::{epoch, stats};
Expand Down Expand Up @@ -192,6 +196,8 @@ fn parse_usize_env_var(name: &str) -> Option<usize> {

fn build_runtime(default_worker_threads: Option<usize>) -> Runtime {
let mut builder = tokio::runtime::Builder::new_multi_thread();
#[cfg(feature = "oom-guard")]
builder.on_thread_start(crate::execution::memory_pools::oom_guard::stamp_current_thread);
if let Some(n) = parse_usize_env_var("COMET_WORKER_THREADS") {
info!("Comet tokio runtime: using COMET_WORKER_THREADS={n}");
builder.worker_threads(n);
Expand Down Expand Up @@ -369,6 +375,31 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
spark_config.get_u64(COMET_MAX_TEMP_DIRECTORY_SIZE, 100 * 1024 * 1024 * 1024);
let logging_memory_pool = spark_config.get_bool(COMET_DEBUG_MEMORY);

#[cfg(feature = "oom-guard")]
{
if spark_config.get_bool(COMET_MEMORY_GUARD_ENABLED) {
// This is the hard, last-resort breaker: a panic on any over-budget
// allocation. It defaults to the same off-heap budget as the cooperative
// real-usage gate below, but the two layers still order correctly because
// the cooperative gate trips on *projected* usage (`balance + additional`)
// while this breaker trips on *actual* usage (`balance`), so cooperative
// spilling is attempted before the breaker fires. Set
// `spark.comet.exec.memoryGuard.size` explicitly above the off-heap budget
// to give the breaker additional headroom (e.g. up to the container RSS
// limit) for a wider spill-before-fail margin.
let default_limit = memory_limit.max(0) as u64;
let limit = spark_config.get_u64(COMET_MEMORY_GUARD_SIZE, default_limit);
if limit == 0 {
warn!(
"spark.comet.exec.memoryGuard.enabled is true but the effective limit \
is 0 (memory_limit={memory_limit}); the guard will not trip. Set \
spark.comet.exec.memoryGuard.size explicitly."
);
}
crate::execution::memory_pools::oom_guard::arm(limit as usize);
}
}

with_trace("createPlan", tracing_enabled, || {
// Init JVM classes
JVMClasses::init(env);
Expand Down Expand Up @@ -404,6 +435,30 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
let memory_pool =
create_memory_pool(&memory_pool_config, task_memory_manager, task_attempt_id);

// Cooperative real-usage gate: when the memory guard is enabled, wrap the
// pool so growth is rejected (triggering a spill) once real allocator usage
// plus the request would exceed the process-wide off-heap budget. This is the
// first line of defense and fires before the hard OomGuard breaker armed above
// (projected vs. actual usage; see that comment), so over-budget work spills
// and retries rather than failing the task.
#[cfg(feature = "oom-guard")]
let memory_pool = if spark_config.get_bool(COMET_MEMORY_GUARD_ENABLED) {
let ceiling = memory_limit.max(0) as usize;
// Enable the fair-share guard for pools whose `reserved()` is per-task;
// `executor_cores` is the fallback divisor when no task count is known.
let fair_share = memory_pool_config
.pool_type
.has_per_task_budget()
.then_some(executor_cores);
Arc::new(crate::execution::memory_pools::RealUsagePool::new(
memory_pool,
ceiling,
fair_share,
)) as Arc<dyn datafusion::execution::memory_pool::MemoryPool>
} else {
memory_pool
};

let memory_pool = if logging_memory_pool {
Arc::new(LoggingMemoryPool::new(task_attempt_id as u64, memory_pool))
} else {
Expand Down Expand Up @@ -715,6 +770,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
schema_addrs: JLongArray,
) -> jlong {
try_unwrap_or_throw(&e, |env| {
#[cfg(feature = "oom-guard")]
crate::execution::memory_pools::oom_guard::stamp_current_thread();
// Retrieve the query
let exec_context = get_execution_context(exec_context);

Expand Down Expand Up @@ -786,6 +843,17 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
.await;

if let Err(panic) = result {
#[cfg(feature = "oom-guard")]
if let Some(e) =
crate::execution::memory_pools::oom_guard::map_panic_to_error(
panic.as_ref(),
)
{
// Runs on the tokio worker thread that panicked, so this clears
// that worker's UNWINDING flag (not the blocked JNI caller thread's).
let _ = tx.send(Err(e)).await;
return;
}
let msg = match panic.downcast_ref::<&str>() {
Some(s) => s.to_string(),
None => match panic.downcast_ref::<String>() {
Expand All @@ -810,76 +878,120 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
pull_input_batches(exec_context)?;
}

if let Some(rx) = &mut exec_context.batch_receiver {
match rx.blocking_recv() {
Some(Ok(batch)) => {
update_metrics(env, exec_context)?;
return prepare_output(
env,
array_addrs,
schema_addrs,
batch,
exec_context.debug_native,
);
}
Some(Err(e)) => {
return Err(e.into());
}
None => {
log_plan_metrics(exec_context, stage_id, partition);
return Ok(-1);
if exec_context.batch_receiver.is_some() {
let recv_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(
|| -> CometResult<jlong> {
// Scope the rx borrow to just the blocking_recv call so that
// exec_context is free for update_metrics / prepare_output below.
let recv = exec_context
.batch_receiver
.as_mut()
.unwrap()
.blocking_recv();
match recv {
Some(Ok(batch)) => {
update_metrics(env, exec_context)?;
prepare_output(
env,
array_addrs,
schema_addrs,
batch,
exec_context.debug_native,
)
}
Some(Err(e)) => Err(e.into()),
None => {
log_plan_metrics(exec_context, stage_id, partition);
Ok(-1)
}
}
},
));

match recv_result {
Ok(r) => return r,
Err(_panic) => {
#[cfg(feature = "oom-guard")]
if let Some(e) =
crate::execution::memory_pools::oom_guard::map_panic_to_error(
_panic.as_ref(),
)
{
// Drop the receiver so any re-entry re-initializes.
exec_context.batch_receiver = None;
return Err(e.into());
}
std::panic::resume_unwind(_panic);
}
}
}

// ScanExec path: busy-poll to interleave JVM batch pulls with stream polling
get_runtime().block_on(async {
loop {
let next_item = exec_context.stream.as_mut().unwrap().next();
let poll_output = poll!(next_item);

// Only check time/tracing every 100 polls to reduce overhead
exec_context.poll_count_since_metrics_check += 1;
if exec_context.poll_count_since_metrics_check >= 100 {
exec_context.poll_count_since_metrics_check = 0;
if let Some(interval) = exec_context.metrics_update_interval {
let now = Instant::now();
if now - exec_context.metrics_last_update_time >= interval {
update_metrics(env, exec_context)?;
exec_context.metrics_last_update_time = now;
let poll_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
get_runtime().block_on(async {
loop {
let next_item = exec_context.stream.as_mut().unwrap().next();
let poll_output = poll!(next_item);

// Only check time/tracing every 100 polls to reduce overhead
exec_context.poll_count_since_metrics_check += 1;
if exec_context.poll_count_since_metrics_check >= 100 {
exec_context.poll_count_since_metrics_check = 0;
if let Some(interval) = exec_context.metrics_update_interval {
let now = Instant::now();
if now - exec_context.metrics_last_update_time >= interval {
update_metrics(env, exec_context)?;
exec_context.metrics_last_update_time = now;
}
}
if exec_context.tracing_enabled {
log_memory_usage(
&exec_context.tracing_memory_metric_name,
total_reserved_for_thread(exec_context.rust_thread_id) as u64,
);
}
}
if exec_context.tracing_enabled {
log_memory_usage(
&exec_context.tracing_memory_metric_name,
total_reserved_for_thread(exec_context.rust_thread_id) as u64,
);
}
}

match poll_output {
Poll::Ready(Some(output)) => {
return prepare_output(
env,
array_addrs,
schema_addrs,
output?,
exec_context.debug_native,
);
}
Poll::Ready(None) => {
log_plan_metrics(exec_context, stage_id, partition);
return Ok(-1);
}
Poll::Pending => {
// JNI call to pull batches from JVM into ScanExec operators.
// block_in_place lets tokio move other tasks off this worker
// while we wait for JVM data.
tokio::task::block_in_place(|| pull_input_batches(exec_context))?;
match poll_output {
Poll::Ready(Some(output)) => {
return prepare_output(
env,
array_addrs,
schema_addrs,
output?,
exec_context.debug_native,
);
}
Poll::Ready(None) => {
log_plan_metrics(exec_context, stage_id, partition);
return Ok(-1);
}
Poll::Pending => {
// JNI call to pull batches from JVM into ScanExec operators.
// block_in_place lets tokio move other tasks off this worker
// while we wait for JVM data.
tokio::task::block_in_place(|| pull_input_batches(exec_context))?;
}
}
}
})
}));

match poll_result {
Ok(r) => r,
Err(_panic) => {
#[cfg(feature = "oom-guard")]
if let Some(e) = crate::execution::memory_pools::oom_guard::map_panic_to_error(
_panic.as_ref(),
) {
// The block_on future was dropped mid-poll; null the stream so any
// inadvertent re-entry re-initializes rather than polling a half-consumed one.
exec_context.stream = None;
return Err(e.into());
}
std::panic::resume_unwind(_panic);
}
})
}
});

if exec_context.tracing_enabled {
Expand Down
16 changes: 16 additions & 0 deletions native/core/src/execution/memory_pools/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ impl MemoryPoolType {
| MemoryPoolType::GreedyUnified
)
}

/// True when this pool's `reserved()` reflects a single task's usage, so a
/// per-task fair-share comparison is meaningful. False for process-wide pools
/// whose `reserved()` is the aggregate across tasks. Note the non-shared
/// per-task pools (`Greedy`/`FairSpill`) return true but keep no task registry,
/// so the fair-share divisor falls back to `executor_cores` for them rather
/// than the dynamic active-task count.
#[cfg_attr(not(feature = "oom-guard"), allow(dead_code))]
pub(crate) fn has_per_task_budget(&self) -> bool {
!matches!(
self,
MemoryPoolType::GreedyGlobal
| MemoryPoolType::FairSpillGlobal
| MemoryPoolType::Unbounded
)
}
}

pub(crate) struct MemoryPoolConfig {
Expand Down
9 changes: 9 additions & 0 deletions native/core/src/execution/memory_pools/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
mod config;
mod fair_pool;
pub mod logging_pool;
#[cfg(feature = "oom-guard")]
pub mod oom_guard;
#[cfg(feature = "oom-guard")]
mod real_usage_pool;
mod task_shared;
mod unified_pool;

Expand All @@ -32,6 +36,8 @@ use std::sync::Arc;
use unified_pool::CometUnifiedMemoryPool;

pub(crate) use config::*;
#[cfg(feature = "oom-guard")]
pub(crate) use real_usage_pool::RealUsagePool;
pub(crate) use task_shared::*;

pub(crate) fn create_memory_pool(
Expand All @@ -45,6 +51,7 @@ pub(crate) fn create_memory_pool(
let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap();
let per_task_memory_pool =
memory_pool_map.entry(task_attempt_id).or_insert_with(|| {
record_task_started();
let pool: Arc<dyn MemoryPool> = Arc::new(TrackConsumersPool::new(
CometUnifiedMemoryPool::new(
Arc::clone(&comet_task_memory_manager),
Expand All @@ -61,6 +68,7 @@ pub(crate) fn create_memory_pool(
let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap();
let per_task_memory_pool =
memory_pool_map.entry(task_attempt_id).or_insert_with(|| {
record_task_started();
let pool: Arc<dyn MemoryPool> = Arc::new(TrackConsumersPool::new(
CometFairMemoryPool::new(
Arc::clone(&comet_task_memory_manager),
Expand Down Expand Up @@ -105,6 +113,7 @@ pub(crate) fn create_memory_pool(
let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap();
let per_task_memory_pool =
memory_pool_map.entry(task_attempt_id).or_insert_with(|| {
record_task_started();
let pool: Arc<dyn MemoryPool> =
if memory_pool_config.pool_type == MemoryPoolType::GreedyTaskShared {
Arc::new(TrackConsumersPool::new(
Expand Down
Loading
Loading