sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -3,8 +3,13 @@
3
3
  from typing import List
4
4
 
5
5
  import torch
6
- from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
7
- from sgl_kernel import build_tree_kernel_efficient as sgl_build_tree_kernel_efficient
6
+
7
+ from sglang.srt.utils import is_cuda_available, is_hip
8
+
9
+ if is_cuda_available() or is_hip():
10
+ from sgl_kernel import (
11
+ build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
12
+ )
8
13
 
9
14
 
10
15
  def build_tree_kernel_efficient_preprocess(
@@ -23,7 +28,6 @@ def build_tree_kernel_efficient_preprocess(
23
28
  top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1)
24
29
  top_scores_index = top_scores.indices
25
30
  top_scores_index = torch.sort(top_scores_index).values
26
-
27
31
  draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
28
32
  draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
29
33
 
@@ -108,296 +112,6 @@ def build_tree_kernel_efficient(
108
112
  )
109
113
 
110
114
 
111
- def build_tree_kernel(
112
- verified_id: torch.Tensor,
113
- score_list: List[torch.Tensor],
114
- token_list: List[torch.Tensor],
115
- parents_list: List[torch.Tensor],
116
- seq_lens: torch.Tensor,
117
- seq_lens_sum: int,
118
- topk: int,
119
- spec_steps: int,
120
- num_verify_tokens: int,
121
- ):
122
- parent_list, top_scores_index, draft_tokens = (
123
- build_tree_kernel_efficient_preprocess(
124
- verified_id,
125
- score_list,
126
- token_list,
127
- parents_list,
128
- num_verify_tokens,
129
- )
130
- )
131
-
132
- bs = seq_lens.numel()
133
- device = seq_lens.device
134
-
135
- tree_mask = torch.full(
136
- (
137
- seq_lens_sum * num_verify_tokens
138
- + num_verify_tokens * num_verify_tokens * bs,
139
- ),
140
- True,
141
- device=device,
142
- )
143
- retrive_index = torch.full(
144
- (bs, num_verify_tokens, spec_steps + 2), -1, device=device, dtype=torch.long
145
- )
146
- positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
147
-
148
- sgl_build_tree_kernel(
149
- parent_list,
150
- top_scores_index,
151
- seq_lens.to(torch.int32),
152
- tree_mask,
153
- positions,
154
- retrive_index,
155
- topk,
156
- spec_steps,
157
- num_verify_tokens,
158
- )
159
-
160
- index = retrive_index.sum(dim=-1) != -spec_steps - 2
161
- cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
162
- retrive_cum_len = torch.zeros(
163
- (cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
164
- )
165
- retrive_cum_len[1:] = cum_len
166
- # TODO: this indexing cause a synchronization, optimize this
167
- retrive_index = retrive_index[index]
168
- return tree_mask, positions, retrive_index, retrive_cum_len, draft_tokens
169
-
170
-
171
- def test_build_tree_kernel():
172
- def findp(p_i, index, parent_list):
173
- pos = index // 10
174
- index_list = index.tolist()
175
- parent_list = parent_list.tolist()
176
- res = [p_i]
177
- while True:
178
- p = pos[p_i]
179
- if p == 0:
180
- break
181
- token_idx = parent_list[p]
182
- p_i = index_list.index(token_idx)
183
- res.append(p_i)
184
- return res
185
-
186
- def create_mask(seq_len, draft_token, index, parent_list, max_depth):
187
- mask = []
188
- positions = []
189
- retrive_index = []
190
- for i, lens in enumerate(seq_len.tolist()):
191
- first_mask = torch.full((lens + draft_token,), True)
192
- first_mask[-(draft_token - 1) :] = False
193
- positions.append(lens)
194
- mask.append(first_mask)
195
- seq_order = []
196
- first_index = torch.Tensor([0] + [-1] * (depth + 1)).cuda().to(torch.long)
197
- r_index = [first_index]
198
- for j in range(draft_token - 1):
199
- mask.append(torch.full((lens + 1,), True))
200
- idx = findp(j, index, parent_list)
201
-
202
- seq_order.append(idx)
203
- positions.append(len(idx) + seq_len)
204
- t = torch.full((draft_token - 1,), False)
205
- t[idx] = True
206
- mask.append(t)
207
-
208
- for i in range(1, draft_token - 1):
209
- is_leaf = 0
210
- for j in range(draft_token - 1):
211
- if i in seq_order[j]:
212
- is_leaf += 1
213
-
214
- if is_leaf == 1:
215
- order_list = [0] + [x + 1 for x in seq_order[i][::-1]]
216
- for _ in range(max_depth + 1 - len(seq_order[i])):
217
- order_list.append(-1)
218
- order = torch.Tensor(order_list).cuda().to(torch.long)
219
- r_index.append(order)
220
- retrive_index.append(torch.stack(r_index))
221
-
222
- return (
223
- torch.cat(mask).cuda(),
224
- torch.Tensor(positions).cuda().to(torch.long),
225
- torch.stack(retrive_index),
226
- )
227
-
228
- index = (
229
- torch.Tensor(
230
- [
231
- 0,
232
- 1,
233
- 2,
234
- 3,
235
- 10,
236
- 11,
237
- 12,
238
- 13,
239
- 20,
240
- 21,
241
- 22,
242
- 30,
243
- 110,
244
- 130,
245
- 150,
246
- 160,
247
- 210,
248
- 211,
249
- 212,
250
- 213,
251
- 214,
252
- 215,
253
- 216,
254
- 217,
255
- 218,
256
- 219,
257
- 220,
258
- 230,
259
- 310,
260
- 311,
261
- 312,
262
- 313,
263
- 314,
264
- 315,
265
- 316,
266
- 317,
267
- 320,
268
- 321,
269
- 322,
270
- 330,
271
- 360,
272
- 380,
273
- 390,
274
- 410,
275
- 411,
276
- 412,
277
- 413,
278
- 414,
279
- 415,
280
- 416,
281
- 417,
282
- 418,
283
- 419,
284
- 420,
285
- 421,
286
- 422,
287
- 423,
288
- 430,
289
- 431,
290
- 440,
291
- 441,
292
- 460,
293
- 470,
294
- ]
295
- )
296
- .to(torch.long)
297
- .cuda()
298
- )
299
-
300
- parent_list = (
301
- torch.Tensor(
302
- [
303
- -1,
304
- 0,
305
- 1,
306
- 2,
307
- 3,
308
- 4,
309
- 5,
310
- 6,
311
- 7,
312
- 8,
313
- 9,
314
- 10,
315
- 11,
316
- 12,
317
- 20,
318
- 30,
319
- 21,
320
- 13,
321
- 22,
322
- 40,
323
- 23,
324
- 110,
325
- 130,
326
- 160,
327
- 150,
328
- 190,
329
- 120,
330
- 111,
331
- 121,
332
- 200,
333
- 180,
334
- 210,
335
- 211,
336
- 212,
337
- 213,
338
- 214,
339
- 215,
340
- 216,
341
- 220,
342
- 230,
343
- 217,
344
- 310,
345
- 311,
346
- 312,
347
- 313,
348
- 320,
349
- 314,
350
- 321,
351
- 315,
352
- 316,
353
- 317,
354
- ]
355
- )
356
- .to(torch.long)
357
- .cuda()
358
- )
359
-
360
- verified_seq_len = torch.Tensor([47]).to(torch.long).cuda()
361
- bs = verified_seq_len.shape[0]
362
- topk = 10
363
- depth = 5 # depth <= 10
364
- num_draft_token = 64
365
-
366
- tree_mask = torch.full(
367
- (
368
- torch.sum(verified_seq_len).item() * num_draft_token
369
- + num_draft_token * num_draft_token * bs,
370
- ),
371
- True,
372
- ).cuda()
373
- retrive_index = torch.full(
374
- (bs, num_draft_token, depth + 2), -1, device="cuda", dtype=torch.long
375
- )
376
- positions = torch.empty((bs * num_draft_token,), device="cuda", dtype=torch.long)
377
-
378
- sgl_build_tree_kernel(
379
- parent_list.unsqueeze(0),
380
- index.unsqueeze(0),
381
- verified_seq_len,
382
- tree_mask,
383
- positions,
384
- retrive_index,
385
- topk,
386
- depth,
387
- num_draft_token,
388
- )
389
-
390
- retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
391
-
392
- c_mask, c_positions, c_retive_index = create_mask(
393
- verified_seq_len, num_draft_token, index, parent_list, depth
394
- )
395
-
396
- assert torch.allclose(tree_mask, c_mask), "tree mask has error."
397
- assert torch.allclose(positions, c_positions), "positions has error."
398
- assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
399
-
400
-
401
115
  def test_build_tree_kernel_efficient():
402
116
  verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32)
403
117
  score_list = [
@@ -611,59 +325,6 @@ def test_build_tree_kernel_efficient():
611
325
  depth = 4
612
326
  num_draft_token = 8
613
327
 
614
- tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
615
- build_tree_kernel(
616
- verified_id=verified_id,
617
- score_list=score_list,
618
- token_list=token_list,
619
- parents_list=parents_list,
620
- seq_lens=seq_lens,
621
- seq_lens_sum=torch.sum(seq_lens).item(),
622
- topk=topk,
623
- spec_steps=depth,
624
- num_verify_tokens=num_draft_token,
625
- )
626
- )
627
-
628
- from sglang.srt.utils import first_rank_print
629
-
630
- first_rank_print("=========== build tree kernel ==========")
631
- # first_rank_print(f"{tree_mask=}", flush=True)
632
- first_rank_print(f"{position=}", flush=True)
633
- first_rank_print(f"{retrive_index=}", flush=True)
634
- first_rank_print(f"{retrive_cum_len=}", flush=True)
635
- first_rank_print(f"{draft_tokens=}", flush=True)
636
- assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
637
- assert retrive_index.tolist() == [
638
- [0, -1, -1, -1, -1, -1],
639
- [0, 2, 4, 6, -1, -1],
640
- [0, 1, 3, 5, 7, -1],
641
- [8, -1, -1, -1, -1, -1],
642
- [8, 9, 10, -1, -1, -1],
643
- [8, 9, 12, -1, -1, -1],
644
- [8, 9, 13, -1, -1, -1],
645
- [8, 9, 11, 14, 15, -1],
646
- ]
647
- assert retrive_cum_len.tolist() == [0, 3, 8]
648
- assert draft_tokens.tolist() == [
649
- 29974,
650
- 29896,
651
- 29906,
652
- 29889,
653
- 29974,
654
- 29946,
655
- 29896,
656
- 29946,
657
- 13,
658
- 13,
659
- 22550,
660
- 4136,
661
- 16492,
662
- 8439,
663
- 29871,
664
- 29941,
665
- ]
666
-
667
328
  (
668
329
  tree_mask,
669
330
  position,
@@ -725,4 +386,3 @@ def test_build_tree_kernel_efficient():
725
386
 
726
387
  if __name__ == "__main__":
727
388
  test_build_tree_kernel_efficient()
728
- test_build_tree_kernel()
@@ -22,6 +22,10 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput
22
22
  if TYPE_CHECKING:
23
23
  from sglang.srt.speculative.eagle_worker import EAGLEWorker
24
24
 
25
+ import logging
26
+
27
+ logger = logging.getLogger(__name__)
28
+
25
29
 
26
30
  class EAGLEDraftCudaGraphRunner:
27
31
  def __init__(self, eagle_worker: EAGLEWorker):
@@ -33,13 +37,10 @@ class EAGLEDraftCudaGraphRunner:
33
37
  self.enable_torch_compile = model_runner.server_args.enable_torch_compile
34
38
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
35
39
  self.tp_size = self.model_runner.tp_size
36
- self.dp_size = model_runner.server_args.dp_size
37
40
  self.topk = model_runner.server_args.speculative_eagle_topk
38
41
  self.speculative_num_steps = model_runner.server_args.speculative_num_steps
39
42
  server_args = model_runner.server_args
40
43
 
41
- assert self.disable_padding
42
-
43
44
  # Batch sizes to capture
44
45
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
45
46
  self.num_tokens_per_bs = server_args.speculative_eagle_topk
@@ -51,6 +52,9 @@ class EAGLEDraftCudaGraphRunner:
51
52
  self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
52
53
  0
53
54
  ].get_cuda_graph_seq_len_fill_value()
55
+ self.seq_lens_cpu = torch.full(
56
+ (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
57
+ )
54
58
 
55
59
  if self.enable_torch_compile:
56
60
  set_torch_compile_config()
@@ -169,6 +173,13 @@ class EAGLEDraftCudaGraphRunner:
169
173
  set_global_graph_memory_pool(graph.pool())
170
174
  return graph, out
171
175
 
176
+ def _postprocess_output_to_raw_bs(self, out, raw_bs):
177
+ score_list, token_list, parents_list = out
178
+ score_list = [x[:raw_bs] for x in score_list]
179
+ token_list = [x[:raw_bs] for x in token_list]
180
+ parents_list = [x[:raw_bs] for x in parents_list]
181
+ return (score_list, token_list, parents_list)
182
+
172
183
  def replay(self, forward_batch: ForwardBatch):
173
184
  assert forward_batch.out_cache_loc is not None
174
185
  raw_bs = forward_batch.batch_size
@@ -180,6 +191,9 @@ class EAGLEDraftCudaGraphRunner:
180
191
  if bs != raw_bs:
181
192
  self.seq_lens.fill_(1)
182
193
  self.out_cache_loc.zero_()
194
+ self.positions.zero_()
195
+
196
+ num_tokens = bs * self.num_tokens_per_bs
183
197
 
184
198
  # Common inputs
185
199
  self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
@@ -193,11 +207,33 @@ class EAGLEDraftCudaGraphRunner:
193
207
  self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
194
208
 
195
209
  # Attention backend
210
+ if bs != raw_bs:
211
+ forward_batch.batch_size = bs
212
+ forward_batch.seq_lens = self.seq_lens[:bs]
213
+ forward_batch.req_pool_indices = self.req_pool_indices[:bs]
214
+ forward_batch.positions = self.positions[:num_tokens]
215
+
216
+ # Special handle for seq_len_cpu used when flashinfer mla is used
217
+ if (forward_batch.decode_seq_lens_cpu is not None) and (bs != raw_bs):
218
+ self.seq_lens_cpu.fill_(1)
219
+ self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
220
+ forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:bs]
221
+
196
222
  self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
197
- forward_batch, forward_batch.batch_size
223
+ forward_batch, bs
198
224
  )
199
225
 
200
226
  # Replay
201
227
  self.graphs[bs].replay()
228
+ out = self.output_buffers[bs]
202
229
 
203
- return self.output_buffers[bs]
230
+ if bs != raw_bs:
231
+ out = self._postprocess_output_to_raw_bs(out, raw_bs)
232
+ forward_batch.batch_size = raw_bs
233
+ forward_batch.positions = self.positions[:raw_num_token]
234
+ forward_batch.seq_lens = self.seq_lens[:raw_bs]
235
+ forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
236
+ if forward_batch.decode_seq_lens_cpu is not None:
237
+ forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:raw_bs]
238
+
239
+ return out