sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/constants.py +3 -0
  6. sglang/srt/conversation.py +14 -3
  7. sglang/srt/custom_op.py +11 -1
  8. sglang/srt/disaggregation/base/conn.py +2 -0
  9. sglang/srt/disaggregation/decode.py +22 -28
  10. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  11. sglang/srt/disaggregation/mini_lb.py +34 -4
  12. sglang/srt/disaggregation/mooncake/conn.py +301 -64
  13. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  14. sglang/srt/disaggregation/nixl/conn.py +94 -46
  15. sglang/srt/disaggregation/prefill.py +20 -15
  16. sglang/srt/disaggregation/utils.py +47 -18
  17. sglang/srt/distributed/parallel_state.py +12 -4
  18. sglang/srt/entrypoints/engine.py +27 -31
  19. sglang/srt/entrypoints/http_server.py +149 -79
  20. sglang/srt/entrypoints/http_server_engine.py +0 -3
  21. sglang/srt/entrypoints/openai/__init__.py +0 -0
  22. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
  23. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  24. sglang/srt/entrypoints/openai/serving_chat.py +897 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +425 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
  27. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  28. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  29. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  30. sglang/srt/entrypoints/openai/utils.py +72 -0
  31. sglang/srt/function_call/base_format_detector.py +7 -4
  32. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  33. sglang/srt/function_call/ebnf_composer.py +64 -10
  34. sglang/srt/function_call/function_call_parser.py +6 -6
  35. sglang/srt/function_call/llama32_detector.py +1 -1
  36. sglang/srt/function_call/mistral_detector.py +1 -1
  37. sglang/srt/function_call/pythonic_detector.py +1 -1
  38. sglang/srt/function_call/qwen25_detector.py +1 -1
  39. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  40. sglang/srt/layers/activation.py +28 -3
  41. sglang/srt/layers/attention/aiter_backend.py +5 -2
  42. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  43. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  44. sglang/srt/layers/attention/flashattention_backend.py +43 -23
  45. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  46. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  47. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  48. sglang/srt/layers/attention/tbo_backend.py +3 -3
  49. sglang/srt/layers/attention/triton_backend.py +19 -11
  50. sglang/srt/layers/communicator.py +5 -5
  51. sglang/srt/layers/dp_attention.py +11 -2
  52. sglang/srt/layers/layernorm.py +44 -2
  53. sglang/srt/layers/linear.py +18 -1
  54. sglang/srt/layers/logits_processor.py +14 -5
  55. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  56. sglang/srt/layers/moe/ep_moe/layer.py +286 -13
  57. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  58. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
  62. sglang/srt/layers/moe/topk.py +117 -4
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  64. sglang/srt/layers/quantization/fp8.py +25 -17
  65. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  66. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  67. sglang/srt/layers/quantization/utils.py +5 -2
  68. sglang/srt/layers/rotary_embedding.py +144 -12
  69. sglang/srt/layers/sampler.py +1 -1
  70. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  71. sglang/srt/lora/lora_manager.py +173 -74
  72. sglang/srt/lora/mem_pool.py +49 -45
  73. sglang/srt/lora/utils.py +1 -1
  74. sglang/srt/managers/cache_controller.py +33 -15
  75. sglang/srt/managers/expert_distribution.py +21 -0
  76. sglang/srt/managers/io_struct.py +19 -14
  77. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  78. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  79. sglang/srt/managers/schedule_batch.py +49 -32
  80. sglang/srt/managers/schedule_policy.py +70 -56
  81. sglang/srt/managers/scheduler.py +189 -68
  82. sglang/srt/managers/template_manager.py +226 -0
  83. sglang/srt/managers/tokenizer_manager.py +11 -8
  84. sglang/srt/managers/tp_worker.py +12 -2
  85. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  86. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  87. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  88. sglang/srt/mem_cache/chunk_cache.py +11 -16
  89. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  90. sglang/srt/mem_cache/memory_pool.py +118 -114
  91. sglang/srt/mem_cache/radix_cache.py +20 -16
  92. sglang/srt/model_executor/cuda_graph_runner.py +77 -46
  93. sglang/srt/model_executor/forward_batch_info.py +18 -5
  94. sglang/srt/model_executor/model_runner.py +27 -8
  95. sglang/srt/model_loader/loader.py +50 -8
  96. sglang/srt/model_loader/weight_utils.py +100 -2
  97. sglang/srt/models/deepseek_nextn.py +35 -30
  98. sglang/srt/models/deepseek_v2.py +255 -30
  99. sglang/srt/models/gemma3n_audio.py +949 -0
  100. sglang/srt/models/gemma3n_causal.py +1009 -0
  101. sglang/srt/models/gemma3n_mm.py +511 -0
  102. sglang/srt/models/glm4.py +312 -0
  103. sglang/srt/models/hunyuan.py +771 -0
  104. sglang/srt/models/mimo_mtp.py +2 -18
  105. sglang/srt/reasoning_parser.py +21 -11
  106. sglang/srt/server_args.py +51 -9
  107. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  108. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  109. sglang/srt/speculative/eagle_utils.py +80 -8
  110. sglang/srt/speculative/eagle_worker.py +124 -41
  111. sglang/srt/torch_memory_saver_adapter.py +19 -15
  112. sglang/srt/two_batch_overlap.py +4 -1
  113. sglang/srt/utils.py +248 -11
  114. sglang/test/test_block_fp8_ep.py +1 -0
  115. sglang/test/test_utils.py +1 -0
  116. sglang/version.py +1 -1
  117. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
  118. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
  119. sglang/srt/entrypoints/verl_engine.py +0 -179
  120. sglang/srt/openai_api/adapter.py +0 -2148
  121. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  123. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -7,11 +7,12 @@ from typing import List, Optional
