Skip to content

Qw35 cpu cache#1278

Open
blueswhen wants to merge 4 commits intoqw35_stablefrom
qw35_cpu_cache
Open

Qw35 cpu cache#1278
blueswhen wants to merge 4 commits intoqw35_stablefrom
qw35_cpu_cache

Conversation

@blueswhen
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a hybrid KV cache management system for Mamba-based models, featuring a new DetachedMambaCheckpointCache to preserve state after node eviction and enhanced tracking for GPU and CPU cache hits. The changes refactor memory management in Qwen3Next, integrate hybrid matching into the radix cache, and refine inference batch logic for improved resource reuse. Feedback focuses on performance optimizations, specifically reducing redundant token conversions and consolidating repeated computations within the checkpoint cache and multi-level cache loading processes.

Comment on lines +37 to +57
def _to_token_list(self, tokens: Iterable[int]) -> list[int]:
if isinstance(tokens, torch.Tensor):
return tokens.cpu().tolist()
if isinstance(tokens, np.ndarray):
return tokens.tolist()
return list(tokens)

def _build_key(self, tokens: Iterable[int]) -> Optional[Tuple[Tuple[int, int], int]]:
token_list = self._to_token_list(tokens)
prefix_len = len(token_list)
if prefix_len == 0 or prefix_len % self.token_page_size != 0:
return None

