sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +23 -15
  4. sglang/bench_serving.py +133 -57
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/srt/configs/model_config.py +39 -28
  7. sglang/srt/conversation.py +1 -1
  8. sglang/srt/disaggregation/decode.py +122 -133
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  10. sglang/srt/disaggregation/fake/conn.py +3 -13
  11. sglang/srt/disaggregation/kv_events.py +357 -0
  12. sglang/srt/disaggregation/mini_lb.py +57 -24
  13. sglang/srt/disaggregation/mooncake/conn.py +11 -2
  14. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  15. sglang/srt/disaggregation/nixl/conn.py +9 -19
  16. sglang/srt/disaggregation/prefill.py +126 -44
  17. sglang/srt/disaggregation/utils.py +116 -5
  18. sglang/srt/distributed/utils.py +3 -3
  19. sglang/srt/entrypoints/EngineBase.py +5 -0
  20. sglang/srt/entrypoints/engine.py +28 -8
  21. sglang/srt/entrypoints/http_server.py +6 -4
  22. sglang/srt/entrypoints/http_server_engine.py +5 -2
  23. sglang/srt/function_call/base_format_detector.py +250 -0
  24. sglang/srt/function_call/core_types.py +34 -0
  25. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  26. sglang/srt/function_call/ebnf_composer.py +234 -0
  27. sglang/srt/function_call/function_call_parser.py +175 -0
  28. sglang/srt/function_call/llama32_detector.py +74 -0
  29. sglang/srt/function_call/mistral_detector.py +84 -0
  30. sglang/srt/function_call/pythonic_detector.py +163 -0
  31. sglang/srt/function_call/qwen25_detector.py +67 -0
  32. sglang/srt/function_call/utils.py +35 -0
  33. sglang/srt/hf_transformers_utils.py +46 -7
  34. sglang/srt/layers/attention/aiter_backend.py +513 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +63 -17
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  37. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  38. sglang/srt/layers/attention/triton_backend.py +3 -0
  39. sglang/srt/layers/attention/utils.py +2 -2
  40. sglang/srt/layers/attention/vision.py +1 -1
  41. sglang/srt/layers/communicator.py +451 -0
  42. sglang/srt/layers/dp_attention.py +0 -10
  43. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  44. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  45. sglang/srt/layers/moe/ep_moe/layer.py +104 -50
  46. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  48. sglang/srt/layers/moe/topk.py +66 -9
  49. sglang/srt/layers/multimodal.py +70 -0
  50. sglang/srt/layers/quantization/__init__.py +7 -2
  51. sglang/srt/layers/quantization/deep_gemm.py +5 -3
  52. sglang/srt/layers/quantization/fp8.py +90 -0
  53. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  54. sglang/srt/layers/quantization/gptq.py +298 -6
  55. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  56. sglang/srt/layers/quantization/qoq.py +244 -0
  57. sglang/srt/lora/lora_manager.py +1 -3
  58. sglang/srt/managers/deepseek_eplb.py +278 -0
  59. sglang/srt/managers/eplb_manager.py +55 -0
  60. sglang/srt/managers/expert_distribution.py +704 -56
  61. sglang/srt/managers/expert_location.py +394 -0
  62. sglang/srt/managers/expert_location_dispatch.py +91 -0
  63. sglang/srt/managers/io_struct.py +16 -3
  64. sglang/srt/managers/mm_utils.py +293 -139
  65. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  66. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  67. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  68. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  69. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  70. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  71. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  72. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  73. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  74. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  75. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  76. sglang/srt/managers/schedule_batch.py +49 -21
  77. sglang/srt/managers/schedule_policy.py +4 -5
  78. sglang/srt/managers/scheduler.py +92 -50
  79. sglang/srt/managers/session_controller.py +1 -1
  80. sglang/srt/managers/tokenizer_manager.py +99 -24
  81. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  82. sglang/srt/mem_cache/chunk_cache.py +3 -1
  83. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  84. sglang/srt/mem_cache/memory_pool.py +74 -52
  85. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  86. sglang/srt/mem_cache/radix_cache.py +58 -5
  87. sglang/srt/metrics/collector.py +2 -2
  88. sglang/srt/mm_utils.py +10 -0
  89. sglang/srt/model_executor/cuda_graph_runner.py +20 -9
  90. sglang/srt/model_executor/expert_location_updater.py +422 -0
  91. sglang/srt/model_executor/forward_batch_info.py +4 -0
  92. sglang/srt/model_executor/model_runner.py +144 -54
  93. sglang/srt/model_loader/loader.py +10 -6
  94. sglang/srt/models/clip.py +5 -1
  95. sglang/srt/models/deepseek_v2.py +297 -343
  96. sglang/srt/models/exaone.py +8 -3
  97. sglang/srt/models/gemma3_mm.py +70 -33
  98. sglang/srt/models/llama4.py +10 -2
  99. sglang/srt/models/llava.py +26 -18
  100. sglang/srt/models/mimo_mtp.py +220 -0
  101. sglang/srt/models/minicpmo.py +5 -12
  102. sglang/srt/models/mistral.py +71 -1
  103. sglang/srt/models/mllama.py +3 -3
  104. sglang/srt/models/qwen2.py +95 -26
  105. sglang/srt/models/qwen2_5_vl.py +8 -0
  106. sglang/srt/models/qwen2_moe.py +330 -60
  107. sglang/srt/models/qwen2_vl.py +6 -0
  108. sglang/srt/models/qwen3.py +52 -10
  109. sglang/srt/models/qwen3_moe.py +411 -48
  110. sglang/srt/models/siglip.py +294 -0
  111. sglang/srt/openai_api/adapter.py +28 -16
  112. sglang/srt/openai_api/protocol.py +6 -0
  113. sglang/srt/operations.py +154 -0
  114. sglang/srt/operations_strategy.py +31 -0
  115. sglang/srt/server_args.py +134 -24
  116. sglang/srt/speculative/eagle_utils.py +131 -0
  117. sglang/srt/speculative/eagle_worker.py +47 -2
  118. sglang/srt/utils.py +68 -12
  119. sglang/test/test_cutlass_moe.py +278 -0
  120. sglang/test/test_utils.py +2 -36
  121. sglang/utils.py +2 -2
  122. sglang/version.py +1 -1
  123. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
  124. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
  125. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  126. sglang/srt/function_call_parser.py +0 -858
  127. sglang/srt/platforms/interface.py +0 -371
  128. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  129. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,7 @@ Enable speculative sampling in FlashMLA
8
8
  """
9
9
 
10
10
  from dataclasses import dataclass
11
- from typing import TYPE_CHECKING, Optional, Union
11
+ from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union
12
12
 
13
13
  import torch
14
14
  import triton
@@ -30,8 +30,8 @@ if TYPE_CHECKING:
30
30
 
31
31
  # FlashMLA only supports pagesize=64
32
32
  PAGE_SIZE = 64
33
- # TODO The current setup is hard-coded and will be changed after integrating with MTP.
34
- Q_LEN = 1
33
+
34
+ # FlashMLA FP8 issue: https://github.com/deepseek-ai/FlashMLA/issues/56
35
35
 
36
36
 
37
37
  @dataclass
@@ -52,7 +52,7 @@ class FlashMLADecodeMetadata:
52
52
 
53
53
 
54
54
  class FlashMLABackend(FlashInferMLAAttnBackend):
55
- """Flashinfer attention kernels."""
55
+ """Flashmla attention kernels."""
56
56
 
57
57
  def __init__(
58
58
  self,
@@ -82,42 +82,72 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
82
82
  self.q_data_type = model_runner.dtype
83
83
  self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
84
84
 
85
+ self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
86
+
85
87
  def init_forward_metadata(self, forward_batch: ForwardBatch):
86
88
 
87
89
  bs = forward_batch.batch_size
88
- spec_info = forward_batch.spec_info
89
90
  if forward_batch.forward_mode.is_decode_or_idle():
90
- if spec_info is None:
91
- max_seqlen_pad = triton.cdiv(
92
- forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
93
- )
94
- block_kv_indices = torch.full(
95
- (bs, max_seqlen_pad),
96
- -1,
97
- dtype=torch.int32,
98
- device=forward_batch.seq_lens.device,
99
- )
100
- create_flashmla_kv_indices_triton[(bs,)](
101
- self.req_to_token,
102
- forward_batch.req_pool_indices,
103
- forward_batch.seq_lens,
104
- None,
105
- block_kv_indices,
106
- self.req_to_token.stride(0),
107
- max_seqlen_pad,
108
- )
109
- mla_metadata, num_splits = get_mla_metadata(
110
- forward_batch.seq_lens.to(torch.int32),
111
- Q_LEN * self.num_q_heads,
112
- 1,
113
- )
114
- self.forward_metadata = FlashMLADecodeMetadata(
115
- mla_metadata,
116
- num_splits,
117
- block_kv_indices,
118
- )
119
- else:
120
- super().init_forward_metadata(forward_batch)
91
+ max_seqlen_pad = triton.cdiv(
92
+ forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
93
+ )
94
+ block_kv_indices = torch.full(
95
+ (bs, max_seqlen_pad),
96
+ -1,
97
+ dtype=torch.int32,
98
+ device=forward_batch.seq_lens.device,
99
+ )
100
+ create_flashmla_kv_indices_triton[(bs,)](
101
+ self.req_to_token,
102
+ forward_batch.req_pool_indices,
103
+ forward_batch.seq_lens,
104
+ None,
105
+ block_kv_indices,
106
+ self.req_to_token.stride(0),
107
+ max_seqlen_pad,
108
+ )
109
+ mla_metadata, num_splits = get_mla_metadata(
110
+ forward_batch.seq_lens.to(torch.int32),
111
+ self.num_q_heads,
112
+ 1,
113
+ )
114
+ self.forward_metadata = FlashMLADecodeMetadata(
115
+ mla_metadata,
116
+ num_splits,
117
+ block_kv_indices,
118
+ )
119
+ elif forward_batch.forward_mode.is_target_verify():
120
+ seq_lens_cpu = forward_batch.seq_lens_cpu + self.num_draft_tokens
121
+ seq_lens = forward_batch.seq_lens + self.num_draft_tokens
122
+
123
+ max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
124
+ block_kv_indices = torch.full(
125
+ (bs, max_seqlen_pad),
126
+ -1,
127
+ dtype=torch.int32,
128
+ device=seq_lens.device,
129
+ )
130
+ create_flashmla_kv_indices_triton[(bs,)](
131
+ self.req_to_token,
132
+ forward_batch.req_pool_indices,
133
+ seq_lens,
134
+ None,
135
+ block_kv_indices,
136
+ self.req_to_token.stride(0),
137
+ max_seqlen_pad,
138
+ )
139
+ mla_metadata, num_splits = get_mla_metadata(
140
+ seq_lens.to(torch.int32),
141
+ self.num_draft_tokens * self.num_q_heads,
142
+ 1,
143
+ )
144
+
145
+ # Use FlashMLADecodeMetadata which has the attributes forward_extend expects
146
+ self.forward_metadata = FlashMLADecodeMetadata(
147
+ mla_metadata,
148
+ num_splits,
149
+ block_kv_indices,
150
+ )
121
151
  else:
122
152
  super().init_forward_metadata(forward_batch)
123
153
 
@@ -136,11 +166,22 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
136
166
  else:
137
167
  cuda_graph_kv_indices = block_kv_indices
138
168
 
139
- self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
140
- torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
141
- Q_LEN * self.num_q_heads,
142
- 1,
143
- )
169
+ if self.num_draft_tokens:
170
+ self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
171
+ torch.ones(
172
+ max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
173
+ ),
174
+ self.num_draft_tokens * self.num_q_heads,
175
+ 1,
176
+ )
177
+ else:
178
+ self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
179
+ torch.ones(
180
+ max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
181
+ ),
182
+ self.num_q_heads,
183
+ 1,
184
+ )
144
185
  self.cuda_graph_kv_indices = cuda_graph_kv_indices
145
186
 
146
187
  def init_forward_metadata_capture_cuda_graph(
@@ -154,31 +195,54 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
154
195
  spec_info: Optional[SpecInfo],
155
196
  ):
156
197
  if forward_mode.is_decode_or_idle():
157
- if spec_info is None:
158
- max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
159
-
160
- create_flashmla_kv_indices_triton[(bs,)](
161
- self.req_to_token,
162
- req_pool_indices,
163
- seq_lens,
164
- None,
165
- self.cuda_graph_kv_indices,
166
- self.req_to_token.stride(0),
167
- self.cuda_graph_kv_indices.stride(0),
168
- )
169
- mla_metadata, num_splits = get_mla_metadata(
170
- seq_lens.to(torch.int32),
171
- Q_LEN * self.num_q_heads,
172
- 1,
173
- )
174
- self.cuda_graph_mla_metadata.copy_(mla_metadata)
175
- self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
176
- self.forward_metadata = FlashMLADecodeMetadata(
177
- self.cuda_graph_mla_metadata,
178
- self.cuda_graph_num_splits[: bs + 1],
179
- self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
180
- )
198
+ max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
181
199
 
200
+ create_flashmla_kv_indices_triton[(bs,)](
201
+ self.req_to_token,
202
+ req_pool_indices,
203
+ seq_lens,
204
+ None,
205
+ self.cuda_graph_kv_indices,
206
+ self.req_to_token.stride(0),
207
+ self.cuda_graph_kv_indices.stride(0),
208
+ )
209
+ mla_metadata, num_splits = get_mla_metadata(
210
+ seq_lens.to(torch.int32),
211
+ self.num_q_heads,
212
+ 1,
213
+ )
214
+ self.cuda_graph_mla_metadata.copy_(mla_metadata)
215
+ self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
216
+ self.forward_metadata = FlashMLADecodeMetadata(
217
+ self.cuda_graph_mla_metadata,
218
+ self.cuda_graph_num_splits[: bs + 1],
219
+ self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
220
+ )
221
+ elif forward_mode.is_target_verify():
222
+ seq_lens = seq_lens + self.num_draft_tokens
223
+ max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
224
+
225
+ create_flashmla_kv_indices_triton[(bs,)](
226
+ self.req_to_token,
227
+ req_pool_indices,
228
+ seq_lens,
229
+ None,
230
+ self.cuda_graph_kv_indices,
231
+ self.req_to_token.stride(0),
232
+ self.cuda_graph_kv_indices.stride(0),
233
+ )
234
+ mla_metadata, num_splits = get_mla_metadata(
235
+ seq_lens.to(torch.int32),
236
+ self.num_draft_tokens * self.num_q_heads,
237
+ 1,
238
+ )
239
+ self.cuda_graph_mla_metadata.copy_(mla_metadata)
240
+ self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
241
+ self.forward_metadata = FlashMLADecodeMetadata(
242
+ self.cuda_graph_mla_metadata,
243
+ self.cuda_graph_num_splits[: bs + 1],
244
+ self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
245
+ )
182
246
  else:
183
247
  super().init_forward_metadata_capture_cuda_graph(
184
248
  bs,
@@ -218,7 +282,32 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
218
282
  )
219
283
  mla_metadata, num_splits = get_mla_metadata(
220
284
  seq_lens.to(torch.int32),
221
- Q_LEN * self.num_q_heads,
285
+ self.num_q_heads,
286
+ 1,
287
+ )
288
+ self.cuda_graph_mla_metadata.copy_(mla_metadata)
289
+ self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
290
+ self.forward_metadata.mla_metadata = self.cuda_graph_mla_metadata
291
+ self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1]
292
+ self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
293
+ :bs, :max_seqlen_pad
294
+ ]
295
+ elif forward_mode.is_target_verify():
296
+ seq_lens = seq_lens[:bs] + self.num_draft_tokens
297
+ seq_lens_cpu = seq_lens_cpu[:bs] + self.num_draft_tokens
298
+ max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
299
+ create_flashmla_kv_indices_triton[(bs,)](
300
+ self.req_to_token,
301
+ req_pool_indices[:bs],
302
+ seq_lens,
303
+ None,
304
+ self.cuda_graph_kv_indices,
305
+ self.req_to_token.stride(0),
306
+ self.cuda_graph_kv_indices.stride(0),
307
+ )
308
+ mla_metadata, num_splits = get_mla_metadata(
309
+ seq_lens.to(torch.int32),
310
+ self.num_draft_tokens * self.num_q_heads,
222
311
  1,
223
312
  )
224
313
  self.cuda_graph_mla_metadata.copy_(mla_metadata)
@@ -228,7 +317,6 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
228
317
  self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
229
318
  :bs, :max_seqlen_pad
230
319
  ]
231
-
232
320
  else:
233
321
  super().init_forward_metadata_replay_cuda_graph(
234
322
  bs,
@@ -268,17 +356,191 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
268
356
  k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
269
357
 
270
358
  reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
359
+ if self.data_type == torch.float8_e4m3fn:
360
+ reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn)
361
+ o, _ = flash_mla_with_kvcache(
362
+ q=reshape_q_fp8,
363
+ k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
364
+ block_table=self.forward_metadata.block_kv_indices[:bs],
365
+ cache_seqlens=forward_batch.seq_lens.to(torch.int32),
366
+ head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
367
+ tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
368
+ num_splits=self.forward_metadata.num_splits,
369
+ softmax_scale=layer.scaling,
370
+ causal=True,
371
+ descale_q=torch.ones((1), dtype=torch.float32, device=reshape_q.device),
372
+ descale_k=torch.ones((1), dtype=torch.float32, device=reshape_q.device),
373
+ )
374
+
375
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
376
+ else:
377
+ # todo: need check all causal True or False?
378
+ o, _ = flash_mla_with_kvcache(
379
+ q=reshape_q,
380
+ k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
381
+ block_table=self.forward_metadata.block_kv_indices[:bs],
382
+ cache_seqlens=forward_batch.seq_lens.to(torch.int32),
383
+ head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
384
+ tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
385
+ num_splits=self.forward_metadata.num_splits,
386
+ softmax_scale=layer.scaling,
387
+ causal=True,
388
+ )
389
+
390
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
391
+
392
+ def forward_extend(
393
+ self,
394
+ q: torch.Tensor,
395
+ k: torch.Tensor,
396
+ v: torch.Tensor,
397
+ layer: RadixAttention,
398
+ forward_batch: ForwardBatch,
399
+ save_kv_cache: bool = True,
400
+ ):
401
+ if (
402
+ forward_batch.forward_mode == ForwardMode.EXTEND
403
+ or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
404
+ ):
405
+ return super().forward_extend(q, k, v, layer, forward_batch, save_kv_cache)
406
+ else:
407
+ cache_loc = forward_batch.out_cache_loc
408
+
409
+ if k is not None:
410
+ assert v is not None
411
+ if save_kv_cache:
412
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
413
+
414
+ bs = forward_batch.batch_size
415
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
416
+
417
+ reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
418
+ if self.data_type == torch.float8_e4m3fn:
419
+ reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn)
420
+ o, _ = flash_mla_with_kvcache(
421
+ q=reshape_q_fp8,
422
+ k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
423
+ block_table=self.forward_metadata.block_kv_indices[:bs],
424
+ cache_seqlens=forward_batch.seq_lens.to(torch.int32)
425
+ + self.num_draft_tokens,
426
+ head_dim_v=self.kv_lora_rank,
427
+ tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
428
+ num_splits=self.forward_metadata.num_splits,
429
+ softmax_scale=layer.scaling,
430
+ causal=True,
431
+ descale_q=torch.ones(
432
+ (1), dtype=torch.float32, device=reshape_q.device
433
+ ),
434
+ descale_k=torch.ones(
435
+ (1), dtype=torch.float32, device=reshape_q.device
436
+ ),
437
+ )
438
+ else:
439
+ o, _ = flash_mla_with_kvcache(
440
+ q=reshape_q,
441
+ k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
442
+ block_table=self.forward_metadata.block_kv_indices[:bs],
443
+ cache_seqlens=forward_batch.seq_lens.to(torch.int32)
444
+ + self.num_draft_tokens,
445
+ head_dim_v=self.kv_lora_rank,
446
+ tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
447
+ num_splits=self.forward_metadata.num_splits,
448
+ softmax_scale=layer.scaling,
449
+ causal=True,
450
+ )
451
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
452
+
271
453
 
272
- o, _ = flash_mla_with_kvcache(
273
- q=reshape_q,
274
- k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
275
- block_table=self.forward_metadata.block_kv_indices,
276
- cache_seqlens=forward_batch.seq_lens.to(torch.int32),
277
- head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
278
- tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
279
- num_splits=self.forward_metadata.num_splits,
280
- softmax_scale=layer.scaling,
281
- causal=False,
454
+ # TODO: multi step kv indices optimization
455
+ class FlashMLAMultiStepDraftBackend:
456
+ """
457
+ Wrap multiple flashmla attention backends as one for multiple consecutive
458
+ draft decoding steps.
459
+ """
460
+
461
+ def __init__(
462
+ self,
463
+ model_runner: ModelRunner,
464
+ topk: int,
465
+ speculative_num_steps: int,
466
+ ):
467
+ from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
468
+
469
+ if topk > 1:
470
+ raise ValueError(
471
+ f"Currently FlashMLA only supports topk=1 for speculative decoding"
472
+ )
473
+ self.topk = topk
474
+ self.speculative_num_steps = speculative_num_steps
475
+ max_bs = model_runner.req_to_token_pool.size * self.topk
476
+ self.kv_indptr = torch.zeros(
477
+ (
478
+ self.speculative_num_steps,
479
+ max_bs + 1,
480
+ ),
481
+ dtype=torch.int32,
482
+ device=model_runner.device,
282
483
  )
283
484
 
284
- return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
485
+ self.attn_backends = []
486
+ for i in range(self.speculative_num_steps):
487
+ self.attn_backends.append(
488
+ FlashMLABackend(
489
+ model_runner,
490
+ skip_prefill=True,
491
+ kv_indptr_buf=self.kv_indptr[i],
492
+ kv_last_page_len_buf=None,
493
+ )
494
+ )
495
+
496
+ def common_template(
497
+ self,
498
+ forward_batch: ForwardBatch,
499
+ call_fn: Callable,
500
+ ):
501
+ assert forward_batch.spec_info is not None
502
+
503
+ for i in range(self.speculative_num_steps - 1):
504
+ call_fn(i, forward_batch)
505
+
506
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
507
+ def call_fn(i, forward_batch):
508
+ assert forward_batch.spec_info is not None
509
+ self.attn_backends[i].init_forward_metadata(forward_batch)
510
+
511
+ self.common_template(forward_batch, call_fn)
512
+
513
+ def init_cuda_graph_state(self, max_bs: int):
514
+ for i in range(self.speculative_num_steps):
515
+ self.attn_backends[i].init_cuda_graph_state(max_bs, block_kv_indices=None)
516
+
517
+ def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
518
+ def call_fn(i, forward_batch):
519
+ self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
520
+ forward_batch.batch_size,
521
+ forward_batch.batch_size * self.topk,
522
+ forward_batch.req_pool_indices,
523
+ forward_batch.seq_lens,
524
+ encoder_lens=None,
525
+ forward_mode=ForwardMode.DECODE,
526
+ spec_info=forward_batch.spec_info,
527
+ )
528
+
529
+ self.common_template(forward_batch, call_fn)
530
+
531
+ def init_forward_metadata_replay_cuda_graph(
532
+ self, forward_batch: ForwardBatch, bs: int
533
+ ):
534
+ def call_fn(i, forward_batch):
535
+ self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
536
+ bs,
537
+ forward_batch.req_pool_indices,
538
+ forward_batch.seq_lens,
539
+ seq_lens_sum=-1,
540
+ encoder_lens=None,
541
+ forward_mode=ForwardMode.DECODE,
542
+ spec_info=forward_batch.spec_info,
543
+ seq_lens_cpu=forward_batch.seq_lens_cpu,
544
+ )
545
+
546
+ self.common_template(forward_batch, call_fn)
@@ -155,6 +155,9 @@ class TritonAttnBackend(AttentionBackend):
155
155
  seq_lens: torch.Tensor,
156
156
  ):
157
157
  num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0]
158
+ # NOTE(alcanderian): Considering speculative_decodeing,
159
+ # num_kv_splits.shape[0] will be topk * real_num_token.
160
+ # And the real_num_token is num_seq in decoding phase.
158
161
  num_group = num_token // num_seq
159
162
 
160
163
  assert (
@@ -77,8 +77,8 @@ def create_flashmla_kv_indices_triton(
77
77
  ) * PAGED_SIZE
78
78
  paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
79
79
 
80
- mask = paged_offset <= num_paged * PAGED_SIZE
81
- mask_out = paged_offset_out <= num_paged
80
+ mask = paged_offset < num_paged * PAGED_SIZE
81
+ mask_out = paged_offset_out < num_paged
82
82
 
83
83
  data = tl.load(
84
84
  req_to_token_ptr
@@ -120,7 +120,7 @@ class VisionSdpaAttention(nn.Module):
120
120
  flatten_batch: bool = False,
121
121
  ) -> Optional[torch.Tensor]:
122
122
  r"""
123
- Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, s, s)`.
123
+ Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
124
124
  Args:
125
125
  s: sequence length
126
126
  cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask