sglang 0.4.2.post3__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sglang/check_env.py +1 -0
  2. sglang/global_config.py +2 -0
  3. sglang/srt/constrained/outlines_backend.py +4 -1
  4. sglang/srt/entrypoints/engine.py +2 -2
  5. sglang/srt/layers/attention/flashinfer_backend.py +265 -147
  6. sglang/srt/layers/attention/triton_backend.py +358 -72
  7. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  8. sglang/srt/layers/linear.py +12 -5
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +178 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +175 -0
  18. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -5
  19. sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -0
  20. sglang/srt/layers/moe/topk.py +1 -1
  21. sglang/srt/layers/quantization/__init__.py +51 -5
  22. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
  25. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  31. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +29 -29
  32. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +33 -33
  35. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
  37. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  38. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +27 -27
  39. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  40. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
  41. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +24 -24
  49. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
  51. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +42 -42
  53. sglang/srt/layers/quantization/fp8_kernel.py +123 -17
  54. sglang/srt/layers/quantization/fp8_utils.py +33 -4
  55. sglang/srt/lora/backend/__init__.py +25 -5
  56. sglang/srt/lora/backend/base_backend.py +31 -9
  57. sglang/srt/lora/backend/flashinfer_backend.py +41 -4
  58. sglang/srt/lora/backend/triton_backend.py +34 -4
  59. sglang/srt/lora/layers.py +293 -0
  60. sglang/srt/lora/lora.py +101 -326
  61. sglang/srt/lora/lora_manager.py +101 -269
  62. sglang/srt/lora/mem_pool.py +174 -0
  63. sglang/srt/lora/triton_ops/__init__.py +7 -1
  64. sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
  65. sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
  66. sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
  67. sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
  68. sglang/srt/lora/utils.py +141 -0
  69. sglang/srt/managers/detokenizer_manager.py +1 -0
  70. sglang/srt/managers/io_struct.py +4 -0
  71. sglang/srt/managers/schedule_batch.py +16 -3
  72. sglang/srt/managers/scheduler.py +29 -0
  73. sglang/srt/managers/tokenizer_manager.py +6 -0
  74. sglang/srt/managers/tp_worker_overlap_thread.py +4 -0
  75. sglang/srt/model_executor/cuda_graph_runner.py +16 -1
  76. sglang/srt/model_executor/model_runner.py +12 -2
  77. sglang/srt/models/deepseek_v2.py +17 -7
  78. sglang/srt/server_args.py +20 -1
  79. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  80. sglang/srt/speculative/eagle_utils.py +64 -21
  81. sglang/srt/speculative/eagle_worker.py +29 -8
  82. sglang/srt/utils.py +7 -0
  83. sglang/version.py +1 -1
  84. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/METADATA +6 -5
  85. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/RECORD +88 -55
  86. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/LICENSE +0 -0
  87. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/WHEEL +0 -0
  88. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
7
7
  Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
8
8
  """
9
9
 
10
+ import math
10
11
  import os
11
12
  from dataclasses import dataclass
12
13
  from enum import Enum, auto
@@ -20,6 +21,7 @@ import triton.language as tl
20
21
  from sglang.global_config import global_config
21
22
  from sglang.srt.layers.attention import AttentionBackend
22
23
  from sglang.srt.layers.dp_attention import get_attention_tp_size
24
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
23
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
24
26
  from sglang.srt.utils import is_flashinfer_available
25
27
 
@@ -35,7 +37,7 @@ if is_flashinfer_available():
35
37
  BatchPrefillWithRaggedKVCacheWrapper,
36
38
  )
37
39
  from flashinfer.cascade import merge_state
38
- from flashinfer.decode import PosEncodingMode
40
+ from flashinfer.mla import BatchMLAPagedAttentionWrapper
39
41
 
40
42
 
41
43
  class WrapperDispatch(Enum):
@@ -45,7 +47,9 @@ class WrapperDispatch(Enum):
45
47
 
46
48
  @dataclass
47
49
  class DecodeMetadata:
48
- decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
50
+ decode_wrappers: List[
51
+ Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
52
+ ]
49
53
 
50
54
 
51
55
  @dataclass
@@ -70,6 +74,8 @@ class FlashInferAttnBackend(AttentionBackend):
70
74
  ):
71
75
  super().__init__()
72
76
 
77
+ self.is_multimodal = model_runner.model_config.is_multimodal
78
+
73
79
  # Parse constants
74
80
  self.decode_use_tensor_cores = should_use_tensor_core(
75
81
  kv_cache_dtype=model_runner.kv_cache_dtype,
@@ -101,6 +107,12 @@ class FlashInferAttnBackend(AttentionBackend):
101
107
  if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
102
108
  global_config.flashinfer_workspace_size = 512 * 1024 * 1024
103
109
 
110
+ self.enable_flashinfer_mla = False
111
+ if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures:
112
+ if global_server_args_dict["enable_flashinfer_mla"]:
113
+ self.enable_flashinfer_mla = True
114
+ global_config.enable_flashinfer_mla = True
115
+
104
116
  # Allocate buffers
105
117
  global global_workspace_buffer
106
118
  if global_workspace_buffer is None:
@@ -118,6 +130,13 @@ class FlashInferAttnBackend(AttentionBackend):
118
130
  )
119
131
  for _ in range(self.num_wrappers)
120
132
  ]
133
+ if self.enable_flashinfer_mla:
134
+ self.qo_indptr = [
135
+ torch.zeros(
136
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
137
+ )
138
+ for _ in range(self.num_wrappers)
139
+ ]
121
140
  else:
122
141
  assert self.num_wrappers == 1
123
142
  self.kv_indptr = [kv_indptr_buf]
@@ -130,12 +149,8 @@ class FlashInferAttnBackend(AttentionBackend):
130
149
  for _ in range(self.num_wrappers)
131
150
  ]
132
151
 
133
- # Create wrappers
134
- # NOTE: we do not use ragged attention when there are multiple wrappers
135
- self.prefill_wrapper_ragged = (
136
- BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
137
- if self.num_wrappers == 1
138
- else None
152
+ self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
153
+ self.workspace_buffer, "NHD"
139
154
  )
140
155
 
141
156
  # Two wrappers: one for sliding window attention and one for full attention.
@@ -155,13 +170,18 @@ class FlashInferAttnBackend(AttentionBackend):
155
170
  self.prefill_wrappers_verify.append(
156
171
  BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
157
172
  )
158
- self.decode_wrappers.append(
159
- BatchDecodeWithPagedKVCacheWrapper(
160
- self.workspace_buffer,
161
- "NHD",
162
- use_tensor_cores=self.decode_use_tensor_cores,
173
+ if self.enable_flashinfer_mla:
174
+ self.decode_wrappers.append(
175
+ BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
176
+ )
177
+ else:
178
+ self.decode_wrappers.append(
179
+ BatchDecodeWithPagedKVCacheWrapper(
180
+ self.workspace_buffer,
181
+ "NHD",
182
+ use_tensor_cores=self.decode_use_tensor_cores,
183
+ )
163
184
  )
164
- )
165
185
 
166
186
  # Create indices updater
167
187
  if not skip_prefill:
@@ -217,13 +237,12 @@ class FlashInferAttnBackend(AttentionBackend):
217
237
  else:
218
238
  prefix_lens = forward_batch.extend_prefix_lens
219
239
 
220
- # Some heuristics to check whether to use ragged forward
221
- if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
222
- use_ragged = True
223
- extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
224
- else:
240
+ if self.is_multimodal:
225
241
  use_ragged = False
226
242
  extend_no_prefix = False
243
+ else:
244
+ use_ragged = True
245
+ extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
227
246
 
228
247
  self.indices_updater_prefill.update(
229
248
  forward_batch.req_pool_indices,
@@ -277,19 +296,32 @@ class FlashInferAttnBackend(AttentionBackend):
277
296
  if forward_mode.is_decode_or_idle():
278
297
  decode_wrappers = []
279
298
  for i in range(self.num_wrappers):
280
- decode_wrappers.append(
281
- BatchDecodeWithPagedKVCacheWrapper(
282
- self.workspace_buffer,
283
- "NHD",
284
- use_cuda_graph=True,
285
- use_tensor_cores=self.decode_use_tensor_cores,
286
- paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
287
- paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
288
- paged_kv_last_page_len_buffer=self.kv_last_page_len[
289
- :num_tokens
290
- ],
299
+ if self.enable_flashinfer_mla:
300
+ decode_wrappers.append(
301
+ BatchMLAPagedAttentionWrapper(
302
+ self.workspace_buffer,
303
+ use_cuda_graph=True,
304
+ qo_indptr=self.qo_indptr[i][: num_tokens + 1],
305
+ kv_indptr=self.kv_indptr[i][: num_tokens + 1],
306
+ kv_indices=self.cuda_graph_kv_indices[i],
307
+ kv_len_arr=self.kv_last_page_len[:num_tokens],
308
+ backend="fa2",
309
+ )
310
+ )
311
+ else:
312
+ decode_wrappers.append(
313
+ BatchDecodeWithPagedKVCacheWrapper(
314
+ self.workspace_buffer,
315
+ "NHD",
316
+ use_cuda_graph=True,
317
+ use_tensor_cores=self.decode_use_tensor_cores,
318
+ paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
319
+ paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
320
+ paged_kv_last_page_len_buffer=self.kv_last_page_len[
321
+ :num_tokens
322
+ ],
323
+ )
291
324
  )
292
- )
293
325
  seq_lens_sum = seq_lens.sum().item()
294
326
  self.indices_updater_decode.update(
295
327
  req_pool_indices,
@@ -378,64 +410,94 @@ class FlashInferAttnBackend(AttentionBackend):
378
410
  forward_batch: ForwardBatch,
379
411
  save_kv_cache=True,
380
412
  ):
381
- prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
382
- self._get_wrapper_idx(layer)
383
- ]
384
- cache_loc = (
385
- forward_batch.out_cache_loc
386
- if not layer.is_cross_attention
387
- else forward_batch.encoder_out_cache_loc
388
- )
413
+ if global_config.enable_flashinfer_mla:
414
+ cache_loc = (
415
+ forward_batch.out_cache_loc
416
+ if not layer.is_cross_attention
417
+ else forward_batch.encoder_out_cache_loc
418
+ )
389
419
 
390
- logits_soft_cap = layer.logit_cap
420
+ logits_soft_cap = layer.logit_cap
391
421
 
392
- if not self.forward_metadata.use_ragged:
393
- if k is not None:
394
- assert v is not None
395
- if save_kv_cache:
396
- forward_batch.token_to_kv_pool.set_kv_buffer(
397
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
398
- )
399
-
400
- o = prefill_wrapper_paged.forward(
401
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
402
- forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
403
- causal=not layer.is_cross_attention,
422
+ o1, _ = self.prefill_wrapper_ragged.forward_return_lse(
423
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
424
+ k.view(-1, layer.tp_k_head_num, layer.head_dim),
425
+ v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
426
+ causal=True,
404
427
  sm_scale=layer.scaling,
405
- window_left=layer.sliding_window_size,
406
428
  logits_soft_cap=logits_soft_cap,
407
- k_scale=layer.k_scale,
408
- v_scale=layer.v_scale,
409
429
  )
430
+
431
+ o = o1
432
+
433
+ if save_kv_cache:
434
+ forward_batch.token_to_kv_pool.set_kv_buffer(
435
+ layer,
436
+ cache_loc,
437
+ k,
438
+ v,
439
+ )
440
+
441
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
410
442
  else:
411
- o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
412
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
413
- k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
414
- v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
415
- causal=True,
416
- sm_scale=layer.scaling,
417
- logits_soft_cap=logits_soft_cap,
443
+ prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
444
+ self._get_wrapper_idx(layer)
445
+ ]
446
+ cache_loc = (
447
+ forward_batch.out_cache_loc
448
+ if not layer.is_cross_attention
449
+ else forward_batch.encoder_out_cache_loc
418
450
  )
419
451
 
420
- if self.forward_metadata.extend_no_prefix:
421
- o = o1
422
- else:
423
- o2, s2 = prefill_wrapper_paged.forward_return_lse(
452
+ logits_soft_cap = layer.logit_cap
453
+
454
+ if not self.forward_metadata.use_ragged:
455
+ if k is not None:
456
+ assert v is not None
457
+ if save_kv_cache:
458
+ forward_batch.token_to_kv_pool.set_kv_buffer(
459
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
460
+ )
461
+
462
+ o = prefill_wrapper_paged.forward(
424
463
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
425
464
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
426
- causal=False,
465
+ causal=not layer.is_cross_attention,
466
+ sm_scale=layer.scaling,
467
+ window_left=layer.sliding_window_size,
468
+ logits_soft_cap=logits_soft_cap,
469
+ k_scale=layer.k_scale,
470
+ v_scale=layer.v_scale,
471
+ )
472
+ else:
473
+ o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
474
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
475
+ k.view(-1, layer.tp_k_head_num, layer.head_dim),
476
+ v.view(-1, layer.tp_v_head_num, layer.head_dim),
477
+ causal=True,
427
478
  sm_scale=layer.scaling,
428
- logits_soft_cap=layer.logit_cap,
479
+ logits_soft_cap=logits_soft_cap,
429
480
  )
430
481
 
431
- o, _ = merge_state(o1, s1, o2, s2)
482
+ if self.forward_metadata.extend_no_prefix:
483
+ o = o1
484
+ else:
485
+ o2, s2 = prefill_wrapper_paged.forward_return_lse(
486
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
487
+ forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
488
+ causal=False,
489
+ sm_scale=layer.scaling,
490
+ logits_soft_cap=layer.logit_cap,
491
+ )
432
492
 
433
- if save_kv_cache:
434
- forward_batch.token_to_kv_pool.set_kv_buffer(
435
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
436
- )
493
+ o, _ = merge_state(o1, s1, o2, s2)
437
494
 
438
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
495
+ if save_kv_cache:
496
+ forward_batch.token_to_kv_pool.set_kv_buffer(
497
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
498
+ )
499
+
500
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
439
501
 
440
502
  def forward_decode(
441
503
  self,
@@ -455,23 +517,45 @@ class FlashInferAttnBackend(AttentionBackend):
455
517
  else forward_batch.encoder_out_cache_loc
456
518
  )
457
519
 
458
- if k is not None:
459
- assert v is not None
460
- if save_kv_cache:
461
- forward_batch.token_to_kv_pool.set_kv_buffer(
462
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
463
- )
520
+ if self.enable_flashinfer_mla:
521
+ if k is not None:
522
+ assert v is not None
523
+ if save_kv_cache:
524
+ forward_batch.token_to_kv_pool.set_kv_buffer(
525
+ layer,
526
+ cache_loc,
527
+ k,
528
+ v,
529
+ )
530
+ reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
531
+ k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
532
+ reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
533
+ o = decode_wrapper.run(
534
+ reshaped_q[:, :, : layer.v_head_dim],
535
+ reshaped_q[:, :, layer.v_head_dim :],
536
+ reshaped_k[:, :, : layer.v_head_dim],
537
+ reshaped_k[:, :, layer.v_head_dim :],
538
+ )
464
539
 
465
- o = decode_wrapper.forward(
466
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
467
- forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
468
- sm_scale=layer.scaling,
469
- logits_soft_cap=layer.logit_cap,
470
- k_scale=layer.k_scale,
471
- v_scale=layer.v_scale,
472
- )
540
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
541
+ else:
542
+ if k is not None:
543
+ assert v is not None
544
+ if save_kv_cache:
545
+ forward_batch.token_to_kv_pool.set_kv_buffer(
546
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
547
+ )
548
+
549
+ o = decode_wrapper.forward(
550
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
551
+ forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
552
+ sm_scale=layer.scaling,
553
+ logits_soft_cap=layer.logit_cap,
554
+ k_scale=layer.k_scale,
555
+ v_scale=layer.v_scale,
556
+ )
473
557
 
474
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
558
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
475
559
 
476
560
  def _get_wrapper_idx(self, layer: RadixAttention):
477
561
  if self.num_wrappers == 1:
@@ -519,7 +603,9 @@ class FlashInferIndicesUpdaterDecode:
519
603
  req_pool_indices: torch.Tensor,
520
604
  seq_lens: torch.Tensor,
521
605
  seq_lens_sum: int,
522
- decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
606
+ decode_wrappers: List[
607
+ Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
608
+ ],
523
609
  encoder_lens: Optional[torch.Tensor],
524
610
  spec_info: Optional[SpecInfo],
525
611
  ):
@@ -531,7 +617,9 @@ class FlashInferIndicesUpdaterDecode:
531
617
  req_pool_indices: torch.Tensor,
532
618
  seq_lens: torch.Tensor,
533
619
  seq_lens_sum: int,
534
- decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
620
+ decode_wrappers: List[
621
+ Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
622
+ ],
535
623
  encoder_lens: Optional[torch.Tensor],
536
624
  spec_info: Optional[SpecInfo],
537
625
  ):
@@ -612,7 +700,9 @@ class FlashInferIndicesUpdaterDecode:
612
700
 
613
701
  def call_begin_forward(
614
702
  self,
615
- wrapper: BatchDecodeWithPagedKVCacheWrapper,
703
+ wrapper: Union[
704
+ BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
705
+ ],
616
706
  req_pool_indices: torch.Tensor,
617
707
  paged_kernel_lens: torch.Tensor,
618
708
  paged_kernel_lens_sum: int,
@@ -640,18 +730,37 @@ class FlashInferIndicesUpdaterDecode:
640
730
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
641
731
  bs = kv_indptr.shape[0] - 1
642
732
 
643
- wrapper.end_forward()
644
- wrapper.begin_forward(
645
- kv_indptr,
646
- kv_indices,
647
- self.kv_last_page_len[:bs],
648
- self.num_qo_heads,
649
- self.num_kv_heads,
650
- self.head_dim,
651
- 1,
652
- data_type=self.data_type,
653
- q_data_type=self.q_data_type,
654
- )
733
+ if global_config.enable_flashinfer_mla:
734
+ sm_scale = 1.0 / math.sqrt(192)
735
+ q_indptr = torch.arange(0, bs + 1).to(0).int()
736
+ kv_lens = paged_kernel_lens.to(torch.int32)
737
+ wrapper.plan(
738
+ q_indptr,
739
+ kv_indptr,
740
+ kv_indices,
741
+ kv_lens,
742
+ self.num_qo_heads,
743
+ 512,
744
+ 64,
745
+ 1,
746
+ False,
747
+ sm_scale,
748
+ self.data_type,
749
+ self.data_type,
750
+ )
751
+ else:
752
+ wrapper.begin_forward(
753
+ kv_indptr,
754
+ kv_indices,
755
+ self.kv_last_page_len[:bs],
756
+ self.num_qo_heads,
757
+ self.num_kv_heads,
758
+ self.head_dim,
759
+ 1,
760
+ data_type=self.data_type,
761
+ q_data_type=self.q_data_type,
762
+ non_blocking=True,
763
+ )
655
764
 
656
765
 
657
766
  class FlashInferIndicesUpdaterPrefill:
@@ -860,31 +969,42 @@ class FlashInferIndicesUpdaterPrefill:
860
969
 
861
970
  # extend part
862
971
  if use_ragged:
863
- wrapper_ragged.end_forward()
864
- wrapper_ragged.begin_forward(
865
- qo_indptr,
972
+ if global_config.enable_flashinfer_mla:
973
+ wrapper_ragged.begin_forward(
974
+ qo_indptr=qo_indptr,
975
+ kv_indptr=qo_indptr,
976
+ num_qo_heads=self.num_qo_heads,
977
+ num_kv_heads=self.num_kv_heads,
978
+ head_dim_qk=192,
979
+ head_dim_vo=128,
980
+ q_data_type=self.q_data_type,
981
+ )
982
+ else:
983
+ wrapper_ragged.begin_forward(
984
+ qo_indptr,
985
+ qo_indptr,
986
+ self.num_qo_heads,
987
+ self.num_kv_heads,
988
+ self.head_dim,
989
+ q_data_type=self.q_data_type,
990
+ )
991
+
992
+ if not global_config.enable_flashinfer_mla:
993
+ # cached part
994
+ wrapper_paged.begin_forward(
866
995
  qo_indptr,
996
+ kv_indptr,
997
+ kv_indices,
998
+ self.kv_last_page_len[:bs],
867
999
  self.num_qo_heads,
868
1000
  self.num_kv_heads,
869
1001
  self.head_dim,
1002
+ 1,
870
1003
  q_data_type=self.q_data_type,
1004
+ custom_mask=custom_mask,
1005
+ non_blocking=True,
871
1006
  )
872
1007
 
873
- # cached part
874
- wrapper_paged.end_forward()
875
- wrapper_paged.begin_forward(
876
- qo_indptr,
877
- kv_indptr,
878
- kv_indices,
879
- self.kv_last_page_len[:bs],
880
- self.num_qo_heads,
881
- self.num_kv_heads,
882
- self.head_dim,
883
- 1,
884
- q_data_type=self.q_data_type,
885
- custom_mask=custom_mask,
886
- )
887
-
888
1008
 
889
1009
  class FlashInferMultiStepDraftBackend:
890
1010
  """
