sglang 0.4.4.post3__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/model_config.py +1 -0
  4. sglang/srt/constrained/base_grammar_backend.py +5 -1
  5. sglang/srt/custom_op.py +5 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  7. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  8. sglang/srt/entrypoints/engine.py +0 -5
  9. sglang/srt/layers/attention/flashattention_backend.py +394 -76
  10. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  11. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  12. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  13. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  14. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  15. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  20. sglang/srt/layers/moe/topk.py +49 -3
  21. sglang/srt/layers/quantization/__init__.py +4 -1
  22. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  23. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  24. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  25. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  26. sglang/srt/layers/quantization/utils.py +1 -1
  27. sglang/srt/layers/rotary_embedding.py +0 -12
  28. sglang/srt/managers/cache_controller.py +34 -11
  29. sglang/srt/managers/mm_utils.py +202 -156
  30. sglang/srt/managers/multimodal_processor.py +0 -2
  31. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  32. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  33. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  34. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  35. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  36. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  37. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  38. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  39. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  40. sglang/srt/managers/schedule_batch.py +185 -128
  41. sglang/srt/managers/scheduler.py +4 -4
  42. sglang/srt/managers/tokenizer_manager.py +1 -1
  43. sglang/srt/managers/utils.py +1 -6
  44. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  45. sglang/srt/mem_cache/memory_pool.py +72 -6
  46. sglang/srt/mem_cache/paged_allocator.py +39 -0
  47. sglang/srt/metrics/collector.py +23 -53
  48. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  49. sglang/srt/model_executor/forward_batch_info.py +10 -10
  50. sglang/srt/model_executor/model_runner.py +59 -57
  51. sglang/srt/model_loader/loader.py +8 -0
  52. sglang/srt/models/clip.py +12 -7
  53. sglang/srt/models/deepseek_janus_pro.py +10 -15
  54. sglang/srt/models/deepseek_v2.py +212 -121
  55. sglang/srt/models/deepseek_vl2.py +105 -104
  56. sglang/srt/models/gemma3_mm.py +14 -80
  57. sglang/srt/models/llama.py +4 -1
  58. sglang/srt/models/llava.py +31 -19
  59. sglang/srt/models/llavavid.py +16 -7
  60. sglang/srt/models/minicpmo.py +63 -147
  61. sglang/srt/models/minicpmv.py +17 -27
  62. sglang/srt/models/mllama.py +29 -14
  63. sglang/srt/models/qwen2.py +9 -6
  64. sglang/srt/models/qwen2_5_vl.py +21 -31
  65. sglang/srt/models/qwen2_vl.py +20 -21
  66. sglang/srt/openai_api/adapter.py +18 -6
  67. sglang/srt/platforms/interface.py +371 -0
  68. sglang/srt/server_args.py +99 -14
  69. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  70. sglang/srt/speculative/eagle_utils.py +140 -28
  71. sglang/srt/speculative/eagle_worker.py +93 -24
  72. sglang/srt/utils.py +104 -51
  73. sglang/test/test_custom_ops.py +55 -0
  74. sglang/test/test_utils.py +13 -26
  75. sglang/utils.py +2 -2
  76. sglang/version.py +1 -1
  77. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +4 -3
  78. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +81 -76
  79. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  80. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  81. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -22,29 +22,55 @@ if TYPE_CHECKING:
22
22
  from sglang.srt.layers.radix_attention import RadixAttention
23
23
  from sglang.srt.model_executor.model_runner import ModelRunner
24
24
 
25
- from flash_attn_interface import flash_attn_with_kvcache
25
+ from sgl_kernel.flash_attn import flash_attn_with_kvcache
26
26
 
27
27
 
28
28
  @dataclass
29
29
  class FlashAttentionMetadata:
30
- """Metadata for decode operations to avoid redundant computations."""
30
+ """Metadata to be init once in the model forward pass,
31
+ each layer's forward pass can reuse the metadata."""
31
32
 
33
+ # Cumulative sequence lengths for query
32
34
  cu_seqlens_q: torch.Tensor = None
35
+ # Cumulative sequence lengths for key
33
36
  cu_seqlens_k: torch.Tensor = None
37
+ # Maximum sequence length for query
34
38
  max_seq_len_q: int = 0
39
+ # Maximum sequence length for key
35
40
  max_seq_len_k: int = 0