7
7
  import torch
8
8
 
9
9
  from sglang.srt.managers.cache_controller import HiCacheController
10
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
11
+ from sglang.srt.mem_cache.base_prefix_cache import MatchResult
10
12
  from sglang.srt.mem_cache.memory_pool import (
11
13
  MHATokenToKVPool,
12
14
  MLATokenToKVPool,
13
15
  ReqToTokenPool,
14
- TokenToKVPoolAllocator,
15
16
  )
16
17
  from sglang.srt.mem_cache.memory_pool_host import (
17
18
  MHATokenToKVPoolHost,
@@ -27,7 +28,7 @@ class HiRadixCache(RadixCache):
27
28
  def __init__(
28
29
  self,
29
30
  req_to_token_pool: ReqToTokenPool,
30
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
31
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
31
32
  tp_cache_group: torch.distributed.ProcessGroup,
32
33
  page_size: int,
33
34
  hicache_ratio: float,
@@ -283,39 +284,44 @@ class HiRadixCache(RadixCache):
283
284
  def init_load_back(
284
285
  self,
285
286
  last_node: TreeNode,
286
- prefix_indices: torch.Tensor,
287
+ host_hit_length: int,
287
288
  mem_quota: Optional[int] = None,
288
289
  ):
289
- assert (
290
- len(prefix_indices) == 0 or prefix_indices.is_cuda
291
- ), "indices of device kV caches should be on GPU"
290
+ _ = host_hit_length # unused, but kept for compatibility
292
291
  if last_node.evicted:
293
292
  loading_values = self.load_back(last_node, mem_quota)
294
293
  if loading_values is not None:
295
- prefix_indices = (
296
- loading_values
297
- if len(prefix_indices) == 0
298
- else torch.cat([prefix_indices, loading_values])
299
- )
300
294
  logger.debug(
301
295
  f"loading back {len(loading_values)} tokens for node {last_node.id}"
302
296
  )
297
+ return loading_values, last_node
303
298
 
304
299
  while last_node.evicted:
305
300
  last_node = last_node.parent
306
301
 
307
- return last_node, prefix_indices
302
+ return (
303
+ torch.empty((0,), dtype=torch.int64, device=self.device),
304
+ last_node,
305
+ )
308
306
 
309
- def ready_to_load_cache(self):
307
+ def ready_to_load_host_cache(self):
308
+ producer_index = self.cache_controller.layer_done_counter.next_producer()
310
309
  self.load_cache_event.set()
310
+ return producer_index
311
311
 
312
- def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
312
+ def check_hicache_events(self):
313
+ self.writing_check()
314
+ self.loading_check()
315
+
316
+ def match_prefix(self, key: List[int], **kwargs):
313
317
  empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
314
318
  if self.disable or len(key) == 0:
315
- if include_evicted:
316
- return empty_value, self.root_node, self.root_node
317
- else:
318
- return empty_value, self.root_node
319
+ return MatchResult(
320
+ device_indices=empty_value,
321
+ last_device_node=self.root_node,
322
+ last_host_node=self.root_node,
323
+ host_hit_length=0,
324
+ )
319
325
 
320
326
  if self.page_size != 1:
321
327
  page_aligned_len = len(key) // self.page_size * self.page_size
@@ -327,14 +333,18 @@ class HiRadixCache(RadixCache):
327
333
  else:
328
334
  value = empty_value
329
335
 
330
- last_node_global = last_node
336
+ host_hit_length = 0
337
+ last_host_node = last_node
331
338
  while last_node.evicted:
339
+ host_hit_length += len(last_node.host_value)
332
340
  last_node = last_node.parent
333
341
 
334
- if include_evicted:
335
- return value, last_node, last_node_global
336
- else:
337
- return value, last_node
342
+ return MatchResult(
343
+ device_indices=value,
344
+ last_device_node=last_node,
345
+ last_host_node=last_host_node,
346
+ host_hit_length=host_hit_length,
347
+ )
338
348
 
339
349
  def _match_prefix_helper(self, node: TreeNode, key: List):
340
350
  node.last_access_time = time.monotonic()
@@ -372,6 +382,7 @@ class HiRadixCache(RadixCache):
372
382
  new_node.lock_ref = child.lock_ref
373
383
  new_node.key = child.key[:split_len]
374
384
  new_node.loading = child.loading
385
+ new_node.hit_count = child.hit_count
375
386
 
376
387
  # split value and host value if exists
377
388
  if child.evicted:
@@ -26,6 +26,7 @@ KVCache actually holds the physical kv cache.
26
26
 
27
27
  import abc
28
28
  import logging
29
+ from contextlib import nullcontext
29
30
  from typing import List, Optional, Tuple, Union
30
31
 
31
32
  import numpy as np
@@ -33,8 +34,9 @@ import torch
33
34
  import triton
34
35
  import triton.language as tl
35
36
 
37
+ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
36
38
  from sglang.srt.layers.radix_attention import RadixAttention
37
- from sglang.srt.utils import debug_timing, is_cuda, next_power_of_2
39
+ from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
38
40
 
39
41
  logger = logging.getLogger(__name__)
40
42
 
@@ -52,6 +54,7 @@ class ReqToTokenPool:
52
54
  device: str,
53
55
  enable_memory_saver: bool,
54
56
  ):
57
+
55
58
  memory_saver_adapter = TorchMemorySaverAdapter.create(
56
59
  enable=enable_memory_saver
57
60
  )
@@ -59,7 +62,7 @@ class ReqToTokenPool:
59
62
  self.size = size
60
63
  self.max_context_len = max_context_len
61
64
  self.device = device
62
- with memory_saver_adapter.region():
65
+ with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
63
66
  self.req_to_token = torch.zeros(
64
67
  (size, max_context_len), dtype=torch.int32, device=device
65
68
  )
@@ -119,6 +122,9 @@ class KVCache(abc.ABC):
119
122
  enable=enable_memory_saver
120
123
  )