@@ -924,38 +1044,50 @@ class FlashInferMultiStepDraftBackend:
924
1044
  self.max_context_len = self.attn_backends[0].max_context_len
925
1045
  # Cached variables for generate_draft_decode_kv_indices
926
1046
  self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
927
- self.kv_indptr_stride = self.kv_indptr.shape[1]
928
1047
 
929
- def common_template(self, forward_batch: ForwardBatch, call_fn: int):
1048
+ def common_template(
1049
+ self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
1050
+ ):
930
1051
  num_seqs = forward_batch.batch_size
931
1052
  bs = self.topk * num_seqs
932
1053
  seq_lens_sum = forward_batch.seq_lens_sum
1054
+
933
1055
  self.generate_draft_decode_kv_indices[
934
1056
  (self.speculative_num_steps, num_seqs, self.topk)
935
1057
  ](
936
1058
  forward_batch.req_pool_indices,
937
1059
  forward_batch.req_to_token_pool.req_to_token,
938
1060
  forward_batch.seq_lens,
939
- self.cuda_graph_kv_indices,
1061
+ kv_indices_buffer,
940
1062
  self.kv_indptr,
941
1063
  forward_batch.positions,
942
1064
  num_seqs,
943
1065
  self.topk,
944
1066
  self.pool_len,
945
- self.kv_indptr_stride,
1067
+ kv_indices_buffer.shape[1],
946
1068
  self.kv_indptr.shape[1],
947
1069
  triton.next_power_of_2(num_seqs),
948
1070
  triton.next_power_of_2(self.speculative_num_steps),
949
1071
  triton.next_power_of_2(bs),
950
1072
  )
