Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a hybrid KV cache management system for Mamba-based models, featuring a new DetachedMambaCheckpointCache to preserve state after node eviction and enhanced tracking for GPU and CPU cache hits. The changes refactor memory management in Qwen3Next, integrate hybrid matching into the radix cache, and refine inference batch logic for improved resource reuse. Feedback focuses on performance optimizations, specifically reducing redundant token conversions and consolidating repeated computations within the checkpoint cache and multi-level cache loading processes.
| def _to_token_list(self, tokens: Iterable[int]) -> list[int]: | ||
| if isinstance(tokens, torch.Tensor): | ||
| return tokens.cpu().tolist() | ||
| if isinstance(tokens, np.ndarray): | ||
| return tokens.tolist() | ||
| return list(tokens) | ||
|
|
||
| def _build_key(self, tokens: Iterable[int]) -> Optional[Tuple[Tuple[int, int], int]]: | ||
| token_list = self._to_token_list(tokens) | ||
| prefix_len = len(token_list) | ||
| if prefix_len == 0 or prefix_len % self.token_page_size != 0: | ||
| return None | ||
|
|
||
| hsum = xxhash.xxh3_128() | ||
| for page_idx in range(prefix_len // self.token_page_size): | ||
| start = page_idx * self.token_page_size | ||
| end = start + self.token_page_size | ||
| chunk_np = np.array(token_list[start:end], dtype=np.uint64) | ||
| hsum.update(chunk_np.tobytes()) | ||
|
|
||
| return (prefix_len // self.token_page_size, hsum.intdigest()), prefix_len |
There was a problem hiding this comment.
The current implementation of _to_token_list and _build_key is inefficient because it repeatedly converts between Python lists, numpy arrays, and torch tensors. Converting a large tensor to a Python list via tolist() is particularly expensive. It is better to convert the input to a single numpy array once and work with slices.
| def _to_token_list(self, tokens: Iterable[int]) -> list[int]: | |
| if isinstance(tokens, torch.Tensor): | |
| return tokens.cpu().tolist() | |
| if isinstance(tokens, np.ndarray): | |
| return tokens.tolist() | |
| return list(tokens) | |
| def _build_key(self, tokens: Iterable[int]) -> Optional[Tuple[Tuple[int, int], int]]: | |
| token_list = self._to_token_list(tokens) | |
| prefix_len = len(token_list) | |
| if prefix_len == 0 or prefix_len % self.token_page_size != 0: | |
| return None | |
| hsum = xxhash.xxh3_128() | |
| for page_idx in range(prefix_len // self.token_page_size): | |
| start = page_idx * self.token_page_size | |
| end = start + self.token_page_size | |
| chunk_np = np.array(token_list[start:end], dtype=np.uint64) | |
| hsum.update(chunk_np.tobytes()) | |
| return (prefix_len // self.token_page_size, hsum.intdigest()), prefix_len | |
| def _to_token_array(self, tokens: Iterable[int]) -> np.ndarray: | |
| if isinstance(tokens, torch.Tensor): | |
| return tokens.detach().cpu().numpy().astype(np.uint64) | |
| if isinstance(tokens, np.ndarray): | |
| return tokens.astype(np.uint64) | |
| return np.array(list(tokens), dtype=np.uint64) | |
| def _build_key(self, tokens: Iterable[int]) -> Optional[Tuple[Tuple[int, int], int]]: | |
| token_arr = self._to_token_array(tokens) | |
| prefix_len = len(token_arr) | |
| if prefix_len == 0 or prefix_len % self.token_page_size != 0: | |
| return None | |
| hsum = xxhash.xxh3_128() | |
| for page_idx in range(prefix_len // self.token_page_size): | |
| start = page_idx * self.token_page_size | |
| end = start + self.token_page_size | |
| hsum.update(token_arr[start:end].tobytes()) | |
| return (prefix_len // self.token_page_size, hsum.intdigest()), prefix_len |
| def match_prompt_prefix( | ||
| self, prompt_tokens: Iterable[int], max_prefix_len: int | ||
| ) -> Optional[DetachedMambaCheckpoint]: | ||
| page_num_limit = max_prefix_len // self.token_page_size | ||
| if page_num_limit <= 0: | ||
| return None | ||
|
|
||
| token_list = self._to_token_list(prompt_tokens) | ||
| if len(token_list) < page_num_limit * self.token_page_size: | ||
| page_num_limit = len(token_list) // self.token_page_size | ||
| if page_num_limit <= 0: | ||
| return None | ||
|
|
||
| hsum = xxhash.xxh3_128() | ||
| best_checkpoint = None | ||
| for page_idx in range(page_num_limit): | ||
| start = page_idx * self.token_page_size | ||
| end = start + self.token_page_size | ||
| chunk_np = np.array(token_list[start:end], dtype=np.uint64) | ||
| hsum.update(chunk_np.tobytes()) | ||
| key = (page_idx + 1, hsum.intdigest()) | ||
| checkpoint = self._checkpoints.get(key) | ||
| if checkpoint is not None: | ||
| best_checkpoint = checkpoint | ||
|
|
||
| if best_checkpoint is None: | ||
| return None | ||
|
|
||
| self._evict_set.discard(best_checkpoint) | ||
| best_checkpoint.touch() | ||
| self._evict_set.add(best_checkpoint) | ||
| return best_checkpoint |
There was a problem hiding this comment.
Similar to _build_key, match_prompt_prefix can be optimized by using the _to_token_array helper to avoid redundant conversions and allocations.
def match_prompt_prefix(
self, prompt_tokens: Iterable[int], max_prefix_len: int
) -> Optional[DetachedMambaCheckpoint]:
page_num_limit = max_prefix_len // self.token_page_size
if page_num_limit <= 0:
return None
token_arr = self._to_token_array(prompt_tokens)
if len(token_arr) < page_num_limit * self.token_page_size:
page_num_limit = len(token_arr) // self.token_page_size
if page_num_limit <= 0:
return None
hsum = xxhash.xxh3_128()
best_checkpoint = None
for page_idx in range(page_num_limit):
start = page_idx * self.token_page_size
end = start + self.token_page_size
hsum.update(token_arr[start:end].tobytes())
key = (page_idx + 1, hsum.intdigest())
checkpoint = self._checkpoints.get(key)
if checkpoint is not None:
best_checkpoint = checkpoint
if best_checkpoint is None:
return None
self._evict_set.discard(best_checkpoint)
best_checkpoint.touch()
self._evict_set.add(best_checkpoint)
return best_checkpoint| prompt_key = torch.tensor(req.shm_req.get_prompt_ids(), dtype=torch.int64, device="cpu")[:-1] | ||
| _, gpu_kv_len, raw_gpu_value_tensor = self.backend.radix_cache.match_prefix_kv( | ||
| prompt_key, update_refs=False | ||
| ) | ||
| kv_upper_len = max(cpu_kv_len, gpu_kv_len) | ||
| if detached_mamba_manager is not None: | ||
| detached_checkpoint = detached_mamba_manager.match_prompt_prefix( | ||
| prompt_tokens=req.shm_req.get_prompt_ids(), | ||
| max_prefix_len=kv_upper_len, | ||
| ) | ||
| detached_len = 0 if detached_checkpoint is None else detached_checkpoint.prefix_len | ||
| final_match_len = max(gpu_buffer_len, detached_len) | ||
| else: | ||
| final_match_len = max(gpu_buffer_len, cpu_kv_len) | ||
|
|
||
| # 更新命中的 cpu kv cache 长度, 减去radix cache和disk cache的部分. | ||
| if is_master_in_dp: | ||
| req.shm_req.cpu_prompt_cache_len = max( | ||
| 0, match_tokens - req.cur_kv_len - req.shm_req.disk_prompt_cache_len | ||
| 0, final_match_len - gpu_buffer_len - req.shm_req.disk_prompt_cache_len | ||
| ) | ||
|
|
||
| raw_reuse_len = min(gpu_kv_len, final_match_len) | ||
| if raw_reuse_len > req.cur_kv_len and raw_gpu_value_tensor is not None: | ||
| prompt_key = torch.tensor(req.shm_req.get_prompt_ids(), dtype=torch.int64, device="cpu")[:-1] |
There was a problem hiding this comment.
The prompt_key and prompt_ids are being redundantly created and fetched multiple times within the loop. Specifically, req.shm_req.get_prompt_ids() is called at lines 103, 110, and 126, and torch.tensor() is called at 103 and 126. Consolidating these calls at the beginning of the loop will improve performance.
prompt_ids = req.shm_req.get_prompt_ids()
prompt_key = torch.tensor(prompt_ids, dtype=torch.int64, device="cpu")[:-1]
_, gpu_kv_len, raw_gpu_value_tensor = self.backend.radix_cache.match_prefix_kv(
prompt_key, update_refs=False
)
kv_upper_len = max(cpu_kv_len, gpu_kv_len)
if detached_mamba_manager is not None:
detached_checkpoint = detached_mamba_manager.match_prompt_prefix(
prompt_tokens=prompt_ids,
max_prefix_len=kv_upper_len,
)
detached_len = 0 if detached_checkpoint is None else detached_checkpoint.prefix_len
final_match_len = max(gpu_buffer_len, detached_len)
else:
final_match_len = max(gpu_buffer_len, cpu_kv_len)
# 更新命中的 cpu kv cache 长度, 减去radix cache和disk cache的部分.
if is_master_in_dp:
req.shm_req.cpu_prompt_cache_len = max(
0, final_match_len - gpu_buffer_len - req.shm_req.disk_prompt_cache_len
)
raw_reuse_len = min(gpu_kv_len, final_match_len)
if raw_reuse_len > req.cur_kv_len and raw_gpu_value_tensor is not None:f6dac47 to
07fb7ca
Compare
No description provided.