Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1041,10 +1041,11 @@ void NgramMatch(const paddle::Tensor& token_ids_all,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& max_dec_len,
const int max_ngram_size,
const int max_draft_tokens);
const int max_draft_tokens,
const bool pad_to_max);

void HybridMtpNgram(const paddle::Tensor& input_ids,
const paddle::Tensor& input_ids_len,
void HybridMtpNgram(const paddle::Tensor& token_ids_all,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& pre_ids,
const paddle::Tensor& step_idx,
const paddle::Tensor& draft_token_num,
Expand All @@ -1054,7 +1055,8 @@ void HybridMtpNgram(const paddle::Tensor& input_ids,
const paddle::Tensor& max_dec_len,
const int max_ngram_size,
const int min_ngram_size,
const int max_draft_tokens);
const int max_draft_tokens,
const bool pad_to_max);

// MTP
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
Expand Down
135 changes: 90 additions & 45 deletions custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@
// Also copies tentative matched tokens to scratch buffers.
// ============================================================
__global__ void ngram_match_mixed_search_kernel(
const int64_t *input_ids,
const int64_t *input_ids_len,
const int64_t *token_ids_all,
const int64_t *prompt_lens,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *draft_token_num,
const int32_t *seq_lens_this_time,
const int64_t *max_dec_len,
int64_t *draft_tokens_copy,
int32_t *seq_lens_this_time_copy,
int64_t input_ids_stride,
int64_t max_model_len,
int64_t pre_ids_stride,
int64_t draft_tokens_stride,
int64_t max_batch_size,
Expand Down Expand Up @@ -69,8 +69,9 @@ __global__ void ngram_match_mixed_search_kernel(
if (draft_budget <= 0 || remaining_dec <= 0) return;
int max_draft_tokens = static_cast<int>(min(draft_budget, remaining_dec));

const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
const int64_t prompt_len = prompt_lens[batch_idx];
const int64_t *cur_input_ids = token_ids_all + batch_idx * max_model_len;
const int64_t cur_input_ids_len = prompt_len;
const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride;
const int64_t cur_step_idx = step_idx[batch_idx];

Expand Down Expand Up @@ -143,7 +144,9 @@ __global__ void ngram_match_mixed_gather_kernel(
int32_t *seq_lens_this_time,
int64_t draft_tokens_stride,
int64_t max_batch_size,
int threshold) {
int threshold,
int max_draft_tokens_param,
bool pad_to_max) {
typedef cub::BlockScan<int, NGRAM_GATHER_THREADS> BlockScanInt;
__shared__ typename BlockScanInt::TempStorage temp_storage1;
__shared__ typename BlockScanInt::TempStorage temp_storage2;
Expand Down Expand Up @@ -202,9 +205,8 @@ __global__ void ngram_match_mixed_gather_kernel(
}
actual = min(actual, tentative);

seq_lens_this_time[tid] = actual;

// Copy ngram draft tokens from scratch to output
// Copy ngram draft tokens from scratch to output FIRST
// (so subsequent padding doesn't overwrite real ngram hits)
int ngram_to_copy = actual - ori;
if (ngram_to_copy > 0) {
int64_t *dst = draft_tokens + tid * draft_tokens_stride;
Expand All @@ -213,6 +215,38 @@ __global__ void ngram_match_mixed_gather_kernel(
dst[ori + k] = src[ori + k];
}
}

// === Pad seq_lens_this_time to num_speculative_tokens+1 for cudagraph
// stability === Hybrid MTP-ngram produces variable seq_lens_this_time
// depending on how many ngram positions hit (range: [num_model_steps+1,
// num_speculative_tokens+1]). cudagraph captures launch params (grid dim,
// kernel args) at capture time; if the captured slt differs from
// replay-time slt, downstream kernels read past valid ranges of cu_seqlens
// / slot_mapping etc., causing CUDA 700.
//
// When pad_to_max=true (cudagraph enabled), force slt =
// num_speculative_tokens+1 = max_draft_tokens + 1: positions beyond actual
// ngram hits get padded with a placeholder token. The target model will
// verify these placeholders and (almost always) reject them, but the verify
// cost is fixed per iteration => grid dim is now invariant. When
// pad_to_max= false (cudagraph disabled), keep the natural variable slt to
// avoid wasting verify compute on placeholders.
if (pad_to_max) {
int target_slt = max_draft_tokens_param + 1;
if (actual < target_slt) {
int64_t *dst = draft_tokens + tid * draft_tokens_stride;
// Reuse the last valid draft token as placeholder. It is a token the
// model could plausibly have produced, so attention math stays
// well-defined; rejection happens at the sampler level.
int64_t pad_token = (actual > 0) ? dst[actual - 1] : 0;
for (int k = actual; k < target_slt; k++) {
dst[k] = pad_token;
}
actual = target_slt;
}
}

seq_lens_this_time[tid] = actual;
}
}

Expand All @@ -228,16 +262,16 @@ static int sum_mixed_cpu(const int *value, int num) {
return sum_value;
}

static void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
const int64_t *input_ids_len,
static void find_candidate_pred_tokens_mixed(const int64_t *token_ids_all,
const int64_t *prompt_lens,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *draft_token_num,
int64_t *draft_tokens,
int32_t *seq_lens_this_time,
int32_t *seq_lens_decoder,
int64_t *max_dec_len,
int64_t input_ids_stride,
int64_t max_model_len,
int64_t pre_ids_stride,
int64_t draft_tokens_stride,
int64_t max_batch_size,
Expand Down Expand Up @@ -268,11 +302,12 @@ static void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
int max_draft_tokens_query =
static_cast<int>(std::min(draft_budget, remaining_dec));

const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
const int64_t prompt_len = prompt_lens[batch_idx];
const int64_t *cur_input_ids = token_ids_all + batch_idx * max_model_len;
int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride;
const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride;
const int64_t cur_step_idx = step_idx[batch_idx];
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
const int64_t cur_input_ids_len = prompt_len;
unprocessed_batch_size--;

auto sum_token_num = sum_mixed_cpu(seq_lens_this_time, batch_idx);
Expand Down Expand Up @@ -363,8 +398,8 @@ static void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
// threshold enforcement + final token copy.
// ============================================================

void HybridMtpNgram(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
void HybridMtpNgram(const paddle::Tensor &token_ids_all,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &draft_token_num,
Expand All @@ -374,9 +409,9 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
const paddle::Tensor &max_dec_len,
const int max_ngram_size,
const int min_ngram_size,
const int max_draft_tokens) {
auto input_ids_shape = input_ids.shape();
const int64_t input_ids_stride = input_ids_shape[1];
const int max_draft_tokens,
const bool pad_to_max) {
const int64_t max_model_len = token_ids_all.shape()[1];

auto pre_ids_shape = pre_ids.shape();
const int64_t pre_ids_stride = pre_ids_shape[1];
Expand All @@ -392,8 +427,8 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
threshold = std::stoi(env_var);
}

if (input_ids.is_gpu()) {
auto stream = input_ids.stream();
if (token_ids_all.is_gpu()) {
auto stream = token_ids_all.stream();

// NOTE: GPU path does not pass seq_lens_decoder to kernels — the mixed
// variant uses ori_seq_len_this_time == 0 to skip inactive items. This
Expand All @@ -403,21 +438,28 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
// counts tentative > 0, which is equivalent under this invariant.

This comment was marked as outdated.

This comment was marked as outdated.


// Allocate scratch buffers for Phase 1 → Phase 2 communication
static paddle::Tensor s_draft_copy_mixed;

This comment was marked as outdated.

This comment was marked as outdated.

static paddle::Tensor s_seqlens_copy_mixed;
static paddle::Tensor s_seqlens_orig_mixed;
static int64_t s_scratch_batch_mixed = 0;
static int64_t s_scratch_stride_mixed = 0;

if (max_batch_size > s_scratch_batch_mixed ||
draft_tokens_stride > s_scratch_stride_mixed) {
s_draft_copy_mixed = paddle::empty({max_batch_size, draft_tokens_stride},
paddle::DataType::INT64,
token_ids_all.place());
s_seqlens_copy_mixed = paddle::empty(
{max_batch_size}, paddle::DataType::INT32, token_ids_all.place());
s_seqlens_orig_mixed = paddle::empty(
{max_batch_size}, paddle::DataType::INT32, token_ids_all.place());
s_scratch_batch_mixed = max_batch_size;
s_scratch_stride_mixed = draft_tokens_stride;
}
auto &draft_tokens_copy = s_draft_copy_mixed;
auto &seq_lens_this_time_copy = s_seqlens_copy_mixed;
auto &seq_lens_this_time_orig = s_seqlens_orig_mixed;

// Scratch copy of draft_tokens (Phase 1 writes tentative tokens here)
auto draft_tokens_copy =
paddle::empty({max_batch_size, draft_tokens_stride},
paddle::DataType::INT64,
input_ids.place());

// Scratch copy of seq_lens_this_time (Phase 1 writes tentative counts)
auto seq_lens_this_time_copy = paddle::empty(
{max_batch_size}, paddle::DataType::INT32, input_ids.place());

// Save a copy of original seq_lens_this_time for Phase 2
// (Phase 1 reads from the original, Phase 2 needs ori values)
auto seq_lens_this_time_orig = paddle::empty(
{max_batch_size}, paddle::DataType::INT32, input_ids.place());
cudaMemcpyAsync(seq_lens_this_time_orig.data<int32_t>(),
seq_lens_this_time.data<int32_t>(),
max_batch_size * sizeof(int32_t),
Expand All @@ -434,16 +476,16 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
NGRAM_BLOCK_THREADS,
0,
stream>>>(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
token_ids_all.data<int64_t>(),
prompt_lens.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
seq_lens_this_time.data<int32_t>(),
max_dec_len.data<int64_t>(),
draft_tokens_copy.data<int64_t>(),
seq_lens_this_time_copy.data<int32_t>(),
input_ids_stride,
max_model_len,
pre_ids_stride,
draft_tokens_stride,
max_batch_size,
Expand All @@ -461,19 +503,21 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
draft_tokens_stride,
max_batch_size,
threshold);
threshold,
max_draft_tokens,
pad_to_max);
} else {
find_candidate_pred_tokens_mixed(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
token_ids_all.data<int64_t>(),
prompt_lens.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
max_model_len,
pre_ids_stride,
draft_tokens_stride,
max_batch_size,
Expand All @@ -484,8 +528,8 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
}

PD_BUILD_STATIC_OP(hybrid_mtp_ngram)
.Inputs({"input_ids",
"input_ids_len",
.Inputs({"token_ids_all",
"prompt_lens",
"pre_ids",
"step_idx",
"draft_token_num",
Expand All @@ -495,7 +539,8 @@ PD_BUILD_STATIC_OP(hybrid_mtp_ngram)
"max_dec_len"})
.Attrs({"max_ngram_size: int",
"min_ngram_size: int",
"max_draft_tokens: int"})
"max_draft_tokens: int",
"pad_to_max: bool"})
.Outputs({"draft_tokens_out", "seq_lens_this_time_out"})
.SetKernelFn(PD_KERNEL(HybridMtpNgram))
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
Expand Down
43 changes: 36 additions & 7 deletions custom_ops/gpu_ops/speculate_decoding/ngram_match.cu
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ __global__ void ngram_match_gather_kernel(
int32_t *seq_lens_this_time,
int64_t draft_tokens_stride,
int64_t max_batch_size,
int threshold) {
int threshold,
int max_draft_tokens_param,
bool pad_to_max) {
typedef cub::BlockScan<int, NGRAM_GATHER_THREADS> BlockScanInt;
__shared__ typename BlockScanInt::TempStorage temp_storage1;
__shared__ typename BlockScanInt::TempStorage temp_storage2;
Expand Down Expand Up @@ -203,16 +205,40 @@ __global__ void ngram_match_gather_kernel(
actual = min(tentative, budget);
}

seq_lens_this_time[tid] = actual;

// Copy draft tokens (slots 1..actual-1) from scratch to output
// Copy draft tokens (slots 1..actual-1) from scratch to output FIRST
// (so subsequent padding doesn't overwrite real ngram hits)
if (actual > 1) {
int64_t *dst = draft_tokens + tid * draft_tokens_stride;
const int64_t *src = draft_tokens_copy + tid * draft_tokens_stride;
for (int k = 1; k < actual; k++) {
dst[k] = src[k];
}
}

// === Pad seq_lens_this_time to num_speculative_tokens+1 for cudagraph
// stability === Variable seq_lens_this_time (range [1,
// num_speculative_tokens+1]) clashes with cudagraph's fixed launch params
// captured at warm-up time; downstream kernels read past valid cu_seqlens /
// slot_mapping when replay sees a smaller slt, leading to OOB / CUDA 700.
// When pad_to_max=true (cudagraph enabled), pad missing positions with a
// placeholder so slt is fixed at num_speculative_tokens+1. pad_to_max=false
// skips the padding cost when cudagraph is off.
if (pad_to_max) {
int target_slt = max_draft_tokens_param + 1;
if (actual < target_slt) {
int64_t *dst = draft_tokens + tid * draft_tokens_stride;
// Reuse the last valid draft token as placeholder. It is a token the
// model could plausibly have produced, so attention math stays
// well-defined; rejection happens at the sampler level.
int64_t pad_token = (actual > 0) ? dst[actual - 1] : 0;
for (int k = actual; k < target_slt; k++) {
dst[k] = pad_token;
}
actual = target_slt;
}
}

seq_lens_this_time[tid] = actual;
}
}

Expand Down Expand Up @@ -374,7 +400,8 @@ void NgramMatch(const paddle::Tensor &token_ids_all,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &max_dec_len,
const int max_ngram_size,
const int max_draft_tokens) {
const int max_draft_tokens,
const bool pad_to_max) {
const int64_t max_model_len = token_ids_all.shape()[1];

auto draft_tokens_shape = draft_tokens.shape();
Expand Down Expand Up @@ -448,7 +475,9 @@ void NgramMatch(const paddle::Tensor &token_ids_all,
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
draft_tokens_stride,
max_batch_size,
threshold);
threshold,
max_draft_tokens,
pad_to_max);
} else {
find_candidate_pred_tokens(
token_ids_all.data<int64_t>(),
Expand Down Expand Up @@ -478,7 +507,7 @@ PD_BUILD_STATIC_OP(ngram_match)
"seq_lens_encoder",
"seq_lens_decoder",
"max_dec_len"})
.Attrs({"max_ngram_size: int", "max_draft_tokens: int"})
.Attrs({"max_ngram_size: int", "max_draft_tokens: int", "pad_to_max: bool"})
.Outputs({"draft_tokens_out", "seq_lens_this_time_out"})
.SetKernelFn(PD_KERNEL(NgramMatch))
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
Expand Down
Loading
Loading