951
- for i in range(self.speculative_num_steps):
1073
+
1074
+ for i in range(self.speculative_num_steps - 1):
952
1075
  forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
953
- forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][
1076
+ forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
954
1077
  : seq_lens_sum * self.topk + bs * (i + 1)
955
1078
  ]
956
1079
  call_fn(i, forward_batch)
957
1080
 
958
1081
  def init_forward_metadata(self, forward_batch: ForwardBatch):
1082
+ kv_indices = torch.zeros(
1083
+ (
1084
+ self.speculative_num_steps,
1085
+ forward_batch.batch_size * self.topk * self.max_context_len,
1086
+ ),
1087
+ dtype=torch.int32,
1088
+ device="cuda",
1089
+ )
1090
+
959
1091
  def call_fn(i, forward_batch):
960
1092
  forward_batch.spec_info.kv_indptr = (
961
1093
  forward_batch.spec_info.kv_indptr.clone()
@@ -965,7 +1097,7 @@ class FlashInferMultiStepDraftBackend:
965
1097
  )
966
1098
  self.attn_backends[i].init_forward_metadata(forward_batch)
967
1099
 
968
- self.common_template(forward_batch, call_fn)
1100
+ self.common_template(forward_batch, kv_indices, call_fn)
969
1101
 
970
1102
  def init_cuda_graph_state(self, max_bs: int):