41
+ # Window size (typically used by Gemma)
36
42
  window_size: tuple = (-1, -1)
43
+ # Page table, the index of KV Cache Tables/Blocks
37
44
  page_table: torch.Tensor = None
45
+ # Sequence lengths for the forward batch
38
46
  cache_seqlens_int32: torch.Tensor = None
39
47
 
40
48
 
41
49
  class FlashAttentionBackend(AttentionBackend):
42
- """FlashAttention backend implementation."""
50
+ """FlashAttention backend implementation.
51
+
52
+ Note about the init:
53
+ - If no spec decoding
54
+ - FlashAttentionBackend will be init once when the server starts.
55
+ - If spec decoding
56
+ - FlashAttentionBackend will be init once for the target worker
57
+ - FlashAttentionMultiStepBackend will be once for the draft worker
58
+ - It will spawn num_steps FlashAttentionBackend for the draft worker
59
+
60
+ Note about CUDA Graph:
61
+ - We only support CUDA Graph for Decode (Normal Decode and Draft Decode) and Target Verify.
62
+ - We don't support CUDA Graph for Extend and Draft Extend.
63
+ - When server init, init_cuda_graph_state will be called first and then init_cuda_graph_capture will be called.
64
+ - For each forward batch, init_replay_cuda_graph will be called first and then replay the graph.
65
+ """
43
66
 
44
67
  def __init__(
45
68
  self,
46
69
  model_runner: ModelRunner,
47
70
  skip_prefill: bool = False,
71
+ topk=0,
72
+ speculative_num_steps=0,
73
+ step_id=0,
48
74
  ):
49
75
  super().__init__()
50
76
 
@@ -53,56 +79,121 @@ class FlashAttentionBackend(AttentionBackend):
53
79
  and model_runner.model_config.is_encoder_decoder
54
80
  ), "Sliding window and cross attention are not supported together"
55
81
 
56
- # Initialize metadata
57
82
  self.forward_metadata: FlashAttentionMetadata = None
58
83
  self.max_context_len = model_runner.model_config.context_len
59
84
  self.device = model_runner.device
60
85
  self.decode_cuda_graph_metadata = {}
86
+ self.target_verify_metadata = {}
61
87
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
62
88
  self.page_size = model_runner.page_size
63
89
  self.use_mla = (
64
90
  model_runner.model_config.attention_arch == AttentionArch.MLA
65
91
  ) and (not global_server_args_dict["disable_mla"])
92
+ self.skip_prefill = skip_prefill
93
+
94
+ # TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
95
+ assert (
96
+ topk <= 1
97
+ ), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend"
98
+
99
+ self.topk = 1
100
+ self.step_id = step_id
101
+ self.speculative_num_steps = speculative_num_steps
66
102
 
67
103
  def init_forward_metadata(self, forward_batch: ForwardBatch):
68
104
  """Initialize forward metadata to cache repetitive calculations."""
69
- # Create metadata based on forward mode
70
105
  metadata = FlashAttentionMetadata()
71
-
72
- # Get sequence information
73
106
  seqlens_in_batch = forward_batch.seq_lens
74
- # Precompute int32 version of sequence lengths
75
- metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
76
107
  batch_size = len(seqlens_in_batch)
77
108
  device = seqlens_in_batch.device