121
124
 
125
+ # used for chunked cpu-offloading
126
+ self.cpu_offloading_chunk_size = 8192
127
+
122
128
  @abc.abstractmethod
123
129
  def get_key_buffer(self, layer_id: int) -> torch.Tensor:
124
130
  raise NotImplementedError()
@@ -153,83 +159,11 @@ class KVCache(abc.ABC):
153
159
  def register_layer_transfer_counter(self, layer_transfer_counter):
154
160
  self.layer_transfer_counter = layer_transfer_counter
155
161
 
156
-
157
- class TokenToKVPoolAllocator:
158
- """An allocator managing the indices to kv cache data."""
159
-
160
- def __init__(
161
- self,
162
- size: int,
163
- dtype: torch.dtype,
164
- device: str,
165
- kvcache: KVCache,
166
- ):
167
- self.size = size
168
- self.dtype = dtype
169
- self.device = device
170
- self.page_size = 1
171
-
172
- self.free_slots = None
173
- self.is_not_in_free_group = True
174
- self.free_group = []
175
- self.clear()
176
-
177
- self._kvcache = kvcache
178
-
179
- def available_size(self):
180
- return len(self.free_slots)
181
-
182
- def debug_print(self) -> str:
183
- return ""
184
-
185
- def get_kvcache(self):
186
- return self._kvcache
187
-
188
- def alloc(self, need_size: int):
189
- if need_size > len(self.free_slots):
190
- return None
191
-
192
- select_index = self.free_slots[:need_size]
193
- self.free_slots = self.free_slots[need_size:]
194
- return select_index
195
-
196
- def free(self, free_index: torch.Tensor):
197
- if free_index.numel() == 0:
198
- return
199
-
200
- if self.is_not_in_free_group:
201
- self.free_slots = torch.cat((self.free_slots, free_index))
202
- else:
203
- self.free_group.append(free_index)
204
-
205
- def free_group_begin(self):
206
- self.is_not_in_free_group = False
207
- self.free_group = []
208
-
209
- def free_group_end(self):
210
- self.is_not_in_free_group = True
211
- if self.free_group:
212
- self.free(torch.cat(self.free_group))
213
-
214
- def backup_state(self):
215
- return self.free_slots
216
-
217
- def restore_state(self, free_slots):
218
- self.free_slots = free_slots
219
-
220
- def clear(self):
221
- # The padded slot 0 is used for writing dummy outputs from padded tokens.
222
- self.free_slots = torch.arange(
223
- 1, self.size + 1, dtype=torch.int64, device=self.device
224
- )
225
- self.is_not_in_free_group = True
226
- self.free_group = []
227
-
228
162
  def get_cpu_copy(self, indices):
