sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/constants.py +3 -0
  5. sglang/srt/conversation.py +13 -3
  6. sglang/srt/custom_op.py +5 -1
  7. sglang/srt/disaggregation/decode.py +22 -28
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  9. sglang/srt/disaggregation/mini_lb.py +34 -4
  10. sglang/srt/disaggregation/mooncake/conn.py +12 -16
  11. sglang/srt/disaggregation/prefill.py +17 -13
  12. sglang/srt/disaggregation/utils.py +46 -18
  13. sglang/srt/distributed/parallel_state.py +12 -4
  14. sglang/srt/entrypoints/engine.py +22 -28
  15. sglang/srt/entrypoints/http_server.py +149 -79
  16. sglang/srt/entrypoints/http_server_engine.py +0 -3
  17. sglang/srt/entrypoints/openai/__init__.py +0 -0
  18. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
  19. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  20. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  21. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  22. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  23. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  24. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  25. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  26. sglang/srt/entrypoints/openai/utils.py +72 -0
  27. sglang/srt/function_call/base_format_detector.py +7 -4
  28. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  29. sglang/srt/function_call/ebnf_composer.py +64 -10
  30. sglang/srt/function_call/function_call_parser.py +6 -6
  31. sglang/srt/function_call/llama32_detector.py +1 -1
  32. sglang/srt/function_call/mistral_detector.py +1 -1
  33. sglang/srt/function_call/pythonic_detector.py +1 -1
  34. sglang/srt/function_call/qwen25_detector.py +1 -1
  35. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  36. sglang/srt/layers/activation.py +21 -3
  37. sglang/srt/layers/attention/aiter_backend.py +5 -2
  38. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  39. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  40. sglang/srt/layers/attention/flashattention_backend.py +19 -9
  41. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  42. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  43. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  44. sglang/srt/layers/attention/tbo_backend.py +3 -3
  45. sglang/srt/layers/attention/triton_backend.py +19 -11
  46. sglang/srt/layers/communicator.py +5 -5
  47. sglang/srt/layers/dp_attention.py +11 -2
  48. sglang/srt/layers/layernorm.py +29 -2
  49. sglang/srt/layers/logits_processor.py +2 -2
  50. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  51. sglang/srt/layers/moe/ep_moe/layer.py +207 -1
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
  54. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  55. sglang/srt/layers/moe/topk.py +91 -4
  56. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  57. sglang/srt/layers/quantization/fp8.py +25 -17
  58. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  59. sglang/srt/layers/quantization/utils.py +5 -2
  60. sglang/srt/layers/rotary_embedding.py +42 -2
  61. sglang/srt/layers/sampler.py +1 -1
  62. sglang/srt/lora/lora_manager.py +173 -74
  63. sglang/srt/lora/mem_pool.py +49 -45
  64. sglang/srt/lora/utils.py +1 -1
  65. sglang/srt/managers/cache_controller.py +33 -15
  66. sglang/srt/managers/io_struct.py +9 -12
  67. sglang/srt/managers/schedule_batch.py +40 -31
  68. sglang/srt/managers/schedule_policy.py +70 -56
  69. sglang/srt/managers/scheduler.py +147 -62
  70. sglang/srt/managers/template_manager.py +226 -0
  71. sglang/srt/managers/tokenizer_manager.py +11 -8
  72. sglang/srt/managers/tp_worker.py +12 -2
  73. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  74. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  75. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  76. sglang/srt/mem_cache/chunk_cache.py +11 -16
  77. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  78. sglang/srt/mem_cache/memory_pool.py +118 -114
  79. sglang/srt/mem_cache/radix_cache.py +20 -16
  80. sglang/srt/model_executor/cuda_graph_runner.py +76 -45
  81. sglang/srt/model_executor/forward_batch_info.py +18 -5
  82. sglang/srt/model_executor/model_runner.py +22 -6
  83. sglang/srt/model_loader/loader.py +8 -1
  84. sglang/srt/model_loader/weight_utils.py +11 -2
  85. sglang/srt/models/deepseek_nextn.py +29 -27
  86. sglang/srt/models/deepseek_v2.py +108 -26
  87. sglang/srt/models/glm4.py +312 -0
  88. sglang/srt/models/mimo_mtp.py +2 -18
  89. sglang/srt/reasoning_parser.py +21 -11
  90. sglang/srt/server_args.py +36 -8
  91. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  92. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  93. sglang/srt/speculative/eagle_utils.py +80 -8
  94. sglang/srt/speculative/eagle_worker.py +124 -41
  95. sglang/srt/torch_memory_saver_adapter.py +19 -15
  96. sglang/srt/utils.py +177 -11
  97. sglang/test/test_block_fp8_ep.py +1 -0
  98. sglang/test/test_utils.py +1 -0
  99. sglang/version.py +1 -1
  100. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
  101. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
  102. sglang/srt/entrypoints/verl_engine.py +0 -179
  103. sglang/srt/openai_api/adapter.py +0 -2148
  104. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  105. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  106. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -26,6 +26,7 @@ KVCache actually holds the physical kv cache.
26
26
 
27
27
  import abc
28
28
  import logging
29
+ from contextlib import nullcontext
29
30
  from typing import List, Optional, Tuple, Union
30
31
 
31
32
  import numpy as np
@@ -33,8 +34,9 @@ import torch
33
34
  import triton
34
35
  import triton.language as tl
35
36
 
37
+ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
36
38
  from sglang.srt.layers.radix_attention import RadixAttention
37
- from sglang.srt.utils import debug_timing, is_cuda, next_power_of_2
39
+ from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
38
40
 
39
41
  logger = logging.getLogger(__name__)
40
42
 
@@ -52,6 +54,7 @@ class ReqToTokenPool:
52
54
  device: str,
53
55
  enable_memory_saver: bool,
54
56
  ):
57
+
55
58
  memory_saver_adapter = TorchMemorySaverAdapter.create(
56
59
  enable=enable_memory_saver
57
60
  )
@@ -59,7 +62,7 @@ class ReqToTokenPool:
59
62
  self.size = size
60
63
  self.max_context_len = max_context_len
61
64
  self.device = device
62
- with memory_saver_adapter.region():
65
+ with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
63
66
  self.req_to_token = torch.zeros(
64
67
  (size, max_context_len), dtype=torch.int32, device=device
65
68
  )
@@ -119,6 +122,9 @@ class KVCache(abc.ABC):
119
122
  enable=enable_memory_saver
120
123
  )
121
124
 
125
+ # used for chunked cpu-offloading
126
+ self.cpu_offloading_chunk_size = 8192
127
+
122
128
  @abc.abstractmethod
123
129
  def get_key_buffer(self, layer_id: int) -> torch.Tensor:
124
130
  raise NotImplementedError()
@@ -153,83 +159,11 @@ class KVCache(abc.ABC):
153
159
  def register_layer_transfer_counter(self, layer_transfer_counter):
154
160
  self.layer_transfer_counter = layer_transfer_counter
155
161
 
156
-
157
- class TokenToKVPoolAllocator:
158
- """An allocator managing the indices to kv cache data."""
159
-
160
- def __init__(
161
- self,
162
- size: int,
163
- dtype: torch.dtype,
164
- device: str,
165
- kvcache: KVCache,
166
- ):
167
- self.size = size
168
- self.dtype = dtype
169
- self.device = device
170
- self.page_size = 1
171
-
172
- self.free_slots = None
173
- self.is_not_in_free_group = True
174
- self.free_group = []
175
- self.clear()
176
-
177
- self._kvcache = kvcache
178
-
179
- def available_size(self):
180
- return len(self.free_slots)
181
-
182
- def debug_print(self) -> str:
183
- return ""
184
-
185
- def get_kvcache(self):
186
- return self._kvcache
187
-
188
- def alloc(self, need_size: int):
189
- if need_size > len(self.free_slots):
190
- return None
191
-
192
- select_index = self.free_slots[:need_size]
193
- self.free_slots = self.free_slots[need_size:]
194
- return select_index
195
-
196
- def free(self, free_index: torch.Tensor):
197
- if free_index.numel() == 0:
198
- return
199
-
200
- if self.is_not_in_free_group:
201
- self.free_slots = torch.cat((self.free_slots, free_index))
202
- else:
203
- self.free_group.append(free_index)
204
-
205
- def free_group_begin(self):
206
- self.is_not_in_free_group = False
207
- self.free_group = []
208
-
209
- def free_group_end(self):
210
- self.is_not_in_free_group = True
211
- if self.free_group:
212
- self.free(torch.cat(self.free_group))
213
-
214
- def backup_state(self):
215
- return self.free_slots
216
-
217
- def restore_state(self, free_slots):
218
- self.free_slots = free_slots
219
-
220
- def clear(self):
221
- # The padded slot 0 is used for writing dummy outputs from padded tokens.
222
- self.free_slots = torch.arange(
223
- 1, self.size + 1, dtype=torch.int64, device=self.device
224
- )
225
- self.is_not_in_free_group = True
226
- self.free_group = []
227
-
228
162
  def get_cpu_copy(self, indices):
229
- return self._kvcache.get_cpu_copy(indices)
163
+ raise NotImplementedError()
230
164
 
231
165
  def load_cpu_copy(self, kv_cache_cpu, indices):
232
- return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
166
+ raise NotImplementedError()
233
167
 
234
168
 
235
169
  class MHATokenToKVPool(KVCache):
@@ -260,10 +194,22 @@ class MHATokenToKVPool(KVCache):
260
194
 
261
195
  self.head_num = head_num
262
196
  self.head_dim = head_dim
197
+
198
+ # for disagg with nvlink
199
+ self.enable_custom_mem_pool = get_bool_env_var(
200
+ "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
201
+ )
202
+ if self.enable_custom_mem_pool:
203
+ # TODO(shangming): abstract custom allocator class for more backends
204
+ from mooncake.allocator import NVLinkAllocator
205
+
206
+ allocator = NVLinkAllocator.get_allocator(self.device)
207
+ self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
208
+ else:
209
+ self.custom_mem_pool = None
210
+
263
211
  self._create_buffers()
264
212
 
265
- # used for chunked cpu-offloading
266
- self.chunk_size = 8192
267
213
  self.layer_transfer_counter = None
268
214
  self.device_module = torch.get_device_module(self.device)
269
215
  self.alt_stream = self.device_module.Stream() if _is_cuda else None
@@ -274,25 +220,30 @@ class MHATokenToKVPool(KVCache):
274
220
  )
275
221
 
276
222
  def _create_buffers(self):
277
- with self.memory_saver_adapter.region():
278
- # [size, head_num, head_dim] for each layer
279
- # The padded slot 0 is used for writing dummy outputs from padded tokens.
280
- self.k_buffer = [
281
- torch.zeros(
282
- (self.size + self.page_size, self.head_num, self.head_dim),
283
- dtype=self.store_dtype,
284
- device=self.device,
285
- )
286
- for _ in range(self.layer_num)
287
- ]
288
- self.v_buffer = [
289
- torch.zeros(
290
- (self.size + self.page_size, self.head_num, self.head_dim),
291
- dtype=self.store_dtype,
292
- device=self.device,
293
- )
294
- for _ in range(self.layer_num)
295
- ]
223
+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
224
+ with (
225
+ torch.cuda.use_mem_pool(self.custom_mem_pool)
226
+ if self.enable_custom_mem_pool
227
+ else nullcontext()
228
+ ):
229
+ # [size, head_num, head_dim] for each layer
230
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
231
+ self.k_buffer = [
232
+ torch.zeros(
233
+ (self.size + self.page_size, self.head_num, self.head_dim),
234
+ dtype=self.store_dtype,
235
+ device=self.device,
236
+ )
237
+ for _ in range(self.layer_num)
238
+ ]
239
+ self.v_buffer = [
240
+ torch.zeros(
241
+ (self.size + self.page_size, self.head_num, self.head_dim),
242
+ dtype=self.store_dtype,
243
+ device=self.device,
244
+ )
245
+ for _ in range(self.layer_num)
246
+ ]
296
247
 
