sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,7 @@ Enable speculative sampling in FlashMLA
8
8
  """
9
9
 
10
10
  from dataclasses import dataclass
11
- from typing import TYPE_CHECKING, Optional, Union
11
+ from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union
12
12
 
13
13
  import torch
14
14
  import triton
@@ -30,8 +30,8 @@ if TYPE_CHECKING:
30
30
 
31
31
  # FlashMLA only supports pagesize=64
32
32
  PAGE_SIZE = 64
33
- # TODO The current setup is hard-coded and will be changed after integrating with MTP.
34
- Q_LEN = 1
33
+
34
+ # FlashMLA FP8 issue: https://github.com/deepseek-ai/FlashMLA/issues/56
35
35
 
36
36
 
37
37
  @dataclass
@@ -52,7 +52,7 @@ class FlashMLADecodeMetadata:
52
52
 
53
53
 
54
54
  class FlashMLABackend(FlashInferMLAAttnBackend):
55
- """Flashinfer attention kernels."""
55
+ """Flashmla attention kernels."""
56
56
 
57
57
  def __init__(
58
58
  self,
@@ -82,42 +82,72 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
82
82
  self.q_data_type = model_runner.dtype
83
83
  self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
84
84
 
85
+ self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
86
+
85
87
  def init_forward_metadata(self, forward_batch: ForwardBatch):
86
88
 
87
89
  bs = forward_batch.batch_size
88
- spec_info = forward_batch.spec_info
89
90
  if forward_batch.forward_mode.is_decode_or_idle():
90
- if spec_info is None:
91
- max_seqlen_pad = triton.cdiv(
92
- forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
93
- )
94
- block_kv_indices = torch.full(
95
- (bs, max_seqlen_pad),
96
- -1,
97
- dtype=torch.int32,
98
- device=forward_batch.seq_lens.device,
99
- )
100
- create_flashmla_kv_indices_triton[(bs,)](
101
- self.req_to_token,
102
- forward_batch.req_pool_indices,
103
- forward_batch.seq_lens,
104
- None,
105
- block_kv_indices,
106
- self.req_to_token.stride(0),
107
- max_seqlen_pad,
108
- )
109
- mla_metadata, num_splits = get_mla_metadata(
110
- forward_batch.seq_lens.to(torch.int32),
111
- Q_LEN * self.num_q_heads,
112
- 1,
113
- )
114
- self.forward_metadata = FlashMLADecodeMetadata(
115
- mla_metadata,
116
- num_splits,
117
- block_kv_indices,
118
- )
119
- else:
120
- super().init_forward_metadata(forward_batch)
91
+ max_seqlen_pad = triton.cdiv(
92
+ forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
93
+ )
94
+ block_kv_indices = torch.full(
95
+ (bs, max_seqlen_pad),
96
+ -1,
97
+ dtype=torch.int32,
98
+ device=forward_batch.seq_lens.device,
99
+ )
100
+ create_flashmla_kv_indices_triton[(bs,)](
101
+ self.req_to_token,
102
+ forward_batch.req_pool_indices,
103
+ forward_batch.seq_lens,
104
+ None,
105
+ block_kv_indices,
106
+ self.req_to_token.stride(0),
107
+ max_seqlen_pad,
108
+ )
109
+ mla_metadata, num_splits = get_mla_metadata(
110
+ forward_batch.seq_lens.to(torch.int32),
111
+ self.num_q_heads,
112
+ 1,
113
+ )
114
+ self.forward_metadata = FlashMLADecodeMetadata(
115
+ mla_metadata,
116
+ num_splits,
117
+ block_kv_indices,
118
+ )
119
+ elif forward_batch.forward_mode.is_target_verify():
120
+ seq_lens_cpu = forward_batch.seq_lens_cpu + self.num_draft_tokens
121
+ seq_lens = forward_batch.seq_lens + self.num_draft_tokens
122
+
123
+ max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
124
+ block_kv_indices = torch.full(
125
+ (bs, max_seqlen_pad),
126
+ -1,
127
+ dtype=torch.int32,
128
+ device=seq_lens.device,
129
+ )
130
+ create_flashmla_kv_indices_triton[(bs,)](
131
+ self.req_to_token,
132
+ forward_batch.req_pool_indices,
133
+ seq_lens,
134
+ None,
135
+ block_kv_indices,
136
+ self.req_to_token.stride(0),
137
+ max_seqlen_pad,
138
+ )
139
+ mla_metadata, num_splits = get_mla_metadata(
140
+ seq_lens.to(torch.int32),
141
+ self.num_draft_tokens * self.num_q_heads,
142
+ 1,
143
+ )
144
+
145
+ # Use FlashMLADecodeMetadata which has the attributes forward_extend expects
146
+ self.forward_metadata = FlashMLADecodeMetadata(
147
+ mla_metadata,
148
+ num_splits,
149
+ block_kv_indices,
150
+ )
121
151
  else:
122
152
  super().init_forward_metadata(forward_batch)
123
153
 
@@ -136,11 +166,22 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
136
166
  else:
137
167
  cuda_graph_kv_indices = block_kv_indices
138
168
 
139
- self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
140
- torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
141
- Q_LEN * self.num_q_heads,
142
- 1,
143
- )
169
+ if self.num_draft_tokens:
170
+ self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
171
+ torch.ones(
172
+ max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
173
+ ),
174
+ self.num_draft_tokens * self.num_q_heads,
175
+ 1,
176
+ )
177
+ else:
178
+ self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
179
+ torch.ones(
180
+ max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
181
+ ),
182
+ self.num_q_heads,
183
+ 1,
184
+ )
144
185
  self.cuda_graph_kv_indices = cuda_graph_kv_indices
145
186
 
146
187
  def init_forward_metadata_capture_cuda_graph(
@@ -154,31 +195,54 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
154
195
  spec_info: Optional[SpecInfo],
155
196
  ):
156
197
  if forward_mode.is_decode_or_idle():
157
- if spec_info is None:
158
- max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
159
-
160
- create_flashmla_kv_indices_triton[(bs,)](
161
- self.req_to_token,
162
- req_pool_indices,
163
- seq_lens,
164
- None,
165
- self.cuda_graph_kv_indices,
166
- self.req_to_token.stride(0),
167
- self.cuda_graph_kv_indices.stride(0),
168
- )
169
- mla_metadata, num_splits = get_mla_metadata(
170
- seq_lens.to(torch.int32),
171
- Q_LEN * self.num_q_heads,
172
- 1,
173
- )
174
- self.cuda_graph_mla_metadata.copy_(mla_metadata)
175
- self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
176
- self.forward_metadata = FlashMLADecodeMetadata(
177
- self.cuda_graph_mla_metadata,
178
- self.cuda_graph_num_splits[: bs + 1],
179
- self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
180
- )
198
+ max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
181
199
 
200
+ create_flashmla_kv_indices_triton[(bs,)](
201
+ self.req_to_token,
202
+ req_pool_indices,
203
+ seq_lens,
204
+ None,
205
+ self.cuda_graph_kv_indices,
206
+ self.req_to_token.stride(0),
207
+ self.cuda_graph_kv_indices.stride(0),
208
+ )
209
+ mla_metadata, num_splits = get_mla_metadata(
210
+ seq_lens.to(torch.int32),
211
+ self.num_q_heads,
212
+ 1,
213
+ )
214
+ self.cuda_graph_mla_metadata.copy_(mla_metadata)
215
+ self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
216
+ self.forward_metadata = FlashMLADecodeMetadata(
217
+ self.cuda_graph_mla_metadata,
218
+ self.cuda_graph_num_splits[: bs + 1],
219
+ self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
220
+ )
221
+ elif forward_mode.is_target_verify():
222
+ seq_lens = seq_lens + self.num_draft_tokens
223
+ max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
224
+
225
+ create_flashmla_kv_indices_triton[(bs,)](
226
+ self.req_to_token,
227
+ req_pool_indices,
228
+ seq_lens,
229
+ None,
230
+ self.cuda_graph_kv_indices,
231
+ self.req_to_token.stride(0),
232
+ self.cuda_graph_kv_indices.stride(0),
233
+ )
234
+ mla_metadata, num_splits = get_mla_metadata(
235
+ seq_lens.to(torch.int32),
236
+ self.num_draft_tokens * self.num_q_heads,
237
+ 1,
238
+ )
239
+ self.cuda_graph_mla_metadata.copy_(mla_metadata)
240
+ self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
241
+ self.forward_metadata = FlashMLADecodeMetadata(
242
+ self.cuda_graph_mla_metadata,
243
+ self.cuda_graph_num_splits[: bs + 1],
244
+ self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
245
+ )
182
246
  else:
183
247
  super().init_forward_metadata_capture_cuda_graph(
184
248
  bs,
@@ -218,7 +282,32 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
218
282
  )
219
283
  mla_metadata, num_splits = get_mla_metadata(
220
284
  seq_lens.to(torch.int32),
221
- Q_LEN * self.num_q_heads,
285
+ self.num_q_heads,
286
+ 1,
287
+ )
288
+ self.cuda_graph_mla_metadata.copy_(mla_metadata)
289
+ self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
290
+ self.forward_metadata.mla_metadata = self.cuda_graph_mla_metadata
291
+ self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1]
292
+ self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
293
+ :bs, :max_seqlen_pad
294
+ ]
295
+ elif forward_mode.is_target_verify():
296
+ seq_lens = seq_lens[:bs] + self.num_draft_tokens
297
+ seq_lens_cpu = seq_lens_cpu[:bs] + self.num_draft_tokens
298
+ max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
299
+ create_flashmla_kv_indices_triton[(bs,)](
300
+ self.req_to_token,
301
+ req_pool_indices[:bs],
302
+ seq_lens,
303
+ None,
304
+ self.cuda_graph_kv_indices,
305
+ self.req_to_token.stride(0),
306
+ self.cuda_graph_kv_indices.stride(0),
307
+ )
308
+ mla_metadata, num_splits = get_mla_metadata(
309
+ seq_lens.to(torch.int32),
310
+ self.num_draft_tokens * self.num_q_heads,
222
311
  1,
223
312
  )
224
313
  self.cuda_graph_mla_metadata.copy_(mla_metadata)
@@ -228,7 +317,6 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
228
317
  self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
229
318
  :bs, :max_seqlen_pad
230
319
  ]
231
-
232
320
  else:
233
321
  super().init_forward_metadata_replay_cuda_graph(
234
322
  bs,
@@ -268,17 +356,191 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
268
356
  k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
269
357
 
270
358
  reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
359
+ if self.data_type == torch.float8_e4m3fn:
360
+ reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn)
361
+ o, _ = flash_mla_with_kvcache(
362
+ q=reshape_q_fp8,
363
+ k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
364
+ block_table=self.forward_metadata.block_kv_indices[:bs],
365
+ cache_seqlens=forward_batch.seq_lens.to(torch.int32),
366
+ head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
367
+ tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
368
+ num_splits=self.forward_metadata.num_splits,
369
+ softmax_scale=layer.scaling,
370
+ causal=True,
371
+ descale_q=torch.ones((1), dtype=torch.float32, device=reshape_q.device),
372
+ descale_k=torch.ones((1), dtype=torch.float32, device=reshape_q.device),
373
+ )
374
+
375
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
376
+ else:
377
+ # todo: need check all causal True or False?
378
+ o, _ = flash_mla_with_kvcache(
379
+ q=reshape_q,
380
+ k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
381
+ block_table=self.forward_metadata.block_kv_indices[:bs],
382
+ cache_seqlens=forward_batch.seq_lens.to(torch.int32),
383
+ head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
384
+ tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
385
+ num_splits=self.forward_metadata.num_splits,
386
+ softmax_scale=layer.scaling,
387
+ causal=True,
388
+ )
389
+
390
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
391
+
392
+ def forward_extend(
393
+ self,
394
+ q: torch.Tensor,
395
+ k: torch.Tensor,
396
+ v: torch.Tensor,
397
+ layer: RadixAttention,
398
+ forward_batch: ForwardBatch,
399
+ save_kv_cache: bool = True,
400
+ ):
401
+ if (
402
+ forward_batch.forward_mode == ForwardMode.EXTEND
403
+ or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
404
+ ):
405
+ return super().forward_extend(q, k, v, layer, forward_batch, save_kv_cache)
406
+ else:
407
+ cache_loc = forward_batch.out_cache_loc
408
+
409
+ if k is not None:
410
+ assert v is not None
411
+ if save_kv_cache:
412
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
413
+
414
+ bs = forward_batch.batch_size
415
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
416
+
417
+ reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
418
+ if self.data_type == torch.float8_e4m3fn:
419
+ reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn)
420
+ o, _ = flash_mla_with_kvcache(
421
+ q=reshape_q_fp8,
422
+ k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
423
+ block_table=self.forward_metadata.block_kv_indices[:bs],
424
+ cache_seqlens=forward_batch.seq_lens.to(torch.int32)
425
+ + self.num_draft_tokens,
426
+ head_dim_v=self.kv_lora_rank,
427
+ tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
428
+ num_splits=self.forward_metadata.num_splits,
429
+ softmax_scale=layer.scaling,
430
+ causal=True,
431
+ descale_q=torch.ones(
432
+ (1), dtype=torch.float32, device=reshape_q.device
433
+ ),
434
+ descale_k=torch.ones(
435
+ (1), dtype=torch.float32, device=reshape_q.device
436
+ ),
437
+ )
438
+ else:
439
+ o, _ = flash_mla_with_kvcache(
440
+ q=reshape_q,
441
+ k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
442
+ block_table=self.forward_metadata.block_kv_indices[:bs],
443
+ cache_seqlens=forward_batch.seq_lens.to(torch.int32)
444
+ + self.num_draft_tokens,
445
+ head_dim_v=self.kv_lora_rank,
446
+ tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
447
+ num_splits=self.forward_metadata.num_splits,
448
+ softmax_scale=layer.scaling,
449
+ causal=True,
450
+ )
451
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
452
+
271
453
 
272
- o, _ = flash_mla_with_kvcache(
273
- q=reshape_q,
274
- k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
275
- block_table=self.forward_metadata.block_kv_indices,
276
- cache_seqlens=forward_batch.seq_lens.to(torch.int32),
277
- head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
278
- tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
279
- num_splits=self.forward_metadata.num_splits,
280
- softmax_scale=layer.scaling,
281
- causal=False,
454
+ # TODO: multi step kv indices optimization
455
+ class FlashMLAMultiStepDraftBackend:
456
+ """
457
+ Wrap multiple flashmla attention backends as one for multiple consecutive
458
+ draft decoding steps.
459
+ """
460
+
461
+ def __init__(
462
+ self,
463
+ model_runner: ModelRunner,
464
+ topk: int,
465
+ speculative_num_steps: int,
466
+ ):
467
+ from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
468
+
469
+ if topk > 1:
470
+ raise ValueError(
471
+ f"Currently FlashMLA only supports topk=1 for speculative decoding"
472
+ )
473
+ self.topk = topk
474
+ self.speculative_num_steps = speculative_num_steps
475
+ max_bs = model_runner.req_to_token_pool.size * self.topk
476
+ self.kv_indptr = torch.zeros(
477
+ (
478
+ self.speculative_num_steps,
479
+ max_bs + 1,
480
+ ),
481
+ dtype=torch.int32,
482
+ device=model_runner.device,
282
483
  )
283
484
 
284
- return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
485
+ self.attn_backends = []
486
+ for i in range(self.speculative_num_steps):
487
+ self.attn_backends.append(
488
+ FlashMLABackend(
489
+ model_runner,
490
+ skip_prefill=True,
491
+ kv_indptr_buf=self.kv_indptr[i],
492
+ kv_last_page_len_buf=None,
493
+ )
494
+ )
495
+
496
+ def common_template(
497
+ self,
498
+ forward_batch: ForwardBatch,
499
+ call_fn: Callable,
500
+ ):
501
+ assert forward_batch.spec_info is not None
502
+
503
+ for i in range(self.speculative_num_steps - 1):
504
+ call_fn(i, forward_batch)
505
+
506
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
507
+ def call_fn(i, forward_batch):
508
+ assert forward_batch.spec_info is not None
509
+ self.attn_backends[i].init_forward_metadata(forward_batch)
510
+
511
+ self.common_template(forward_batch, call_fn)
512
+
513
+ def init_cuda_graph_state(self, max_bs: int):
514
+ for i in range(self.speculative_num_steps):
515
+ self.attn_backends[i].init_cuda_graph_state(max_bs, block_kv_indices=None)
516
+
517
+ def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
518
+ def call_fn(i, forward_batch):
519
+ self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
520
+ forward_batch.batch_size,
521
+ forward_batch.batch_size * self.topk,
522
+ forward_batch.req_pool_indices,
523
+ forward_batch.seq_lens,
524
+ encoder_lens=None,
525
+ forward_mode=ForwardMode.DECODE,
526
+ spec_info=forward_batch.spec_info,
527
+ )
528
+
529
+ self.common_template(forward_batch, call_fn)
530
+
531
+ def init_forward_metadata_replay_cuda_graph(
532
+ self, forward_batch: ForwardBatch, bs: int
533
+ ):
534
+ def call_fn(i, forward_batch):
535
+ self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
536
+ bs,
537
+ forward_batch.req_pool_indices,
538
+ forward_batch.seq_lens,
539
+ seq_lens_sum=-1,
540
+ encoder_lens=None,
541
+ forward_mode=ForwardMode.DECODE,
542
+ spec_info=forward_batch.spec_info,
543
+ seq_lens_cpu=forward_batch.seq_lens_cpu,
544
+ )
545
+
546
+ self.common_template(forward_batch, call_fn)
@@ -155,6 +155,9 @@ class TritonAttnBackend(AttentionBackend):
155
155
  seq_lens: torch.Tensor,
156
156
  ):
157
157
  num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0]
158
+ # NOTE(alcanderian): Considering speculative_decodeing,
159
+ # num_kv_splits.shape[0] will be topk * real_num_token.
160
+ # And the real_num_token is num_seq in decoding phase.
158
161
  num_group = num_token // num_seq
159
162
 
160
163
  assert (
@@ -919,7 +919,7 @@ def _fwd_kernel(
919
919
 
920
920
  e_max = n_e_max
921
921
 
922
- # stage 2: compute the trianlge part
922
+ # stage 2: compute the triangle part
923
923
 
924
924
  cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
925
925
  for start_n in range(0, cur_block_m_end, BLOCK_N):
@@ -28,7 +28,8 @@ def create_flashinfer_kv_indices_triton(
28
28
 
29
29
  num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
30
30
  for i in range(num_loop):
31
- offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
31
+ # index into req_to_token_ptr needs to be int64
32
+ offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
32
33
  mask = offset < kv_end - kv_start
33
34
  data = tl.load(
34
35
  req_to_token_ptr
@@ -70,13 +71,14 @@ def create_flashmla_kv_indices_triton(
70
71
  num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
71
72
 
72
73
  for i in range(num_pages_loop):
74
+ # index into req_to_token_ptr needs to be int64
73
75
  paged_offset = (
74
- tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
76
+ tl.arange(0, NUM_PAGE_PER_BLOCK).to(tl.int64) + i * NUM_PAGE_PER_BLOCK
75
77
  ) * PAGED_SIZE
76
78
  paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
77
79
 
78
- mask = paged_offset <= num_paged * PAGED_SIZE
79
- mask_out = paged_offset_out <= num_paged
80
+ mask = paged_offset < num_paged * PAGED_SIZE
81
+ mask_out = paged_offset_out < num_paged
80
82
 
81
83
  data = tl.load(
82
84
  req_to_token_ptr
@@ -120,7 +120,7 @@ class VisionSdpaAttention(nn.Module):
120
120
  flatten_batch: bool = False,
121
121
  ) -> Optional[torch.Tensor]:
122
122
  r"""
123
- Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, s, s)`.
123
+ Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
124
124
  Args:
125
125
  s: sequence length
126
126
  cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask