sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
5
5
 
6
6
  import torch
7
7
 
8
+ from sglang.srt.layers.dp_attention import DPPaddingMode
8
9
  from sglang.srt.model_executor.cuda_graph_runner import (
9
10
  CUDA_GRAPH_CAPTURE_FAILED_MSG,
10
11
  CudaGraphRunner,
@@ -97,13 +98,6 @@ class EAGLEDraftCudaGraphRunner:
97
98
  )
98
99
 
99
100
  if self.require_gathered_buffer:
100
- self.gathered_buffer = torch.zeros(
101
- (
102
- self.max_num_token,
103
- self.model_runner.model_config.hidden_size,
104
- ),
105
- dtype=self.model_runner.dtype,
106
- )
107
101
  if self.require_mlp_tp_gather:
108
102
  self.global_num_tokens_gpu = torch.zeros(
109
103
  (self.dp_size,), dtype=torch.int32
@@ -111,12 +105,30 @@ class EAGLEDraftCudaGraphRunner:
111
105
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
112
106
  (self.dp_size,), dtype=torch.int32
113
107
  )
108
+ self.gathered_buffer = torch.zeros(
109
+ (
110
+ self.max_num_token * self.dp_size,
111
+ self.model_runner.model_config.hidden_size,
112
+ ),
113
+ dtype=self.model_runner.dtype,
114
+ )
114
115
  else:
115
116
  assert self.require_attn_tp_gather
116
117
  self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
117
118
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
118
119
  (1,), dtype=torch.int32
119
120
  )
121
+ self.gathered_buffer = torch.zeros(
122
+ (
123
+ self.max_num_token,
124
+ self.model_runner.model_config.hidden_size,
125
+ ),
126
+ dtype=self.model_runner.dtype,
127
+ )
128
+ else:
129
+ self.global_num_tokens_gpu = None
130
+ self.global_num_tokens_for_logprob_gpu = None
131
+ self.gathered_buffer = None
120
132
 
121
133
  # Capture
122
134
  try:
@@ -130,9 +142,9 @@ class EAGLEDraftCudaGraphRunner:
130
142
  def can_run(self, forward_batch: ForwardBatch):
131
143
  if self.require_mlp_tp_gather:
132
144
  cuda_graph_bs = (
133
- sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
145
+ max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
134
146
  if self.model_runner.spec_algorithm.is_eagle()
135
- else sum(forward_batch.global_num_tokens_cpu)
147
+ else max(forward_batch.global_num_tokens_cpu)
136
148
  )
137
149
  else:
138
150
  cuda_graph_bs = forward_batch.batch_size
@@ -168,26 +180,20 @@ class EAGLEDraftCudaGraphRunner:
168
180
  if self.require_mlp_tp_gather:
169
181
  self.global_num_tokens_gpu.copy_(
170
182
  torch.tensor(
171
- [
172
- num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
173
- for i in range(self.dp_size)
174
- ],
183
+ [num_tokens] * self.dp_size,
175
184
  dtype=torch.int32,
176
185
  device=self.input_ids.device,
177
186
  )
178
187
  )
179
188
  self.global_num_tokens_for_logprob_gpu.copy_(
180
189
  torch.tensor(
181
- [
182
- num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
183
- for i in range(self.dp_size)
184
- ],
190
+ [num_tokens] * self.dp_size,
185
191
  dtype=torch.int32,
186
192
  device=self.input_ids.device,
187
193
  )
188
194
  )
189
195
  global_num_tokens = self.global_num_tokens_gpu
190
- gathered_buffer = self.gathered_buffer[:num_tokens]
196
+ gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
191
197
  global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
192
198
  elif self.require_attn_tp_gather:
193
199
  self.global_num_tokens_gpu.copy_(
@@ -233,6 +239,7 @@ class EAGLEDraftCudaGraphRunner:
233
239
  return_logprob=False,
234
240
  positions=positions,
235
241
  global_num_tokens_gpu=global_num_tokens,
242
+ dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
236
243
  gathered_buffer=gathered_buffer,
237
244
  spec_algorithm=self.model_runner.spec_algorithm,
238
245
  spec_info=spec_info,
@@ -290,12 +297,13 @@ class EAGLEDraftCudaGraphRunner:
290
297
 
291
298
  # Pad
292
299
  if self.require_mlp_tp_gather:
293
- total_batch_size = (
294
- sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
300
+ max_num_tokens = max(forward_batch.global_num_tokens_cpu)
301
+ max_batch_size = (
302
+ max_num_tokens // self.num_tokens_per_bs
295
303
  if self.model_runner.spec_algorithm.is_eagle()
296
- else sum(forward_batch.global_num_tokens_cpu)
304
+ else max_num_tokens
297
305
  )
298
- index = bisect.bisect_left(self.capture_bs, total_batch_size)
306
+ index = bisect.bisect_left(self.capture_bs, max_batch_size)
299
307
  else:
300
308
  index = bisect.bisect_left(self.capture_bs, raw_bs)
301
309
  bs = self.capture_bs[index]
@@ -316,12 +324,10 @@ class EAGLEDraftCudaGraphRunner:
316
324
  self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
317
325
  self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
318
326
 
327
+ # TODO(ch-wan): support num_token_non_padded
319
328
  if self.require_gathered_buffer:
320
- self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
321
- self.global_num_tokens_for_logprob_gpu.copy_(
322
- forward_batch.global_num_tokens_for_logprob_gpu
323
- )
324
- forward_batch.gathered_buffer = self.gathered_buffer
329
+ self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
330
+ self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
325
331
 
326
332
  # Attention backend
327
333
  if bs != raw_bs:
@@ -330,7 +336,6 @@ class EAGLEDraftCudaGraphRunner:
330
336
  forward_batch.req_pool_indices = self.req_pool_indices[:bs]
331
337
  forward_batch.positions = self.positions[:num_tokens]
332
338
 
333
- # Special handle for seq_len_cpu used when flashinfer mla is used
334
339
  if forward_batch.seq_lens_cpu is not None:
335
340
  if bs != raw_bs:
336
341
  self.seq_lens_cpu.fill_(self.seq_len_fill_value)
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
5
5
 
6
6
  import torch
7
7
 
8
+ from sglang.srt.layers.dp_attention import DPPaddingMode
8
9
  from sglang.srt.model_executor.cuda_graph_runner import (
9
10
  CUDA_GRAPH_CAPTURE_FAILED_MSG,
10
11
  CudaGraphRunner,
@@ -84,7 +85,15 @@ class EAGLEDraftExtendCudaGraphRunner:
84
85
  self.hidden_states = torch.zeros(
85
86
  (
86
87
  self.max_num_token,
87
- self.model_runner.model_config.hidden_size * 3,
88
+ (
89
+ self.model_runner.model_config.hf_config.target_hidden_size
90
+ * 3
91
+ if hasattr(
92
+ self.model_runner.model_config.hf_config,
93
+ "target_hidden_size",
94
+ )
95
+ else self.model_runner.model_config.hidden_size * 3
96
+ ),
88
97
  ),
89
98
  dtype=self.model_runner.dtype,
90
99
  )
@@ -101,13 +110,6 @@ class EAGLEDraftExtendCudaGraphRunner:
101
110
  )
102
111
 
103
112
  if self.require_gathered_buffer:
104
- self.gathered_buffer = torch.zeros(
105
- (
106
- self.max_num_token,
107
- self.model_runner.model_config.hidden_size,
108
- ),
109
- dtype=self.model_runner.dtype,
110
- )
111
113
  if self.require_mlp_tp_gather:
112
114
  self.global_num_tokens_gpu = torch.zeros(
113
115
  (self.dp_size,), dtype=torch.int32
@@ -115,12 +117,31 @@ class EAGLEDraftExtendCudaGraphRunner:
115
117
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
116
118
  (self.dp_size,), dtype=torch.int32
117
119
  )
120
+ self.gathered_buffer = torch.zeros(
121
+ (
122
+ self.max_num_token * self.dp_size,
123
+ self.model_runner.model_config.hidden_size,
124
+ ),
125
+ dtype=self.model_runner.dtype,
126
+ )
118
127
  else:
119
128
  assert self.require_attn_tp_gather
120
129
  self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
121
130
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
122
131
  (1,), dtype=torch.int32
123
132
  )
133
+ self.gathered_buffer = torch.zeros(
134
+ (
135
+ self.max_num_token,
136
+ self.model_runner.model_config.hidden_size,
137
+ ),
138
+ dtype=self.model_runner.dtype,
139
+ )
140
+ else:
141
+ self.global_num_tokens_gpu = None
142
+ self.global_num_tokens_for_logprob_gpu = None
143
+ self.gathered_buffer = None
144
+
124
145
  # Capture
125
146
  try:
126
147
  with model_capture_mode():
@@ -133,9 +154,9 @@ class EAGLEDraftExtendCudaGraphRunner:
133
154
  def can_run(self, forward_batch: ForwardBatch):
134
155
  if self.require_mlp_tp_gather:
135
156
  cuda_graph_bs = (
136
- sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
157
+ max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
137
158
  if self.model_runner.spec_algorithm.is_eagle()
138
- else sum(forward_batch.global_num_tokens_cpu)
159
+ else max(forward_batch.global_num_tokens_cpu)
139
160
  )
140
161
  else:
141
162
  cuda_graph_bs = forward_batch.seq_lens.numel()
@@ -172,27 +193,19 @@ class EAGLEDraftExtendCudaGraphRunner:
172
193
  if self.require_mlp_tp_gather:
173
194
  self.global_num_tokens_gpu.copy_(
174
195
  torch.tensor(
175
- [
176
- num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
177
- for i in range(self.dp_size)
178
- ],
196
+ [num_tokens] * self.dp_size,
179
197
  dtype=torch.int32,
180
198
  device=self.input_ids.device,
181
199
  )
182
200
  )
183
201
  self.global_num_tokens_for_logprob_gpu.copy_(
184
202
  torch.tensor(
185
- [
186
- num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
187
- for i in range(self.dp_size)
188
- ],
203
+ [bs] * self.dp_size,
189
204
  dtype=torch.int32,
190
205
  device=self.input_ids.device,
191
206
  )
192
207
  )
193
- global_num_tokens = self.global_num_tokens_gpu
194
- gathered_buffer = self.gathered_buffer[:num_tokens]
195
- global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
208
+ gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
196
209
  elif self.require_attn_tp_gather:
197
210
  self.global_num_tokens_gpu.copy_(
198
211
  torch.tensor(
@@ -203,18 +216,14 @@ class EAGLEDraftExtendCudaGraphRunner:
203
216
  )
204
217
  self.global_num_tokens_for_logprob_gpu.copy_(
205
218
  torch.tensor(
206
- [num_tokens],
219
+ [bs],
207
220
  dtype=torch.int32,
208
221
  device=self.input_ids.device,
209
222
  )
210
223
  )
211
- global_num_tokens = self.global_num_tokens_gpu
212
224
  gathered_buffer = self.gathered_buffer[:num_tokens]
213
- global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
214
225
  else:
215
- global_num_tokens = None
216
226
  gathered_buffer = None
217
- global_num_tokens_for_logprob = None
218
227
 
219
228
  spec_info = EagleDraftInput(
220
229
  hidden_states=hidden_states,
@@ -235,8 +244,9 @@ class EAGLEDraftExtendCudaGraphRunner:
235
244
  seq_lens_sum=seq_lens.sum().item(),
236
245
  return_logprob=False,
237
246
  positions=positions,
238
- global_num_tokens_gpu=global_num_tokens,
239
- global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
247
+ global_num_tokens_gpu=self.global_num_tokens_gpu,
248
+ global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
249
+ dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
240
250
  gathered_buffer=gathered_buffer,
241
251
  spec_algorithm=self.model_runner.spec_algorithm,
242
252
  spec_info=spec_info,
@@ -298,12 +308,13 @@ class EAGLEDraftExtendCudaGraphRunner:
298
308
  raw_bs = forward_batch.batch_size
299
309
  num_tokens = forward_batch.input_ids.shape[0]
300
310
  if self.require_mlp_tp_gather:
301
- total_batch_size = (
302
- sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
311
+ max_num_tokens = max(forward_batch.global_num_tokens_cpu)
312
+ max_batch_size = (
313
+ max_num_tokens // self.num_tokens_per_bs
303
314
  if self.model_runner.spec_algorithm.is_eagle()
304
- else sum(forward_batch.global_num_tokens_cpu)
315
+ else max_num_tokens
305
316
  )
306
- index = bisect.bisect_left(self.capture_bs, total_batch_size)
317
+ index = bisect.bisect_left(self.capture_bs, max_batch_size)
307
318
  else:
308
319
  index = bisect.bisect_left(self.capture_bs, raw_bs)
309
320
 
@@ -326,12 +337,10 @@ class EAGLEDraftExtendCudaGraphRunner:
326
337
  self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
327
338
  self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
328
339
 
340
+ # TODO(ch-wan): support num_token_non_padded
329
341
  if self.require_gathered_buffer:
330
- self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
331
- self.global_num_tokens_for_logprob_gpu.copy_(
332
- forward_batch.global_num_tokens_for_logprob_gpu
333
- )
334
- forward_batch.gathered_buffer = self.gathered_buffer
342
+ self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
343
+ self.global_num_tokens_for_logprob_gpu.fill_(bs)
335
344
 
336
345
  if forward_batch.seq_lens_cpu is not None:
337
346
  if bs != raw_bs:
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import copy
3
4
  import logging
4
5
  import os
5
6
  import time
@@ -70,9 +71,20 @@ class EagleDraftInput:
70
71
  kv_indptr: torch.Tensor = None
71
72
  kv_indices: torch.Tensor = None
72
73
 
74
+ # Shape info for padding
75
+ num_tokens_per_batch: int = -1
76
+ num_tokens_for_logprob_per_batch: int = -1
77
+
78
+ # Inputs for draft extend
79
+ # shape: (b,)
80
+ seq_lens_for_draft_extend: torch.Tensor = None
81
+ req_pool_indices_for_draft_extend: torch.Tensor = None
82
+
73
83
  def prepare_for_extend(self, batch: ScheduleBatch):
84
+
74
85
  if batch.forward_mode.is_idle():
75
86
  return
87
+
76
88
  # Prefill only generate 1 token.
77
89
  assert len(self.verified_id) == len(batch.seq_lens)
78
90
 
@@ -94,7 +106,7 @@ class EagleDraftInput:
94
106
  capture_hidden_mode: CaptureHiddenMode,
95
107
  ):
96
108
  return cls(
97
- verified_id=None,
109
+ verified_id=torch.empty((0,), device=device, dtype=torch.int32),
98
110
  hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
99
111
  topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
100
112
  topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
@@ -108,7 +120,10 @@ class EagleDraftInput:
108
120
  batch: ScheduleBatch,
109
121
  speculative_num_steps: int,
110
122
  ):
111
- batch.forward_mode = ForwardMode.DRAFT_EXTEND
123
+
124
+ if batch.forward_mode.is_idle():
125
+ return
126
+
112
127
  batch.input_ids = self.verified_id
113
128
  batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
114
129
  batch.extend_num_tokens = sum(batch.extend_lens)
@@ -315,7 +330,7 @@ class EagleVerifyInput:
315
330
  def verify(
316
331
  self,
317
332
  batch: ScheduleBatch,
318
- logits_output: torch.Tensor,
333
+ logits_output: LogitsProcessorOutput,
319
334
  token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
320
335
  page_size: int,
321
336
  vocab_mask: Optional[torch.Tensor] = None, # For grammar
@@ -362,6 +377,11 @@ class EagleVerifyInput:
362
377
  )
363
378
  accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
364
379
 
380
+ if bs != len(sampling_info):
381
+ sampling_info = copy.deepcopy(sampling_info)
382
+ # NOTE: retrive_index are the indices of the requests that are kept.
383
+ sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
384
+
365
385
  # Apply the custom logit processors if registered in the sampling info.
366
386
  if sampling_info.has_custom_logit_processor:
367
387
  apply_custom_logit_processor(
@@ -593,13 +613,14 @@ class EagleVerifyInput:
593
613
  batch.out_cache_loc = tgt_cache_loc
594
614
  batch.seq_lens.add_(accept_length + 1)
595
615
 
596
- draft_input = EagleDraftInput()
597
- draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
598
- draft_input.verified_id = verified_id
599
- draft_input.accept_length = accept_length
600
- draft_input.accept_length_cpu = accept_length.tolist()
601
- draft_input.seq_lens_for_draft_extend = batch.seq_lens
602
- draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
616
+ draft_input = EagleDraftInput(
617
+ hidden_states=batch.spec_info.hidden_states[accept_index],
618
+ verified_id=verified_id,
619
+ accept_length=accept_length,
620
+ accept_length_cpu=accept_length.tolist(),
621
+ seq_lens_for_draft_extend=batch.seq_lens,
622
+ req_pool_indices_for_draft_extend=batch.req_pool_indices,
623
+ )
603
624
 
604
625
  return EagleVerifyOutput(
605
626
  draft_input=draft_input,
@@ -622,7 +643,6 @@ class EagleVerifyInput:
622
643
  batch.seq_lens.add_(accept_length + 1)
623
644
 
624
645
  accept_length_cpu = accept_length.tolist()
625
- draft_input = EagleDraftInput()
626
646
  if len(unfinished_accept_index) > 0:
627
647
  unfinished_accept_index = torch.cat(unfinished_accept_index)
628
648
  unfinished_index_device = torch.tensor(
@@ -653,18 +673,26 @@ class EagleVerifyInput:
653
673
  next_power_of_2(self.draft_token_num),
654
674
  )
655
675
 
656
- draft_input.hidden_states = batch.spec_info.hidden_states[
657
- unfinished_accept_index
658
- ]
659
- draft_input.verified_id = predict[unfinished_accept_index]
660
- draft_input.accept_length_cpu = draft_input_accept_length_cpu
661
- draft_input.accept_length = accept_length[unfinished_index_device]
662
- draft_input.seq_lens_for_draft_extend = batch.seq_lens[
663
- unfinished_index_device
664
- ]
665
- draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
666
- unfinished_index_device
667
- ]
676
+ draft_input = EagleDraftInput(
677
+ hidden_states=batch.spec_info.hidden_states[
678
+ unfinished_accept_index
679
+ ],
680
+ verified_id=predict[unfinished_accept_index],
681
+ accept_length_cpu=draft_input_accept_length_cpu,
682
+ accept_length=accept_length[unfinished_index_device],
683
+ seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device],
684
+ req_pool_indices_for_draft_extend=batch.req_pool_indices[
685
+ unfinished_index_device
686
+ ],
687
+ )
688
+ else:
689
+ draft_input = EagleDraftInput.create_idle_input(
690
+ device=batch.device,
691
+ hidden_size=batch.model_config.hidden_size,
692
+ dtype=batch.model_config.dtype,
693
+ topk=self.topk,
694
+ capture_hidden_mode=CaptureHiddenMode.LAST,
695
+ )
668
696
 
669
697
  return EagleVerifyOutput(
670
698
  draft_input=draft_input,