sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -92,7 +92,7 @@ def build_tree_kernel_efficient(
92
92
  sgl_build_tree_kernel_efficient(
93
93
  parent_list,
94
94
  top_scores_index,
95
- seq_lens.to(torch.int32),
95
+ seq_lens,
96
96
  tree_mask,
97
97
  positions,
98
98
  retrive_index,
@@ -20,6 +20,12 @@ from sglang.srt.model_executor.forward_batch_info import (
20
20
  ForwardMode,
21
21
  )
22
22
  from sglang.srt.speculative.eagle_utils import EagleDraftInput
23
+ from sglang.srt.utils import (
24
+ require_attn_tp_gather,
25
+ require_gathered_buffer,
26
+ require_mlp_sync,
27
+ require_mlp_tp_gather,
28
+ )
23
29
 
24
30
  if TYPE_CHECKING:
25
31
  from sglang.srt.speculative.eagle_worker import EAGLEWorker
@@ -38,9 +44,18 @@ class EAGLEDraftCudaGraphRunner:
38
44
  self.output_buffers = {}
39
45
  self.enable_torch_compile = model_runner.server_args.enable_torch_compile
40
46
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
47
+ self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
48
+ self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
49
+ self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
50
+ self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
51
+ self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
52
+ self.dp_size = self.model_runner.dp_size
41
53
  self.tp_size = self.model_runner.tp_size
42
54
  self.topk = model_runner.server_args.speculative_eagle_topk
43
55
  self.speculative_num_steps = model_runner.server_args.speculative_num_steps
56
+ self.enable_profile_cuda_graph = (
57
+ model_runner.server_args.enable_profile_cuda_graph
58
+ )
44
59
  server_args = model_runner.server_args
45
60
 
46
61
  # Batch sizes to capture
@@ -50,7 +65,9 @@ class EAGLEDraftCudaGraphRunner:
50
65
  # Attention backend
51
66
  self.max_bs = max(self.capture_bs)
52
67
  self.max_num_token = self.max_bs * self.num_tokens_per_bs
53
- self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token)
68
+ self.model_runner.draft_attn_backend.init_cuda_graph_state(
69
+ self.max_bs, self.max_num_token
70
+ )
54
71
  self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
55
72
  0
56
73
  ].get_cuda_graph_seq_len_fill_value()
@@ -75,10 +92,32 @@ class EAGLEDraftCudaGraphRunner:
75
92
  self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
76
93
  self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
77
94
  self.hidden_states = torch.zeros(
78
- (self.max_num_token, self.model_runner.model_config.hidden_size),
95
+ (self.max_bs, self.model_runner.model_config.hidden_size),
79
96
  dtype=self.model_runner.dtype,
80
97
  )
81
98
 
99
+ if self.require_gathered_buffer:
100
+ self.gathered_buffer = torch.zeros(
101
+ (
102
+ self.max_num_token,
103
+ self.model_runner.model_config.hidden_size,
104
+ ),
105
+ dtype=self.model_runner.dtype,
106
+ )
107
+ if self.require_mlp_tp_gather:
108
+ self.global_num_tokens_gpu = torch.zeros(
109
+ (self.dp_size,), dtype=torch.int32
110
+ )
111
+ self.global_num_tokens_for_logprob_gpu = torch.zeros(
112
+ (self.dp_size,), dtype=torch.int32
113
+ )
114
+ else:
115
+ assert self.require_attn_tp_gather
116
+ self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
117
+ self.global_num_tokens_for_logprob_gpu = torch.zeros(
118
+ (1,), dtype=torch.int32
119
+ )
120
+
82
121
  # Capture
83
122
  try:
84
123
  with model_capture_mode():
@@ -89,11 +128,24 @@ class EAGLEDraftCudaGraphRunner:
89
128
  )
90
129
 
91
130
  def can_run(self, forward_batch: ForwardBatch):
131
+ if self.require_mlp_tp_gather:
132
+ cuda_graph_bs = (
133
+ sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
134
+ if self.model_runner.spec_algorithm.is_eagle()
135
+ else sum(forward_batch.global_num_tokens_cpu)
136
+ )
137
+ else:
138
+ cuda_graph_bs = forward_batch.batch_size
139
+
92
140
  is_bs_supported = (
93
- forward_batch.batch_size in self.graphs
141
+ cuda_graph_bs in self.graphs
94
142
  if self.disable_padding
95
- else forward_batch.batch_size <= self.max_bs
143
+ else cuda_graph_bs <= self.max_bs
96
144
  )
145
+
146
+ if self.require_mlp_sync:
147
+ is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
148
+
97
149
  return is_bs_supported
98
150
 
99
151
  def capture(self):
@@ -113,10 +165,58 @@ class EAGLEDraftCudaGraphRunner:
113
165
  topk_index = self.topk_index[:num_seqs]
114
166
  hidden_states = self.hidden_states[:num_seqs]
115
167
 
168
+ if self.require_mlp_tp_gather:
169
+ self.global_num_tokens_gpu.copy_(
170
+ torch.tensor(
171
+ [
172
+ num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
173
+ for i in range(self.dp_size)
174
+ ],
175
+ dtype=torch.int32,
176
+ device=self.input_ids.device,
177
+ )
178
+ )
179
+ self.global_num_tokens_for_logprob_gpu.copy_(
180
+ torch.tensor(
181
+ [
182
+ num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
183
+ for i in range(self.dp_size)
184
+ ],
185
+ dtype=torch.int32,
186
+ device=self.input_ids.device,
187
+ )
188
+ )
189
+ global_num_tokens = self.global_num_tokens_gpu
190
+ gathered_buffer = self.gathered_buffer[:num_tokens]
191
+ global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
192
+ elif self.require_attn_tp_gather:
193
+ self.global_num_tokens_gpu.copy_(
194
+ torch.tensor(
195
+ [num_tokens],
196
+ dtype=torch.int32,
197
+ device=self.input_ids.device,
198
+ )
199
+ )
200
+ self.global_num_tokens_for_logprob_gpu.copy_(
201
+ torch.tensor(
202
+ [num_tokens],
203
+ dtype=torch.int32,
204
+ device=self.input_ids.device,
205
+ )
206
+ )
207
+ global_num_tokens = self.global_num_tokens_gpu
208
+ gathered_buffer = self.gathered_buffer[:num_tokens]
209
+ global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
210
+ else:
211
+ global_num_tokens = None
212
+ gathered_buffer = None
213
+ global_num_tokens_for_logprob = None
214
+
116
215
  spec_info = EagleDraftInput(
117
216
  topk_p=topk_p,
118
217
  topk_index=topk_index,
119
218
  hidden_states=hidden_states,
219
+ capture_hidden_mode=CaptureHiddenMode.LAST,
120
220
  )
121
221
 
122
222
  # Forward batch
@@ -132,11 +232,14 @@ class EAGLEDraftCudaGraphRunner:
132
232
  seq_lens_sum=seq_lens.sum().item(),
133
233
  return_logprob=False,
134
234
  positions=positions,
235
+ global_num_tokens_gpu=global_num_tokens,
236
+ gathered_buffer=gathered_buffer,
135
237
  spec_algorithm=self.model_runner.spec_algorithm,
136
238
  spec_info=spec_info,
137
239
  capture_hidden_mode=(
138
240
  spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
139
241
  ),
242
+ global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
140
243
  )
141
244
 
142
245
  # Attention backend
@@ -146,6 +249,9 @@ class EAGLEDraftCudaGraphRunner:
146
249
 
147
250
  # Run and capture
148
251
  def run_once():
252
+ # Clean intermediate result cache for DP attention
253
+ forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
254
+
149
255
  # Backup two fields, which will be modified in-place in `draft_forward`.
150
256
  output_cache_loc_backup = forward_batch.out_cache_loc
151
257
  hidden_states_backup = forward_batch.spec_info.hidden_states
@@ -183,12 +289,19 @@ class EAGLEDraftCudaGraphRunner:
183
289
  raw_num_token = raw_bs * self.num_tokens_per_bs
184
290
 
185
291
  # Pad