971
1103
  self.cuda_graph_kv_indices = torch.zeros(
@@ -973,7 +1105,6 @@ class FlashInferMultiStepDraftBackend:
973
1105
  dtype=torch.int32,
974
1106
  device="cuda",
975
1107
  )
976
- self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1]
977
1108
  for i in range(self.speculative_num_steps):
978
1109
  self.attn_backends[i].init_cuda_graph_state(
979
1110
  max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
@@ -995,7 +1126,7 @@ class FlashInferMultiStepDraftBackend:
995
1126
  ][0]
996
1127
  decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
997
1128
 
998
- self.common_template(forward_batch, call_fn)
1129
+ self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
999
1130
 
1000
1131
  def init_forward_metadata_replay_cuda_graph(self, forward_batch):
1001
1132
  def call_fn(i, forward_batch):
@@ -1009,7 +1140,7 @@ class FlashInferMultiStepDraftBackend:
1009
1140
  spec_info=forward_batch.spec_info,
1010
1141
  )
1011
1142
 
1012
- self.common_template(forward_batch, call_fn)
1143
+ self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
1013
1144
 
1014
1145
 
1015
1146
  @triton.jit
@@ -1070,21 +1201,6 @@ def should_use_tensor_core(
1070
1201
  if env_override is not None:
1071
1202
  return env_override.lower() == "true"
1072
1203
 
1073
- # Try to use _grouped_size_compiled_for_decode_kernels if available
1074
- # This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
1075
- try:
1076
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
1077
-
1078
- if not _grouped_size_compiled_for_decode_kernels(
1079
- num_attention_heads,
1080
- num_kv_heads,
1081
- ):
1082
- return True
1083
- else:
1084
- return False
1085
- except (ImportError, AttributeError):
1086
- pass
1087
-
1088
1204
  # Calculate GQA group size
1089
1205
  gqa_group_size = num_attention_heads // num_kv_heads
1090
1206
 
@@ -1114,6 +1230,7 @@ def fast_decode_plan(
1114
1230
  sm_scale: Optional[float] = None,
1115
1231
  rope_scale: Optional[float] = None,
1116
1232
  rope_theta: Optional[float] = None,
1233
+ **kwargs,
1117
1234
  ) -> None:
1118
1235
  """A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
1119
1236
  batch_size = len(last_page_len)
@@ -1170,6 +1287,7 @@ def fast_decode_plan(
1170
1287
  window_left,
1171
1288
  logits_soft_cap,
1172
1289
  head_dim,
1290
+ head_dim,
1173
1291
  empty_q_data,
1174
1292
  empty_kv_cache,
1175
1293
  stream.cuda_stream,