78
- metadata.cu_seqlens_k = torch.nn.functional.pad(
79
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
80
- )
81
- # Precompute maximum sequence length
82
- metadata.max_seq_len_k = seqlens_in_batch.max().item()
83
- # Precompute page table
84
- metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
85
- forward_batch.req_pool_indices, : metadata.max_seq_len_k
86
- ]
87
-
88
- # Precompute strided indices
89
- # [0, page_size, 2 * page_size, ...]
90
- if self.page_size > 1:
91
- self.strided_indices = torch.arange(
92
- 0, metadata.page_table.shape[1], self.page_size, device=self.device
93
- )
94
- metadata.page_table = (
95
- metadata.page_table[:, self.strided_indices] // self.page_size
109
+ if forward_batch.forward_mode.is_decode():
110
+ # Skip Prefill or Draft Decode
111
+ # Note: Draft Decode will be ran on the Draft Worker
112
+ if forward_batch.spec_info is not None:
113
+ metadata.cu_seqlens_q = torch.arange(
114
+ 0, batch_size + 1, dtype=torch.int32, device=device
115
+ )
116
+ seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
117
+ metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32)
118
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
119
+ torch.cumsum(
120
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
121
+ ),
122
+ (1, 0),
123
+ )
124
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
125
+ self.step_id + 1
126
+ )
127
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
128
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
129
+ ]
130
+ cache_loc = forward_batch.out_cache_loc.view(
131
+ self.speculative_num_steps, -1
132
+ ).T
133
+
134
+ for idx, single_seq_len in enumerate(seq_lens_with_decode):
135
+ real_bsz_start_idx = idx
136
+ real_bsz_end_idx = idx + 1
137
+ metadata.page_table[
138
+ real_bsz_start_idx:real_bsz_end_idx,
139
+ (single_seq_len - (self.step_id + 1)) : single_seq_len,
140
+ ] = cache_loc[
141
+ real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1)
142
+ ]
143
+ else: # Normal Decode without Spec Decoding
144
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
145
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
146
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
147
+ )
148
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
149
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
150
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
151
+ ]
152
+ metadata.cu_seqlens_q = torch.arange(
153
+ 0, batch_size + 1, dtype=torch.int32, device=device
154
+ )
155
+ elif forward_batch.forward_mode.is_target_verify():
156
+ # Note: Target Verify will be ran on the Target Worker
157
+ draft_token_num = forward_batch.spec_info.draft_token_num
158
+ metadata.cache_seqlens_int32 = (
159
+ forward_batch.seq_lens + draft_token_num
160
+ ).to(torch.int32)
161
+ metadata.max_seq_len_q = draft_token_num
162
+ metadata.max_seq_len_k = (
163
+ forward_batch.seq_lens_cpu.max().item() + draft_token_num
96
164
  )
97
-
98
- if forward_batch.forward_mode == ForwardMode.DECODE:
99
- # Precompute cumulative sequence lengths
100
165
  metadata.cu_seqlens_q = torch.arange(
101
- 0, batch_size + 1, dtype=torch.int32, device=device
166
+ 0,
167
+ batch_size * draft_token_num + 1,
168
+ draft_token_num,
169
+ dtype=torch.int32,
170
+ device=device,
102
171
  )
103
- else:
172
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
173
+ torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
174
+ (1, 0),
175
+ )
176
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
177
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
178
+ ]
179
+
180
+ elif forward_batch.forward_mode.is_extend_or_draft_extend():
181
+ # Normal or Draft Extend (Both of them will be ran on the Target Worker)
182
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
183
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
184
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
185
+ )
186
+ # Precompute maximum sequence length
187
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
188
+ # Precompute page table
189
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
190
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
191
+ ]
104
192
  # Precompute cumulative sequence lengths
105
- if any(forward_batch.extend_prefix_lens_cpu):
193
+ if (
194
+ any(forward_batch.extend_prefix_lens_cpu)
195
+ or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
196
+ ):
106
197
  extend_seq_lens = forward_batch.extend_seq_lens
107
198
  metadata.cu_seqlens_q = torch.nn.functional.pad(
108
199
  torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
@@ -111,6 +202,15 @@ class FlashAttentionBackend(AttentionBackend):
111
202
  else:
112
203
  metadata.cu_seqlens_q = metadata.cu_seqlens_k
113
204
  metadata.max_seq_len_q = metadata.max_seq_len_k
205
+
206
+ # Precompute strided indices
207
+ if self.page_size > 1:
208
+ self.strided_indices = torch.arange(
209
+ 0, metadata.page_table.shape[1], self.page_size, device=self.device
210
+ )
211
+ metadata.page_table = (
212
+ metadata.page_table[:, self.strided_indices] // self.page_size
213
+ )
114
214
  self.forward_metadata = metadata
115
215
 
116
216
  def forward_extend(
@@ -122,7 +222,6 @@ class FlashAttentionBackend(AttentionBackend):
122
222
  forward_batch: ForwardBatch,
123
223
  save_kv_cache=True,
124
224
  ):
125
-
126
225
  if k is not None:
127
226
  assert v is not None
128
227
  if save_kv_cache:
@@ -157,7 +256,7 @@ class FlashAttentionBackend(AttentionBackend):
157
256
 
158
257
  page_table = metadata.page_table
159
258
 
160
- # # Use Flash Attention for prefill
259
+ # Use Flash Attention for prefill
161
260
  if not self.use_mla:
162
261
  # Do multi-head attention
163
262
  kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
@@ -263,7 +362,6 @@ class FlashAttentionBackend(AttentionBackend):
263
362
  if layer.sliding_window_size is not None
264
363
  else (-1, -1)
265
364
  )