229
- return self._kvcache.get_cpu_copy(indices)
163
+ raise NotImplementedError()
230
164
 
231
165
  def load_cpu_copy(self, kv_cache_cpu, indices):
232
- return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
166
+ raise NotImplementedError()
233
167
 
234
168
 
235
169
  class MHATokenToKVPool(KVCache):
@@ -260,10 +194,22 @@ class MHATokenToKVPool(KVCache):
260
194
 
261
195
  self.head_num = head_num
262
196
  self.head_dim = head_dim
197
+
198
+ # for disagg with nvlink
199
+ self.enable_custom_mem_pool = get_bool_env_var(
200
+ "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
201
+ )
202
+ if self.enable_custom_mem_pool:
203
+ # TODO(shangming): abstract custom allocator class for more backends
204
+ from mooncake.allocator import NVLinkAllocator
205
+
206
+ allocator = NVLinkAllocator.get_allocator(self.device)
207
+ self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
208
+ else:
209
+ self.custom_mem_pool = None
210
+
263
211
  self._create_buffers()
264
212
 
265
- # used for chunked cpu-offloading
266
- self.chunk_size = 8192
267
213
  self.layer_transfer_counter = None
268
214
  self.device_module = torch.get_device_module(self.device)
269
215
  self.alt_stream = self.device_module.Stream() if _is_cuda else None