297
248
  self.data_ptrs = torch.tensor(
298
249
  [x.data_ptr() for x in self.k_buffer + self.v_buffer],
@@ -349,13 +300,17 @@ class MHATokenToKVPool(KVCache):
349
300
  ]
350
301
  return kv_data_ptrs, kv_data_lens, kv_item_lens
351
302
 
303
+ def maybe_get_custom_mem_pool(self):
304
+ return self.custom_mem_pool
305
+
352
306
  def get_cpu_copy(self, indices):
353
307
  torch.cuda.synchronize()
354
308
  kv_cache_cpu = []
309
+ chunk_size = self.cpu_offloading_chunk_size
355
310
  for layer_id in range(self.layer_num):
356
311
  kv_cache_cpu.append([])
357
- for i in range(0, len(indices), self.chunk_size):
358
- chunk_indices = indices[i : i + self.chunk_size]
312
+ for i in range(0, len(indices), chunk_size):
313
+ chunk_indices = indices[i : i + chunk_size]
359
314
  k_cpu = self.k_buffer[layer_id][chunk_indices].to(
360
315
  "cpu", non_blocking=True
361
316
  )
@@ -368,12 +323,13 @@ class MHATokenToKVPool(KVCache):
368
323
 
369
324
  def load_cpu_copy(self, kv_cache_cpu, indices):
370
325
  torch.cuda.synchronize()
326
+ chunk_size = self.cpu_offloading_chunk_size
371
327
  for layer_id in range(self.layer_num):