266
-
267
365
  page_table = metadata.page_table
268
366
 
269
367
  if not self.use_mla:
@@ -281,8 +379,6 @@ class FlashAttentionBackend(AttentionBackend):
281
379
 
282
380
  # Pre-reshape query tensor
283
381
  q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
284
-
285
- # Run attention with precomputed values
286
382
  o = flash_attn_with_kvcache(
287
383
  q=q_reshaped,
288
384
  k_cache=key_cache,
@@ -334,7 +430,6 @@ class FlashAttentionBackend(AttentionBackend):
334
430
  k_descale=layer.k_scale,
335
431
  v_descale=layer.v_scale,
336
432
  )
337
-
338
433
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
339
434
 
340
435
  def init_cuda_graph_state(self, max_bs: int):
@@ -346,7 +441,6 @@ class FlashAttentionBackend(AttentionBackend):
346
441
  This creates fixed-size tensors that will be reused during CUDA graph replay
347
442
  to avoid memory allocations.
348
443
  """
349
- # Initialize fixed size tensors for decode operations
350
444
  self.decode_cuda_graph_metadata = {
351
445
  # Page table for token mapping (batch_size, max_context_len)
352
446
  "page_table": torch.zeros(
@@ -355,6 +449,39 @@ class FlashAttentionBackend(AttentionBackend):
355
449
  dtype=torch.int32,
356
450
  device=self.device,
357
451
  ),
452
+ "page_table_draft_decode": torch.zeros(
453
+ max_bs,
454
+ (self.max_context_len + self.page_size - 1) // self.page_size,
455
+ dtype=torch.int32,
456
+ device=self.device,
457
+ ),
458
+ "strided_indices": torch.arange(
459
+ 0, self.max_context_len, self.page_size, device=self.device
460
+ ),
461
+ "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
462
+ "cu_seqlens_q": torch.arange(
463
+ 0, max_bs + 128, dtype=torch.int32, device=self.device
464
+ ),
465
+ "cu_seqlens_k": torch.zeros(
466
+ max_bs + 128, dtype=torch.int32, device=self.device
467
+ ),
468
+ }
469
+
470
+ self.target_verify_metadata = {
471
+ "page_table": torch.zeros(
472
+ max_bs,
473
+ (self.max_context_len + self.page_size - 1) // self.page_size,
474
+ dtype=torch.int32,
475
+ device=self.device,
476
+ ),
477
+ "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
478
+ "cu_seqlens_q": torch.zeros(
479
+ max_bs + 128, dtype=torch.int32, device=self.device
480
+ ),
481
+ "cu_seqlens_k": torch.zeros(
482
+ max_bs + 128, dtype=torch.int32, device=self.device
483
+ ),
484
+ "max_seqlen_q": 0,
358
485
  "strided_indices": torch.arange(
359
486
  0, self.max_context_len, self.page_size, device=self.device
360
487
  ),
@@ -372,27 +499,89 @@ class FlashAttentionBackend(AttentionBackend):
372
499
  ):
373
500
  """Initialize forward metadata for capturing CUDA graph."""
374
501
  metadata = FlashAttentionMetadata()
375
- # Get sequence information
376
- metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
377
- batch_size = len(seq_lens)
378
502
  device = seq_lens.device
379
- metadata.cu_seqlens_k = torch.nn.functional.pad(
380
- torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
381
- )
382
- # Precompute maximum sequence length
383
- metadata.max_seq_len_k = seq_lens.max().item()
384
- # Precompute page table
385
- metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
386
- req_pool_indices, :
387
- ]
388
- if forward_mode == ForwardMode.DECODE:
389
- # Precompute cumulative sequence lengths
390
- metadata.cu_seqlens_q = torch.arange(
391
- 0, batch_size + 1, dtype=torch.int32, device=device
503
+ if forward_mode.is_decode():
504
+ if spec_info is not None:
505
+ # Draft Decode
506
+ metadata.cu_seqlens_q = torch.arange(
507
+ 0, bs + 1, dtype=torch.int32, device=device
508
+ )
509
+ metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
510
+ "cache_seqlens"
511
+ ][:bs]
512
+
513
+ metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
514
+ : bs + 1
515
+ ]
516
+
517
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
518
+ torch.cumsum(
519
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
520
+ ),
521
+ (1, 0),
522
+ )
523
+ metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1)
524
+ metadata.page_table = self.decode_cuda_graph_metadata[
525
+ "page_table_draft_decode"
526
+ ][req_pool_indices, :]
527
+ else:
528
+ # Normal Decode
529
+ # Get sequence information
530
+ metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
531
+ batch_size = len(seq_lens)
532
+ device = seq_lens.device
533
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
534
+ torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
535
+ )
536
+ # Precompute maximum sequence length
537
+ metadata.max_seq_len_k = seq_lens.max().item()
538
+ # Precompute page table
539
+ metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
540
+ req_pool_indices, :
541
+ ]
542
+ # Precompute cumulative sequence lengths
543
+ metadata.cu_seqlens_q = torch.arange(
544
+ 0, batch_size + 1, dtype=torch.int32, device=device
545
+ )
546
+ self.decode_cuda_graph_metadata[bs] = metadata
547
+ elif forward_mode.is_target_verify():
548
+ draft_token_num = spec_info.draft_token_num
549
+
550
+ metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
551
+ :bs
552
+ ]
553
+ metadata.cache_seqlens_int32.copy_(
554
+ (seq_lens + draft_token_num).to(torch.int32)
392
555
  )
393
- else:
394
- raise ValueError("Do not support Prefill Mode cuda graph")
395
- self.decode_cuda_graph_metadata[bs] = metadata
556
+
557
+ metadata.max_seq_len_q = draft_token_num
558
+ metadata.max_seq_len_k = seq_lens.max().item() + draft_token_num
559
+
560
+ metadata.cu_seqlens_q = self.target_verify_metadata["cu_seqlens_q"][
561
+ torch.arange(
562
+ 0,
563
+ bs * draft_token_num + 1,
564
+ draft_token_num,
565
+ dtype=torch.int32,
566
+ device=device,
567
+ )
568
+ ]
569
+ cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)]
570
+ cu_k.copy_(
571
+ torch.nn.functional.pad(
572
+ torch.cumsum(
573
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
574
+ ),
575
+ (1, 0),
576
+ )
577
+ )
578
+ metadata.cu_seqlens_k = cu_k
579
+ metadata.page_table = self.target_verify_metadata["page_table"][
580
+ req_pool_indices, :
581
+ ]
582
+
583
+ self.target_verify_metadata[bs] = metadata
584
+
396
585
  self.forward_metadata = metadata
397
586
 
398
587
  def init_forward_metadata_replay_cuda_graph(
@@ -405,30 +594,159 @@ class FlashAttentionBackend(AttentionBackend):
405
594
  forward_mode: ForwardMode,
406
595
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
407
596
  seq_lens_cpu: Optional[torch.Tensor],
597
+ out_cache_loc: torch.Tensor = None,
408
598
  ):
409
599
  # """Initialize forward metadata for replaying CUDA graph."""
410
- metadata = self.decode_cuda_graph_metadata[bs]
600
+ device = seq_lens.device
601
+ seq_lens = seq_lens[:bs]
602
+ req_pool_indices = req_pool_indices[:bs]
603
+ seq_lens_cpu = seq_lens_cpu[:bs]
604
+ if forward_mode.is_decode():
605
+ metadata = self.decode_cuda_graph_metadata[bs]
606
+
607
+ if spec_info is not None:
608
+ # Draft Decode
609
+ max_len = seq_lens_cpu.max().item()
610
+ metadata.max_seq_len_k = max_len + (self.step_id + 1)
611
+
612
+ metadata.cache_seqlens_int32.copy_(
613
+ (seq_lens + (self.step_id + 1)).to(torch.int32)
614
+ )
411
615
 
412
- # For CPU operations
413
- max_len = seq_lens_cpu[:bs].max().item()
414
- metadata.max_seq_len_k = max_len
616
+ metadata.max_seq_len_k = seq_lens_cpu.max().item() + (self.step_id + 1)
415
617
 
416
- # For GPU operations
417
- seq_lens_in_batch = seq_lens[:bs]
418
- metadata.cache_seqlens_int32 = seq_lens_in_batch.to(torch.int32)
419
- metadata.cu_seqlens_k = torch.nn.functional.pad(
420
- torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0)
421
- )
618
+ metadata.cu_seqlens_k.copy_(
619
+ torch.nn.functional.pad(
620
+ torch.cumsum(
621
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
622
+ ),
623
+ (1, 0),
624
+ )
625
+ )
626
+
627
+ page_table = self.req_to_token[
628
+ req_pool_indices, : metadata.max_seq_len_k
629
+ ]
630
+
631
+ metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
632
+ else:
633
+ # Normal Decode
634
+ max_len = seq_lens_cpu.max().item()
635
+ metadata.max_seq_len_k = max_len
636
+
637
+ metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
638
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
639
+ torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
640
+ )
641
+
642
+ max_seq_pages = (
643
+ metadata.max_seq_len_k + self.page_size - 1
644
+ ) // self.page_size
645
+ page_indices = self.req_to_token[
646
+ :,
647
+ self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
648
+ ]
649
+ page_indices = page_indices[req_pool_indices] // self.page_size
650
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices)
651
+ metadata.page_table[:, max_seq_pages:].fill_(0)
652
+
653
+ elif forward_mode.is_target_verify():
654
+ metadata = self.target_verify_metadata[bs]
655
+ draft_token_num = spec_info.draft_token_num
656
+
657
+ metadata.cu_seqlens_q.copy_(
658
+ torch.arange(
659
+ 0,
660
+ bs * draft_token_num + 1,
661
+ draft_token_num,
662
+ dtype=torch.int32,
663
+ device=device,
664
+ )
665
+ )
666
+ metadata.cache_seqlens_int32.copy_(
667
+ (seq_lens + draft_token_num).to(torch.int32)
668
+ )
669
+
670
+ metadata.max_seq_len_k = seq_lens_cpu.max().item() + draft_token_num
671
+ metadata.cu_seqlens_k.copy_(
672
+ torch.nn.functional.pad(
673
+ torch.cumsum(
674
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
675
+ ),
676
+ (1, 0),
677
+ )
678
+ )
679
+ page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
680
+ metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
422
681
 
423
- max_seq_pages = (metadata.max_seq_len_k + self.page_size - 1) // self.page_size
424
- page_indices = self.req_to_token[
425
- :, self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages]
426
- ]
427
- page_indices = page_indices[req_pool_indices[:bs]] // self.page_size
428
- metadata.page_table[:, :max_seq_pages].copy_(page_indices)
429
- metadata.page_table[:, max_seq_pages:].fill_(0)
430
682
  self.forward_metadata = metadata
431
683
 
432
684
  def get_cuda_graph_seq_len_fill_value(self):
433
685
  """Get the fill value for sequence length in CUDA graph."""
434
686
  return 0
687
+
688
+
689
+ class FlashAttentionMultiStepBackend:
690
+
691
+ def __init__(
692
+ self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
693
+ ):
694
+ self.model_runner = model_runner
695
+ self.topk = topk
696
+ self.speculative_num_steps = speculative_num_steps
697
+
698
+ self.attn_backends = []
699
+ for i in range(self.speculative_num_steps):
700
+ self.attn_backends.append(
701
+ FlashAttentionBackend(
702
+ model_runner,
703
+ topk=self.topk,
704
+ speculative_num_steps=self.speculative_num_steps,
705
+ step_id=i,
706
+ )
707
+ )
708
+
709
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
710
+ for i in range(self.speculative_num_steps - 1):
711
+ self.attn_backends[i].init_forward_metadata(forward_batch)
712
+
713
+ def init_cuda_graph_state(self, max_bs: int):
714
+ for i in range(self.speculative_num_steps):
715
+ self.attn_backends[i].init_cuda_graph_state(max_bs)
716
+
717
+ def init_forward_metadata_capture_cuda_graph(
718
+ self,
719
+ forward_batch: ForwardBatch,
720
+ ):
721
+ assert forward_batch.spec_info is not None
722
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
723
+
724
+ for i in range(self.speculative_num_steps - 1):
725
+ self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
726
+ forward_batch.batch_size,
727
+ forward_batch.batch_size * self.topk,
728
+ forward_batch.req_pool_indices,
729
+ forward_batch.seq_lens,
730
+ encoder_lens=None,
731
+ forward_mode=ForwardMode.DECODE,
732
+ spec_info=forward_batch.spec_info,
733
+ )
734
+
735
+ def init_forward_metadata_replay_cuda_graph(
736
+ self, forward_batch: ForwardBatch, bs: int
737
+ ):
738
+ assert forward_batch.spec_info is not None
739
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
740
+
741
+ for i in range(self.speculative_num_steps - 1):
742
+ self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
743
+ bs,
744
+ forward_batch.req_pool_indices,
745
+ forward_batch.seq_lens,
746
+ forward_batch.seq_lens_sum,
747
+ encoder_lens=None,
748
+ forward_mode=ForwardMode.DECODE,
749
+ spec_info=forward_batch.spec_info,
750
+ seq_lens_cpu=forward_batch.seq_lens_cpu,
751
+ out_cache_loc=forward_batch.out_cache_loc,
752
+ )
@@ -14,7 +14,6 @@ from functools import partial
14
14
  from typing import TYPE_CHECKING, Callable, List, Optional, Union
15
15
 
16
16
  import torch
17
- import triton
18
17
 
19
18
  from sglang.global_config import global_config
20
19
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -22,7 +21,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
22
21
  from sglang.srt.layers.dp_attention import get_attention_tp_size
23
22
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
24
23
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
25
- from sglang.srt.utils import get_bool_env_var, is_flashinfer_available
24
+ from sglang.srt.utils import is_flashinfer_available, next_power_of_2
26
25
 
27
26
  if TYPE_CHECKING:
28
27
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -932,6 +931,7 @@ class FlashInferMultiStepDraftBackend:
932
931
  self.topk = topk
933
932
  self.speculative_num_steps = speculative_num_steps
934
933
  self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
934
+ self.page_size = model_runner.page_size
935
935
 
936
936
  max_bs = model_runner.req_to_token_pool.size * self.topk
937
937
  self.kv_indptr = torch.zeros(
@@ -985,9 +985,9 @@ class FlashInferMultiStepDraftBackend:
985
985
  self.pool_len,
986
986
  kv_indices_buffer.shape[1],
987
987
  self.kv_indptr.shape[1],
988
- triton.next_power_of_2(num_seqs),
989
- triton.next_power_of_2(self.speculative_num_steps),
990
- triton.next_power_of_2(bs),
988
+ next_power_of_2(num_seqs),
989
+ next_power_of_2(self.speculative_num_steps),
990
+ next_power_of_2(bs),
991
991
  )
992
992
 
993
993
  assert forward_batch.spec_info is not None
@@ -1018,8 +1018,6 @@ class FlashInferMultiStepDraftBackend:
1018
1018
  )
1019
1019
 
1020
1020
  def call_fn(i, forward_batch):
1021
- assert forward_batch.spec_info is not None
1022
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
1023
1021
  forward_batch.spec_info.kv_indptr = (
1024
1022
  forward_batch.spec_info.kv_indptr.clone()
1025
1023
  )
@@ -71,8 +71,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
71
71
  self.device = model_runner.device
72
72
  self.skip_prefill = skip_prefill
73
73
 
74
- global_config.enable_flashinfer_mla = True
75
-
76
74
  # Allocate buffers
77
75
  global global_workspace_buffer
78
76
  if global_workspace_buffer is None:
@@ -797,7 +795,7 @@ class FlashInferMLAMultiStepDraftBackend:
797
795
  encoder_lens=None,
798
796
  forward_mode=ForwardMode.DECODE,
799
797
  spec_info=forward_batch.spec_info,
800
- seq_lens_cpu=forward_batch.decode_seq_lens_cpu,
798
+ seq_lens_cpu=forward_batch.seq_lens_cpu,
801
799
  )
802
800
 
803
801
  self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
@@ -92,7 +92,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
92
92
  if forward_batch.forward_mode.is_decode_or_idle():
93
93
  if spec_info is None:
94
94
  max_seqlen_pad = triton.cdiv(
95
- forward_batch.decode_seq_lens_cpu.max().item(), PAGE_SIZE
95
+ forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
96
96
  )
97
97
  block_kv_indices = torch.full(
98
98
  (bs, max_seqlen_pad),