186
- index = bisect.bisect_left(self.capture_bs, raw_bs)
292
+ if self.require_mlp_tp_gather:
293
+ total_batch_size = (
294
+ sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
295
+ if self.model_runner.spec_algorithm.is_eagle()
296
+ else sum(forward_batch.global_num_tokens_cpu)
297
+ )
298
+ index = bisect.bisect_left(self.capture_bs, total_batch_size)
299
+ else:
300
+ index = bisect.bisect_left(self.capture_bs, raw_bs)
187
301
  bs = self.capture_bs[index]
188
302
  if bs != raw_bs:
189
- self.seq_lens.fill_(1)
303
+ self.seq_lens.fill_(self.seq_len_fill_value)
190
304
  self.out_cache_loc.zero_()
191
- self.positions.zero_()
192
305
 
193
306
  num_tokens = bs * self.num_tokens_per_bs
194
307
 
@@ -203,6 +316,13 @@ class EAGLEDraftCudaGraphRunner:
203
316
  self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
204
317
  self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
205
318
 
319
+ if self.require_gathered_buffer:
320
+ self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
321
+ self.global_num_tokens_for_logprob_gpu.copy_(
322
+ forward_batch.global_num_tokens_for_logprob_gpu
323
+ )
324
+ forward_batch.gathered_buffer = self.gathered_buffer
325
+
206
326
  # Attention backend
207
327
  if bs != raw_bs:
208
328
  forward_batch.batch_size = bs
@@ -211,14 +331,16 @@ class EAGLEDraftCudaGraphRunner:
211
331
  forward_batch.positions = self.positions[:num_tokens]
212
332
 
213
333
  # Special handle for seq_len_cpu used when flashinfer mla is used
214
- if forward_batch.seq_lens_cpu is not None and bs != raw_bs:
215
- self.seq_lens_cpu.fill_(1)
334
+ if forward_batch.seq_lens_cpu is not None:
335
+ if bs != raw_bs:
336
+ self.seq_lens_cpu.fill_(self.seq_len_fill_value)
216
337
  self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
217
338
  forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
218
339
 
219
340
  self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
220
341
  forward_batch, bs
221
342
  )
343
+ # TODO: The forward_batch.seq_len_sum might need to be updated to reflect the padding in the cuda graph
222
344
 
223
345
  # Replay
224
346
  self.graphs[bs].replay()
@@ -21,6 +21,12 @@ from sglang.srt.model_executor.forward_batch_info import (
21
21
  ForwardMode,
22
22
  )
23
23
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
24
+ from sglang.srt.utils import (
25
+ require_attn_tp_gather,
26
+ require_gathered_buffer,
27
+ require_mlp_sync,
28
+ require_mlp_tp_gather,
29
+ )
24
30
 
25
31
  if TYPE_CHECKING:
26
32
  from sglang.srt.speculative.eagle_worker import EAGLEWorker
@@ -35,10 +41,17 @@ class EAGLEDraftExtendCudaGraphRunner:
35
41
  self.output_buffers = {}
36
42
  self.enable_torch_compile = model_runner.server_args.enable_torch_compile
37
43
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
44
+ self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
45
+ self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
46
+ self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
47
+ self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
38
48
  self.tp_size = self.model_runner.tp_size
39
49
  self.dp_size = model_runner.server_args.dp_size
40
50
  self.speculative_num_steps = model_runner.server_args.speculative_num_steps
41
51
  self.topk = model_runner.server_args.speculative_eagle_topk
52
+ self.enable_profile_cuda_graph = (
53
+ model_runner.server_args.enable_profile_cuda_graph
54
+ )
42
55
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
43
56
  self.padded_static_len = -1
44
57
 
@@ -48,7 +61,7 @@ class EAGLEDraftExtendCudaGraphRunner:
48
61
  self.max_num_token = self.max_bs * self.num_tokens_per_bs
49
62
 
50
63
  self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state(
51
- self.max_num_token
64
+ self.max_bs, self.max_num_token
52
65
  )
53
66
  self.seq_len_fill_value = (
54
67
  self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value()
@@ -83,10 +96,31 @@ class EAGLEDraftExtendCudaGraphRunner:
83
96
 
84
97
  self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
85
98
  self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
86
- self.accept_length = (
87
- torch.ones((self.max_bs,), dtype=torch.int32) * self.num_tokens_per_bs
99
+ self.accept_length = torch.full(
100
+ (self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
88
101
  )
89
102
 
103
+ if self.require_gathered_buffer:
104
+ self.gathered_buffer = torch.zeros(
105
+ (
106
+ self.max_num_token,
107
+ self.model_runner.model_config.hidden_size,
108
+ ),
109
+ dtype=self.model_runner.dtype,
110
+ )
111
+ if self.require_mlp_tp_gather:
112
+ self.global_num_tokens_gpu = torch.zeros(
113
+ (self.dp_size,), dtype=torch.int32
114
+ )
115
+ self.global_num_tokens_for_logprob_gpu = torch.zeros(
116
+ (self.dp_size,), dtype=torch.int32
117
+ )
118
+ else:
119
+ assert self.require_attn_tp_gather
120
+ self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
121
+ self.global_num_tokens_for_logprob_gpu = torch.zeros(
122
+ (1,), dtype=torch.int32
123
+ )
90
124
  # Capture
91
125
  try:
92
126
  with model_capture_mode():
@@ -97,14 +131,24 @@ class EAGLEDraftExtendCudaGraphRunner:
97
131
  )
98
132
 
99
133
  def can_run(self, forward_batch: ForwardBatch):
100
- batch_size = forward_batch.seq_lens.numel()
134
+ if self.require_mlp_tp_gather:
135
+ cuda_graph_bs = (
136
+ sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
137
+ if self.model_runner.spec_algorithm.is_eagle()
138
+ else sum(forward_batch.global_num_tokens_cpu)
139
+ )
140
+ else:
141
+ cuda_graph_bs = forward_batch.seq_lens.numel()
101
142
 
102
143
  is_bs_supported = (
103
- batch_size in self.graphs
144
+ cuda_graph_bs in self.graphs
104
145
  if self.disable_padding
105
- else batch_size <= self.max_bs
146
+ else cuda_graph_bs <= self.max_bs
106
147
  )
107
148
 
149
+ if self.require_mlp_sync:
150
+ is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
151
+
108
152
  return is_bs_supported
109
153
 
110
154
  def capture(self):
@@ -125,6 +169,53 @@ class EAGLEDraftExtendCudaGraphRunner:
125
169
  positions = self.positions[:num_tokens]
126
170
  hidden_states = self.hidden_states[:num_tokens]
127
171
 
172
+ if self.require_mlp_tp_gather:
173
+ self.global_num_tokens_gpu.copy_(
174
+ torch.tensor(
175
+ [
176
+ num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
177
+ for i in range(self.dp_size)
178
+ ],
179
+ dtype=torch.int32,
180
+ device=self.input_ids.device,
181
+ )
182
+ )
183
+ self.global_num_tokens_for_logprob_gpu.copy_(
184
+ torch.tensor(
185
+ [
186
+ num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
187
+ for i in range(self.dp_size)
188
+ ],
189
+ dtype=torch.int32,
190
+ device=self.input_ids.device,
191
+ )
192
+ )
193
+ global_num_tokens = self.global_num_tokens_gpu
194
+ gathered_buffer = self.gathered_buffer[:num_tokens]
195
+ global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
196
+ elif self.require_attn_tp_gather:
197
+ self.global_num_tokens_gpu.copy_(
198
+ torch.tensor(
199
+ [num_tokens],
200
+ dtype=torch.int32,
201
+ device=self.input_ids.device,
202
+ )
203
+ )
204
+ self.global_num_tokens_for_logprob_gpu.copy_(
205
+ torch.tensor(
206
+ [num_tokens],
207
+ dtype=torch.int32,
208
+ device=self.input_ids.device,
209
+ )
210
+ )
211
+ global_num_tokens = self.global_num_tokens_gpu
212
+ gathered_buffer = self.gathered_buffer[:num_tokens]
213
+ global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
214
+ else:
215
+ global_num_tokens = None
216
+ gathered_buffer = None
217
+ global_num_tokens_for_logprob = None
218
+
128
219
  spec_info = EagleDraftInput(
129
220
  hidden_states=hidden_states,
130
221
  accept_length=accept_length,
@@ -144,6 +235,9 @@ class EAGLEDraftExtendCudaGraphRunner:
144
235
  seq_lens_sum=seq_lens.sum().item(),
145
236
  return_logprob=False,
146
237
  positions=positions,
238
+ global_num_tokens_gpu=global_num_tokens,
239
+ global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
240
+ gathered_buffer=gathered_buffer,
147
241
  spec_algorithm=self.model_runner.spec_algorithm,
148
242
  spec_info=spec_info,
149
243
  capture_hidden_mode=CaptureHiddenMode.LAST,
@@ -164,6 +258,9 @@ class EAGLEDraftExtendCudaGraphRunner:
164
258
 
165
259
  # Run and capture
166
260
  def run_once():
261
+ # Clean intermediate result cache for DP attention
262
+ forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
263
+
167
264
  # Backup two fields, which will be modified in-place in `draft_forward`.
168
265
  output_cache_loc_backup = forward_batch.out_cache_loc
169
266
  hidden_states_backup = forward_batch.spec_info.hidden_states
@@ -200,38 +297,57 @@ class EAGLEDraftExtendCudaGraphRunner:
200
297
  # in the batch, which will not be counted as num_seqs
201
298
  raw_bs = forward_batch.batch_size
202
299
  num_tokens = forward_batch.input_ids.shape[0]
300
+ if self.require_mlp_tp_gather:
301
+ total_batch_size = (
302
+ sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
303
+ if self.model_runner.spec_algorithm.is_eagle()
304
+ else sum(forward_batch.global_num_tokens_cpu)
305
+ )
306
+ index = bisect.bisect_left(self.capture_bs, total_batch_size)
307
+ else:
308
+ index = bisect.bisect_left(self.capture_bs, raw_bs)
203
309
 
204
- index = bisect.bisect_left(self.capture_bs, raw_bs)
205
310
  bs = self.capture_bs[index]
206
311
  if bs * self.num_tokens_per_bs != num_tokens:
207
- self.seq_lens.fill_(1)
208
- self.accept_length.fill_(1)
312
+ self.seq_lens.fill_(self.seq_len_fill_value)
209
313
  self.out_cache_loc.zero_()
314
+ self.accept_length.fill_(1)
315
+ self.extend_seq_lens.fill_(1)
210
316
 
211
317
  # Common inputs
212
318
  self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
213
319
  self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
214
- self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
320
+ if forward_batch.extend_seq_lens is not None:
321
+ self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
215
322
  self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
216
323
  self.positions[:num_tokens].copy_(forward_batch.positions)
217
324
  self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
218
- self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
325
+ if forward_batch.spec_info.accept_length is not None:
326
+ self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
219
327
  self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
220
328
 
329
+ if self.require_gathered_buffer:
330
+ self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
331
+ self.global_num_tokens_for_logprob_gpu.copy_(
332
+ forward_batch.global_num_tokens_for_logprob_gpu
333
+ )
334
+ forward_batch.gathered_buffer = self.gathered_buffer
335
+
221
336
  if forward_batch.seq_lens_cpu is not None:
222
337
  if bs != raw_bs:
223
- self.seq_lens_cpu.fill_(1)
338
+ self.seq_lens_cpu.fill_(self.seq_len_fill_value)
224
339
  self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
225
340
 
226
341
  if bs != raw_bs:
342
+ forward_batch.spec_info.positions = self.positions[:num_tokens]
227
343
  forward_batch.spec_info.accept_length = self.accept_length[:bs]
228
- forward_batch.spec_info.positions = None
229
344
 
230
345
  self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph(
231
346
  bs=bs,
232
347
  req_pool_indices=self.req_pool_indices,
233
348
  seq_lens=self.seq_lens,
234
- seq_lens_sum=forward_batch.seq_lens_sum + (bs - raw_bs),
349
+ seq_lens_sum=forward_batch.seq_lens_sum
350
+ + (bs - raw_bs) * self.seq_len_fill_value,
235
351
  encoder_lens=None,
236
352
  forward_mode=ForwardMode.DRAFT_EXTEND,
237
353
  spec_info=forward_batch.spec_info,