372
- for i in range(0, len(indices), self.chunk_size):
373
- chunk_indices = indices[i : i + self.chunk_size]
328
+ for i in range(0, len(indices), chunk_size):
329
+ chunk_indices = indices[i : i + chunk_size]
374
330
  k_cpu, v_cpu = (
375
- kv_cache_cpu[layer_id][i // self.chunk_size][0],
376
- kv_cache_cpu[layer_id][i // self.chunk_size][1],
331
+ kv_cache_cpu[layer_id][i // chunk_size][0],
332
+ kv_cache_cpu[layer_id][i // chunk_size][1],
377
333
  )
378
334
  assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
379
335
  k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
@@ -569,16 +525,34 @@ class MLATokenToKVPool(KVCache):
569
525
  self.kv_lora_rank = kv_lora_rank
570
526
  self.qk_rope_head_dim = qk_rope_head_dim
571
527
 
572
- with self.memory_saver_adapter.region():
573
- # The padded slot 0 is used for writing dummy outputs from padded tokens.
574
- self.kv_buffer = [
575
- torch.zeros(
576
- (size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
577
- dtype=self.store_dtype,
578
- device=device,
579
- )
580
- for _ in range(layer_num)
581
- ]
528
+ # for disagg with nvlink
529
+ self.enable_custom_mem_pool = get_bool_env_var(
530
+ "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
531
+ )
532
+ if self.enable_custom_mem_pool:
533
+ # TODO(shangming): abstract custom allocator class for more backends
534
+ from mooncake.allocator import NVLinkAllocator
535
+
536
+ allocator = NVLinkAllocator.get_allocator(self.device)
537
+ self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
538
+ else:
539
+ self.custom_mem_pool = None
540
+
541
+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
542
+ with (
543
+ torch.cuda.use_mem_pool(self.custom_mem_pool)
544
+ if self.custom_mem_pool
545
+ else nullcontext()
546
+ ):
547
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
548
+ self.kv_buffer = [
549
+ torch.zeros(
550
+ (size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
551
+ dtype=self.store_dtype,
552
+ device=device,
553
+ )
554
+ for _ in range(layer_num)
555
+ ]
582
556
 
583
557
  self.layer_transfer_counter = None
584
558
 
@@ -604,6 +578,9 @@ class MLATokenToKVPool(KVCache):
604
578
  ]
605
579
  return kv_data_ptrs, kv_data_lens, kv_item_lens
606
580
 
581
+ def maybe_get_custom_mem_pool(self):
582
+ return self.custom_mem_pool
583
+
607
584
  def get_key_buffer(self, layer_id: int):
608
585
  if self.layer_transfer_counter is not None:
609
586
  self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
@@ -677,6 +654,33 @@ class MLATokenToKVPool(KVCache):
677
654
  flat_data = flat_data.to(device=self.device, non_blocking=False)
678
655
  self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
679
656
 
657
+ def get_cpu_copy(self, indices):
658
+ torch.cuda.synchronize()
659
+ kv_cache_cpu = []
660
+ chunk_size = self.cpu_offloading_chunk_size
661
+ for layer_id in range(self.layer_num):
662
+ kv_cache_cpu.append([])
663
+ for i in range(0, len(indices), chunk_size):
664
+ chunk_indices = indices[i : i + chunk_size]
665
+ kv_cpu = self.kv_buffer[layer_id][chunk_indices].to(
666
+ "cpu", non_blocking=True
667
+ )
668
+ kv_cache_cpu[-1].append(kv_cpu)
669
+ torch.cuda.synchronize()
670
+ return kv_cache_cpu
671
+
672
+ def load_cpu_copy(self, kv_cache_cpu, indices):
673
+ torch.cuda.synchronize()
674
+ chunk_size = self.cpu_offloading_chunk_size
675
+ for layer_id in range(self.layer_num):
676
+ for i in range(0, len(indices), chunk_size):
677
+ chunk_indices = indices[i : i + chunk_size]
678
+ kv_cpu = kv_cache_cpu[layer_id][i // chunk_size]
679
+ assert kv_cpu.shape[0] == len(chunk_indices)
680
+ kv_chunk = kv_cpu.to(self.kv_buffer[0].device, non_blocking=True)
681
+ self.kv_buffer[layer_id][chunk_indices] = kv_chunk
682
+ torch.cuda.synchronize()
683
+
680
684
 
681
685
  class DoubleSparseTokenToKVPool(KVCache):
682
686
  def __init__(
@@ -704,7 +708,7 @@ class DoubleSparseTokenToKVPool(KVCache):
704
708
  end_layer,
705
709
  )
706
710
 
707
- with self.memory_saver_adapter.region():
711
+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
708
712
  # [size, head_num, head_dim] for each layer
709
713
  self.k_buffer = [
710
714
  torch.zeros(
@@ -23,7 +23,7 @@ import heapq
23
23
  import time
24
24
  from collections import defaultdict
25
25
  from functools import partial
26
- from typing import TYPE_CHECKING, List, Optional, Tuple
26
+ from typing import TYPE_CHECKING, List, Optional
27
27
 
28
28
  import torch
29
29
 
@@ -31,11 +31,10 @@ from sglang.srt.disaggregation.kv_events import (
31
31
  AllBlocksCleared,
32
32
  BlockRemoved,
33
33
  BlockStored,
34
- KVCacheEvent,
35
34
  )
36
- from sglang.srt.managers.schedule_batch import global_server_args_dict
37
- from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
38
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
35
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
36
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
37
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
39
38
 
40
39
  if TYPE_CHECKING:
41
40
  from sglang.srt.managers.schedule_batch import Req
@@ -47,9 +46,9 @@ class TreeNode:
47
46
 
48
47
  def __init__(self, id: Optional[int] = None):
49
48
  self.children = defaultdict(TreeNode)
50
- self.parent = None
51
- self.key = None
52
- self.value = None
49
+ self.parent: TreeNode = None
50
+ self.key: List[int] = None
51
+ self.value: Optional[torch.Tensor] = None
53
52
  self.lock_ref = 0
54
53
  self.last_access_time = time.monotonic()
55
54
 
@@ -57,7 +56,7 @@ class TreeNode:
57
56
  # indicating the node is loading KV cache from host
58
57
  self.loading = False
59
58
  # store the host indices of KV cache
60
- self.host_value = None
59
+ self.host_value: Optional[torch.Tensor] = None
61
60
 
62
61
  self.id = TreeNode.counter if id is None else id
63
62
  TreeNode.counter += 1
@@ -99,7 +98,7 @@ class RadixCache(BasePrefixCache):
99
98
  def __init__(
100
99
  self,
101
100
  req_to_token_pool: ReqToTokenPool,
102
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
101
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
103
102
  page_size: int,
104
103
  disable: bool = False,
105
104
  enable_kv_cache_events: bool = False,
@@ -135,7 +134,7 @@ class RadixCache(BasePrefixCache):
135
134
  self.protected_size_ = 0
136
135
  self._record_all_cleared_event()
137
136
 
138
- def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
137
+ def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
139
138
  """Find the matching prefix from the radix tree.
140
139
  Args:
141
140
  key: A list of token IDs to find a matching prefix.
@@ -147,13 +146,14 @@ class RadixCache(BasePrefixCache):
147
146
  than the last node's value.
148
147
  """
149
148
  if self.disable or len(key) == 0:
150
- return (
151
- torch.empty(
149
+ return MatchResult(
150
+ device_indices=torch.empty(
152
151
  (0,),
153
152
  dtype=torch.int64,
154
153
  device=self.device,
155
154
  ),
156
- self.root_node,
155
+ last_device_node=self.root_node,
156
+ last_host_node=self.root_node,
157
157
  )
158
158
 
159
159
  if self.page_size != 1:
@@ -165,7 +165,11 @@ class RadixCache(BasePrefixCache):
165
165
  value = torch.cat(value)
166
166
  else:
167
167
  value = torch.empty((0,), dtype=torch.int64, device=self.device)
168
- return value, last_node
168
+ return MatchResult(
169
+ device_indices=value,
170
+ last_device_node=last_node,
171
+ last_host_node=last_node,
172
+ )
169
173
 
170
174
  def insert(self, key: List, value=None):
171
175
  if self.disable:
@@ -235,7 +239,7 @@ class RadixCache(BasePrefixCache):
235
239
  )
236
240
 
237
241
  # The prefix indices could be updated, reuse it
238
- new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
242
+ new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
239
243
  self.req_to_token_pool.write(
240
244
  (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
241
245
  new_indices[len(req.prefix_indices) :],
@@ -46,6 +46,10 @@ from sglang.srt.utils import (
46
46
  get_available_gpu_memory,
47
47
  get_device_memory_capacity,
48
48
  rank0_log,
49
+ require_attn_tp_gather,
50
+ require_gathered_buffer,
51
+ require_mlp_sync,
52
+ require_mlp_tp_gather,
49
53
  )
50
54
 
51
55
  logger = logging.getLogger(__name__)
@@ -207,8 +211,10 @@ class CudaGraphRunner:
207
211
  self.enable_torch_compile = model_runner.server_args.enable_torch_compile
208
212
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
209
213
  self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
210
- self.enable_dp_attention = model_runner.server_args.enable_dp_attention
211
- self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
214
+ self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
215
+ self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
216
+ self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
217
+ self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
212
218
  self.enable_two_batch_overlap = (
213
219
  model_runner.server_args.enable_two_batch_overlap
214
220
  )
@@ -242,13 +248,13 @@ class CudaGraphRunner:
242
248
  # Attention backend
243
249
  self.max_bs = max(self.capture_bs)
244
250
  self.max_num_token = self.max_bs * self.num_tokens_per_bs
245
- if global_server_args_dict["attention_backend"] == "flashmla":
246
- self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
247
- else:
248
- self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
251
+ self.model_runner.attn_backend.init_cuda_graph_state(
252
+ self.max_bs, self.max_num_token
253
+ )
249
254
  self.seq_len_fill_value = (
250
255
  self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
251
256
  )
257
+
252
258
  # FIXME(lsyin): leave it here for now, I don't know whether it is necessary
253
259
  self.encoder_len_fill_value = 0
254
260
  self.seq_lens_cpu = torch.full(
@@ -299,18 +305,30 @@ class CudaGraphRunner:
299
305
  else:
300
306
  self.encoder_lens = None
301
307
 
302
- if self.enable_dp_attention or self.enable_sp_layernorm:
303
- # TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
308
+ if self.require_gathered_buffer:
304
309
  self.gathered_buffer = torch.zeros(
305
310
  (
306
- self.max_bs * self.dp_size * self.num_tokens_per_bs,
311
+ self.max_num_token,
307
312
  self.model_runner.model_config.hidden_size,
308
313
  ),
309
314
  dtype=self.model_runner.dtype,
310
315
  )
311
- self.global_num_tokens_gpu = torch.zeros(
312
- (self.dp_size,), dtype=torch.int32
313
- )
316
+ if self.require_mlp_tp_gather:
317
+ self.global_num_tokens_gpu = torch.zeros(
318
+ (self.dp_size,), dtype=torch.int32
319
+ )
320
+ else:
321
+ assert self.require_attn_tp_gather
322
+ self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
323
+
324
+ self.custom_mask = torch.ones(
325
+ (
326
+ (self.seq_lens.sum().item() + self.max_num_token)
327
+ * self.num_tokens_per_bs
328
+ ),
329
+ dtype=torch.bool,
330
+ device="cuda",
331
+ )
314
332
 
315
333
  # Capture
316
334
  try:
@@ -322,20 +340,23 @@ class CudaGraphRunner:
322
340
  )
323
341
 
324
342
  def can_run(self, forward_batch: ForwardBatch):
325
- if self.enable_dp_attention or self.enable_sp_layernorm:
326
- total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
327
-
328
- is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
329
- total_global_tokens in self.graphs
330
- if self.disable_padding
331
- else total_global_tokens <= self.max_bs
343
+ if self.require_mlp_tp_gather:
344
+ cuda_graph_bs = (
345
+ sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
346
+ if self.model_runner.spec_algorithm.is_eagle()
347
+ else sum(forward_batch.global_num_tokens_cpu)
332
348
  )
333
349
  else:
334
- is_bs_supported = (
335
- forward_batch.batch_size in self.graphs
336
- if self.disable_padding
337
- else forward_batch.batch_size <= self.max_bs
338
- )
350
+ cuda_graph_bs = forward_batch.batch_size
351
+
352
+ is_bs_supported = (
353
+ cuda_graph_bs in self.graphs
354
+ if self.disable_padding
355
+ else cuda_graph_bs <= self.max_bs
356
+ )
357
+
358
+ if self.require_mlp_sync:
359
+ is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
339
360
 
340
361
  # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
341
362
  # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
@@ -456,11 +477,11 @@ class CudaGraphRunner:
456
477
  {k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
457
478
  )
458
479
 
459
- if self.enable_dp_attention or self.enable_sp_layernorm:
480
+ if self.require_mlp_tp_gather:
460
481
  self.global_num_tokens_gpu.copy_(
461
482
  torch.tensor(
462
483
  [
463
- num_tokens // self.dp_size + (i < bs % self.dp_size)
484
+ num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
464
485
  for i in range(self.dp_size)
465
486
  ],
466
487
  dtype=torch.int32,
@@ -469,6 +490,16 @@ class CudaGraphRunner:
469
490
  )
470
491
  global_num_tokens = self.global_num_tokens_gpu
471
492
  gathered_buffer = self.gathered_buffer[:num_tokens]
493
+ elif self.require_attn_tp_gather:
494
+ self.global_num_tokens_gpu.copy_(
495
+ torch.tensor(
496
+ [num_tokens],
497
+ dtype=torch.int32,
498
+ device=input_ids.device,
499
+ )
500
+ )
501
+ global_num_tokens = self.global_num_tokens_gpu
502
+ gathered_buffer = self.gathered_buffer[:num_tokens]
472
503
  else:
473
504
  global_num_tokens = None
474
505
  gathered_buffer = None
@@ -604,15 +635,18 @@ class CudaGraphRunner:
604
635
  raw_num_token = raw_bs * self.num_tokens_per_bs
605
636
 
606
637
  # Pad
607
- if self.enable_dp_attention or self.enable_sp_layernorm:
608
- index = bisect.bisect_left(
609
- self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
638
+ if self.require_mlp_tp_gather:
639
+ total_batch_size = (
640
+ sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
641
+ if self.model_runner.spec_algorithm.is_eagle()
642
+ else sum(forward_batch.global_num_tokens_cpu)
610
643
  )
644
+ index = bisect.bisect_left(self.capture_bs, total_batch_size)
611
645
  else:
612
646
  index = bisect.bisect_left(self.capture_bs, raw_bs)
613
647
  bs = self.capture_bs[index]
614
648
  if bs != raw_bs:
615
- self.seq_lens.fill_(1)
649
+ self.seq_lens.fill_(self.seq_len_fill_value)
616
650
  self.out_cache_loc.zero_()
617
651
 
618
652
  # Common inputs
@@ -624,7 +658,7 @@ class CudaGraphRunner:
624
658
 
625
659
  if forward_batch.seq_lens_cpu is not None:
626
660
  if bs != raw_bs:
627
- self.seq_lens_cpu.fill_(1)
661
+ self.seq_lens_cpu.fill_(self.seq_len_fill_value)
628
662
  self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
629
663
 
630
664
  if pp_proxy_tensors:
@@ -636,27 +670,28 @@ class CudaGraphRunner:
636
670
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
637
671
  if forward_batch.mrope_positions is not None:
638
672
  self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
639
- if self.enable_dp_attention or self.enable_sp_layernorm:
673
+ if self.require_gathered_buffer:
640
674
  self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
641
675
  if enable_num_token_non_padded(self.model_runner.server_args):
642
676
  self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
643
677
  if self.enable_two_batch_overlap:
644
678
  self.tbo_plugin.replay_prepare(
645
- forward_mode=forward_batch.forward_mode,
679
+ forward_mode=self.capture_forward_mode,
646
680
  bs=bs,
647
681
  num_token_non_padded=len(forward_batch.input_ids),
648
682
  )
649
-
683
+ if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
684
+ forward_batch.spec_info.custom_mask = self.custom_mask
650
685
  # Attention backend
651
686
  self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
652
687
  bs,
653
- self.req_pool_indices,
654
- self.seq_lens,
655
- forward_batch.seq_lens_sum + (bs - raw_bs),
656
- self.encoder_lens,
657
- forward_batch.forward_mode,
688
+ self.req_pool_indices[:bs],
689
+ self.seq_lens[:bs],
690
+ forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
691
+ self.encoder_lens[:bs] if self.is_encoder_decoder else None,
692
+ self.capture_forward_mode,
658
693
  forward_batch.spec_info,
659
- seq_lens_cpu=self.seq_lens_cpu,
694
+ seq_lens_cpu=self.seq_lens_cpu[:bs],
660
695
  )
661
696
 
662
697
  # Store fields
@@ -704,11 +739,7 @@ class CudaGraphRunner:
704
739
  else:
705
740
  spec_info = EagleVerifyInput(
706
741
  draft_token=None,
707
- custom_mask=torch.ones(
708
- (num_tokens * self.model_runner.model_config.context_len),
709
- dtype=torch.bool,
710
- device="cuda",
711
- ),
742
+ custom_mask=self.custom_mask,
712
743
  positions=None,
713
744
  retrive_index=None,
714
745
  retrive_next_token=None,