@@ -274,25 +220,30 @@ class MHATokenToKVPool(KVCache):
274
220
  )
275
221
 
276
222
  def _create_buffers(self):
277
- with self.memory_saver_adapter.region():
278
- # [size, head_num, head_dim] for each layer
279
- # The padded slot 0 is used for writing dummy outputs from padded tokens.
280
- self.k_buffer = [
281
- torch.zeros(
282
- (self.size + self.page_size, self.head_num, self.head_dim),
283
- dtype=self.store_dtype,
284
- device=self.device,
285
- )
286
- for _ in range(self.layer_num)
287
- ]
288
- self.v_buffer = [
289
- torch.zeros(
290
- (self.size + self.page_size, self.head_num, self.head_dim),
291
- dtype=self.store_dtype,
292
- device=self.device,
293
- )
294
- for _ in range(self.layer_num)
295
- ]
223
+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
224
+ with (
225
+ torch.cuda.use_mem_pool(self.custom_mem_pool)
226
+ if self.enable_custom_mem_pool
227
+ else nullcontext()
228
+ ):
229
+ # [size, head_num, head_dim] for each layer
230
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
231
+ self.k_buffer = [
232
+ torch.zeros(
233
+ (self.size + self.page_size, self.head_num, self.head_dim),
234
+ dtype=self.store_dtype,
235
+ device=self.device,
236
+ )
237
+ for _ in range(self.layer_num)
238
+ ]
239
+ self.v_buffer = [
240
+ torch.zeros(
241
+ (self.size + self.page_size, self.head_num, self.head_dim),
242
+ dtype=self.store_dtype,
243
+ device=self.device,
244
+ )
245
+ for _ in range(self.layer_num)
246
+ ]
296
247
 
297
248
  self.data_ptrs = torch.tensor(
298
249
  [x.data_ptr() for x in self.k_buffer + self.v_buffer],
@@ -349,13 +300,17 @@ class MHATokenToKVPool(KVCache):
349
300
  ]
350
301
  return kv_data_ptrs, kv_data_lens, kv_item_lens
351
302
 
303
+ def maybe_get_custom_mem_pool(self):
304
+ return self.custom_mem_pool
305
+
352
306
  def get_cpu_copy(self, indices):
353
307
  torch.cuda.synchronize()
354
308
  kv_cache_cpu = []
309
+ chunk_size = self.cpu_offloading_chunk_size
355
310
  for layer_id in range(self.layer_num):
356
311
  kv_cache_cpu.append([])
357
- for i in range(0, len(indices), self.chunk_size):
358
- chunk_indices = indices[i : i + self.chunk_size]
312
+ for i in range(0, len(indices), chunk_size):
313
+ chunk_indices = indices[i : i + chunk_size]
359
314
  k_cpu = self.k_buffer[layer_id][chunk_indices].to(
360
315
  "cpu", non_blocking=True
361
316
  )
@@ -368,12 +323,13 @@ class MHATokenToKVPool(KVCache):
368
323
 
369
324
  def load_cpu_copy(self, kv_cache_cpu, indices):
370
325
  torch.cuda.synchronize()
326
+ chunk_size = self.cpu_offloading_chunk_size
371
327
  for layer_id in range(self.layer_num):
372
- for i in range(0, len(indices), self.chunk_size):
373
- chunk_indices = indices[i : i + self.chunk_size]
328
+ for i in range(0, len(indices), chunk_size):
329
+ chunk_indices = indices[i : i + chunk_size]
374
330
  k_cpu, v_cpu = (
375
- kv_cache_cpu[layer_id][i // self.chunk_size][0],
376
- kv_cache_cpu[layer_id][i // self.chunk_size][1],
331
+ kv_cache_cpu[layer_id][i // chunk_size][0],
332
+ kv_cache_cpu[layer_id][i // chunk_size][1],
377
333
  )
378
334
  assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
379
335
  k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
@@ -569,16 +525,34 @@ class MLATokenToKVPool(KVCache):
569
525
  self.kv_lora_rank = kv_lora_rank
570
526
  self.qk_rope_head_dim = qk_rope_head_dim
571
527
 
572
- with self.memory_saver_adapter.region():
573
- # The padded slot 0 is used for writing dummy outputs from padded tokens.
574
- self.kv_buffer = [
575
- torch.zeros(
576
- (size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
577
- dtype=self.store_dtype,
578
- device=device,
579
- )
580
- for _ in range(layer_num)
581
- ]
528
+ # for disagg with nvlink
529
+ self.enable_custom_mem_pool = get_bool_env_var(
530
+ "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
531
+ )
532
+ if self.enable_custom_mem_pool:
533
+ # TODO(shangming): abstract custom allocator class for more backends
534
+ from mooncake.allocator import NVLinkAllocator
535
+
536
+ allocator = NVLinkAllocator.get_allocator(self.device)
537
+ self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
538
+ else:
539
+ self.custom_mem_pool = None
540
+
541
+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
542
+ with (
543
+ torch.cuda.use_mem_pool(self.custom_mem_pool)
544
+ if self.custom_mem_pool
545
+ else nullcontext()
546
+ ):
547
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
548
+ self.kv_buffer = [
549
+ torch.zeros(
550
+ (size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
551
+ dtype=self.store_dtype,
552
+ device=device,
553
+ )
554
+ for _ in range(layer_num)
555
+ ]
582
556
 
583
557
  self.layer_transfer_counter = None
584
558
 
@@ -604,6 +578,9 @@ class MLATokenToKVPool(KVCache):
604
578
  ]
605
579
  return kv_data_ptrs, kv_data_lens, kv_item_lens
606
580
 
581
+ def maybe_get_custom_mem_pool(self):
582
+ return self.custom_mem_pool
583
+
607
584
  def get_key_buffer(self, layer_id: int):
608
585
  if self.layer_transfer_counter is not None:
609
586
  self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
@@ -677,6 +654,33 @@ class MLATokenToKVPool(KVCache):
677
654
  flat_data = flat_data.to(device=self.device, non_blocking=False)
678
655
  self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
679
656
 
657
+ def get_cpu_copy(self, indices):
658
+ torch.cuda.synchronize()
659
+ kv_cache_cpu = []
660
+ chunk_size = self.cpu_offloading_chunk_size
661
+ for layer_id in range(self.layer_num):
662
+ kv_cache_cpu.append([])
663
+ for i in range(0, len(indices), chunk_size):
664
+ chunk_indices = indices[i : i + chunk_size]
665
+ kv_cpu = self.kv_buffer[layer_id][chunk_indices].to(
666
+ "cpu", non_blocking=True
667
+ )
668
+ kv_cache_cpu[-1].append(kv_cpu)
669
+ torch.cuda.synchronize()
670
+ return kv_cache_cpu
671
+
672
+ def load_cpu_copy(self, kv_cache_cpu, indices):
673
+ torch.cuda.synchronize()
674
+ chunk_size = self.cpu_offloading_chunk_size
675
+ for layer_id in range(self.layer_num):
676
+ for i in range(0, len(indices), chunk_size):
677
+ chunk_indices = indices[i : i + chunk_size]
678
+ kv_cpu = kv_cache_cpu[layer_id][i // chunk_size]
679
+ assert kv_cpu.shape[0] == len(chunk_indices)
680
+ kv_chunk = kv_cpu.to(self.kv_buffer[0].device, non_blocking=True)
681
+ self.kv_buffer[layer_id][chunk_indices] = kv_chunk
682
+ torch.cuda.synchronize()
683
+
680
684
 
681
685
  class DoubleSparseTokenToKVPool(KVCache):
682
686
  def __init__(
@@ -704,7 +708,7 @@ class DoubleSparseTokenToKVPool(KVCache):
704
708
  end_layer,
705
709
  )
706
710
 
707
- with self.memory_saver_adapter.region():
711
+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
708
712
  # [size, head_num, head_dim] for each layer
709
713
  self.k_buffer = [
710
714
  torch.zeros(
@@ -23,7 +23,7 @@ import heapq
23
23
  import time
24
24
  from collections import defaultdict
25
25
  from functools import partial
26
- from typing import TYPE_CHECKING, List, Optional, Tuple
26
+ from typing import TYPE_CHECKING, List, Optional
27
27
 
28
28
  import torch
29
29
 
@@ -31,11 +31,10 @@ from sglang.srt.disaggregation.kv_events import (
31
31
  AllBlocksCleared,
32
32
  BlockRemoved,
33
33
  BlockStored,
34
- KVCacheEvent,
35
34
  )
36
- from sglang.srt.managers.schedule_batch import global_server_args_dict
37
- from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
38
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
35
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
36
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
37
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
39
38
 
40
39
  if TYPE_CHECKING:
41
40
  from sglang.srt.managers.schedule_batch import Req
@@ -47,9 +46,9 @@ class TreeNode:
47
46
 
48
47
  def __init__(self, id: Optional[int] = None):
49
48
  self.children = defaultdict(TreeNode)
50
- self.parent = None
51
- self.key = None
52
- self.value = None
49
+ self.parent: TreeNode = None
50
+ self.key: List[int] = None
51
+ self.value: Optional[torch.Tensor] = None
53
52
  self.lock_ref = 0
54
53
  self.last_access_time = time.monotonic()
55
54
 
@@ -57,7 +56,7 @@ class TreeNode:
57
56
  # indicating the node is loading KV cache from host
58
57
  self.loading = False
59
58
  # store the host indices of KV cache
60
- self.host_value = None
59
+ self.host_value: Optional[torch.Tensor] = None
61
60
 
62
61
  self.id = TreeNode.counter if id is None else id
63
62
  TreeNode.counter += 1
@@ -99,7 +98,7 @@ class RadixCache(BasePrefixCache):
99
98
  def __init__(
100
99
  self,
101
100
  req_to_token_pool: ReqToTokenPool,
102
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
101
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
103
102
  page_size: int,
104
103
  disable: bool = False,
105
104
  enable_kv_cache_events: bool = False,
@@ -135,7 +134,7 @@ class RadixCache(BasePrefixCache):
135
134
  self.protected_size_ = 0
136
135
  self._record_all_cleared_event()
137
136
 
138
- def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
137
+ def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
139
138
  """Find the matching prefix from the radix tree.
140
139
  Args:
141
140
  key: A list of token IDs to find a matching prefix.
@@ -147,13 +146,14 @@ class RadixCache(BasePrefixCache):
147
146
  than the last node's value.
148
147
  """
149
148
  if self.disable or len(key) == 0:
150
- return (
151
- torch.empty(
149
+ return MatchResult(
150
+ device_indices=torch.empty(
152
151
  (0,),
153
152
  dtype=torch.int64,
154
153
  device=self.device,
155
154
  ),
156
- self.root_node,
155
+ last_device_node=self.root_node,
156
+ last_host_node=self.root_node,
157
157
  )
158
158
 
159
159
  if self.page_size != 1:
@@ -165,7 +165,11 @@ class RadixCache(BasePrefixCache):
165
165
  value = torch.cat(value)
166
166
  else:
167
167
  value = torch.empty((0,), dtype=torch.int64, device=self.device)
168
- return value, last_node
168
+ return MatchResult(
169
+ device_indices=value,
170
+ last_device_node=last_node,
171
+ last_host_node=last_node,
172
+ )
169
173
 
170
174
  def insert(self, key: List, value=None):
171
175
  if self.disable:
@@ -235,7 +239,7 @@ class RadixCache(BasePrefixCache):
235
239
  )
236
240
 
237
241
  # The prefix indices could be updated, reuse it
238
- new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
242
+ new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
239
243
  self.req_to_token_pool.write(
240
244
  (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
241
245
  new_indices[len(req.prefix_indices) :],