hsum = xxhash.xxh3_128()
for page_idx in range(prefix_len // self.token_page_size):
start = page_idx * self.token_page_size
end = start + self.token_page_size
chunk_np = np.array(token_list[start:end], dtype=np.uint64)
hsum.update(chunk_np.tobytes())

return (prefix_len // self.token_page_size, hsum.intdigest()), prefix_len
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of _to_token_list and _build_key is inefficient because it repeatedly converts between Python lists, numpy arrays, and torch tensors. Converting a large tensor to a Python list via tolist() is particularly expensive. It is better to convert the input to a single numpy array once and work with slices.

Suggested change
def _to_token_list(self, tokens: Iterable[int]) -> list[int]:
if isinstance(tokens, torch.Tensor):
return tokens.cpu().tolist()
if isinstance(tokens, np.ndarray):
return tokens.tolist()
return list(tokens)
def _build_key(self, tokens: Iterable[int]) -> Optional[Tuple[Tuple[int, int], int]]:
token_list = self._to_token_list(tokens)
prefix_len = len(token_list)
if prefix_len == 0 or prefix_len % self.token_page_size != 0:
return None
hsum = xxhash.xxh3_128()
for page_idx in range(prefix_len // self.token_page_size):
start = page_idx * self.token_page_size
end = start + self.token_page_size
chunk_np = np.array(token_list[start:end], dtype=np.uint64)
hsum.update(chunk_np.tobytes())
return (prefix_len // self.token_page_size, hsum.intdigest()), prefix_len
def _to_token_array(self, tokens: Iterable[int]) -> np.ndarray:
if isinstance(tokens, torch.Tensor):
return tokens.detach().cpu().numpy().astype(np.uint64)
if isinstance(tokens, np.ndarray):
return tokens.astype(np.uint64)
return np.array(list(tokens), dtype=np.uint64)
def _build_key(self, tokens: Iterable[int]) -> Optional[Tuple[Tuple[int, int], int]]:
token_arr = self._to_token_array(tokens)
prefix_len = len(token_arr)
if prefix_len == 0 or prefix_len % self.token_page_size != 0:
return None
hsum = xxhash.xxh3_128()
for page_idx in range(prefix_len // self.token_page_size):
start = page_idx * self.token_page_size
end = start + self.token_page_size
hsum.update(token_arr[start:end].tobytes())
return (prefix_len // self.token_page_size, hsum.intdigest()), prefix_len

Comment on lines +75 to +106
def match_prompt_prefix(
self, prompt_tokens: Iterable[int], max_prefix_len: int
) -> Optional[DetachedMambaCheckpoint]:
page_num_limit = max_prefix_len // self.token_page_size
if page_num_limit <= 0:
return None

token_list = self._to_token_list(prompt_tokens)
if len(token_list) < page_num_limit * self.token_page_size:
page_num_limit = len(token_list) // self.token_page_size
if page_num_limit <= 0:
return None

hsum = xxhash.xxh3_128()
best_checkpoint = None
for page_idx in range(page_num_limit):
start = page_idx * self.token_page_size
end = start + self.token_page_size
chunk_np = np.array(token_list[start:end], dtype=np.uint64)
hsum.update(chunk_np.tobytes())
key = (page_idx + 1, hsum.intdigest())
checkpoint = self._checkpoints.get(key)
if checkpoint is not None:
best_checkpoint = checkpoint

if best_checkpoint is None:
return None

self._evict_set.discard(best_checkpoint)
best_checkpoint.touch()
self._evict_set.add(best_checkpoint)
return best_checkpoint
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to _build_key, match_prompt_prefix can be optimized by using the _to_token_array helper to avoid redundant conversions and allocations.

    def match_prompt_prefix(
        self, prompt_tokens: Iterable[int], max_prefix_len: int
    ) -> Optional[DetachedMambaCheckpoint]:
        page_num_limit = max_prefix_len // self.token_page_size
        if page_num_limit <= 0:
            return None

        token_arr = self._to_token_array(prompt_tokens)
        if len(token_arr) < page_num_limit * self.token_page_size:
            page_num_limit = len(token_arr) // self.token_page_size
        if page_num_limit <= 0:
            return None

        hsum = xxhash.xxh3_128()
        best_checkpoint = None
        for page_idx in range(page_num_limit):
            start = page_idx * self.token_page_size
            end = start + self.token_page_size
            hsum.update(token_arr[start:end].tobytes())
            key = (page_idx + 1, hsum.intdigest())
            checkpoint = self._checkpoints.get(key)
            if checkpoint is not None:
                best_checkpoint = checkpoint

        if best_checkpoint is None:
            return None

        self._evict_set.discard(best_checkpoint)
        best_checkpoint.touch()
        self._evict_set.add(best_checkpoint)
        return best_checkpoint

Comment on lines +103 to +126
prompt_key = torch.tensor(req.shm_req.get_prompt_ids(), dtype=torch.int64, device="cpu")[:-1]
_, gpu_kv_len, raw_gpu_value_tensor = self.backend.radix_cache.match_prefix_kv(
prompt_key, update_refs=False
)
kv_upper_len = max(cpu_kv_len, gpu_kv_len)
if detached_mamba_manager is not None:
detached_checkpoint = detached_mamba_manager.match_prompt_prefix(
prompt_tokens=req.shm_req.get_prompt_ids(),
max_prefix_len=kv_upper_len,
)
detached_len = 0 if detached_checkpoint is None else detached_checkpoint.prefix_len
final_match_len = max(gpu_buffer_len, detached_len)
else:
final_match_len = max(gpu_buffer_len, cpu_kv_len)

# 更新命中的 cpu kv cache 长度, 减去radix cache和disk cache的部分.
if is_master_in_dp:
req.shm_req.cpu_prompt_cache_len = max(
0, match_tokens - req.cur_kv_len - req.shm_req.disk_prompt_cache_len
0, final_match_len - gpu_buffer_len - req.shm_req.disk_prompt_cache_len
)

raw_reuse_len = min(gpu_kv_len, final_match_len)
if raw_reuse_len > req.cur_kv_len and raw_gpu_value_tensor is not None:
prompt_key = torch.tensor(req.shm_req.get_prompt_ids(), dtype=torch.int64, device="cpu")[:-1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The prompt_key and prompt_ids are being redundantly created and fetched multiple times within the loop. Specifically, req.shm_req.get_prompt_ids() is called at lines 103, 110, and 126, and torch.tensor() is called at 103 and 126. Consolidating these calls at the beginning of the loop will improve performance.

                prompt_ids = req.shm_req.get_prompt_ids()
                prompt_key = torch.tensor(prompt_ids, dtype=torch.int64, device="cpu")[:-1]
                _, gpu_kv_len, raw_gpu_value_tensor = self.backend.radix_cache.match_prefix_kv(
                    prompt_key, update_refs=False
                )
                kv_upper_len = max(cpu_kv_len, gpu_kv_len)
                if detached_mamba_manager is not None:
                    detached_checkpoint = detached_mamba_manager.match_prompt_prefix(
                        prompt_tokens=prompt_ids,
                        max_prefix_len=kv_upper_len,
                    )
                detached_len = 0 if detached_checkpoint is None else detached_checkpoint.prefix_len
                final_match_len = max(gpu_buffer_len, detached_len)
            else:
                final_match_len = max(gpu_buffer_len, cpu_kv_len)

            # 更新命中的 cpu kv cache 长度, 减去radix cache和disk cache的部分.
            if is_master_in_dp:
                req.shm_req.cpu_prompt_cache_len = max(
                    0, final_match_len - gpu_buffer_len - req.shm_req.disk_prompt_cache_len
                )

            raw_reuse_len = min(gpu_kv_len, final_match_len)
            if raw_reuse_len > req.cur_kv_len and raw_gpu_value_tensor is not None:

@blueswhen blueswhen force-pushed the qw35_cpu_cache branch 3 times, most recently from f6dac47 to 07fb7ca Compare April 21, 2026 10:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant