sglang 0.4.2.post4__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. sglang/global_config.py +2 -0
  2. sglang/srt/entrypoints/engine.py +2 -2
  3. sglang/srt/layers/attention/flashinfer_backend.py +235 -110
  4. sglang/srt/layers/attention/triton_backend.py +358 -72
  5. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  6. sglang/srt/layers/linear.py +12 -5
  7. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
  8. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +178 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +175 -0
  16. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -2
  17. sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -0
  18. sglang/srt/layers/moe/topk.py +1 -1
  19. sglang/srt/layers/quantization/__init__.py +51 -5
  20. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  21. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
  22. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  23. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +29 -29
  24. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  25. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +33 -33
  26. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  27. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
  28. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  29. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +27 -27
  30. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  31. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
  32. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  33. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +24 -24
  34. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  35. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
  36. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  37. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +42 -42
  38. sglang/srt/layers/quantization/fp8_kernel.py +123 -17
  39. sglang/srt/layers/quantization/fp8_utils.py +33 -4
  40. sglang/srt/managers/detokenizer_manager.py +1 -0
  41. sglang/srt/managers/io_struct.py +4 -0
  42. sglang/srt/managers/schedule_batch.py +16 -3
  43. sglang/srt/managers/scheduler.py +29 -0
  44. sglang/srt/managers/tokenizer_manager.py +6 -0
  45. sglang/srt/managers/tp_worker_overlap_thread.py +4 -0
  46. sglang/srt/model_executor/cuda_graph_runner.py +12 -1
  47. sglang/srt/model_executor/model_runner.py +12 -2
  48. sglang/srt/models/deepseek_v2.py +17 -7
  49. sglang/srt/server_args.py +20 -1
  50. sglang/srt/speculative/eagle_worker.py +28 -8
  51. sglang/srt/utils.py +7 -0
  52. sglang/version.py +1 -1
  53. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/METADATA +4 -3
  54. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/RECORD +57 -41
  55. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/LICENSE +0 -0
  56. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/WHEEL +0 -0
  57. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/top_level.txt +0 -0
sglang/global_config.py CHANGED
@@ -38,5 +38,7 @@ class GlobalConfig:
38
38
  self.enable_precache_with_tracing = True
39
39
  self.enable_parallel_encoding = True
40
40
 
41
+ self.enable_flashinfer_mla = False
42
+
41
43
 
42
44
  global_config = GlobalConfig()
@@ -297,7 +297,7 @@ def _set_envs_and_config(server_args: ServerArgs):
297
297
  # Set global environments
298
298
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
299
299
  os.environ["NCCL_CUMEM_ENABLE"] = "0"
300
- os.environ["NCCL_NVLS_ENABLE"] = "0"
300
+ os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
301
301
  os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
302
302
  os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
303
303
 
@@ -317,7 +317,7 @@ def _set_envs_and_config(server_args: ServerArgs):
317
317
  if server_args.attention_backend == "flashinfer":
318
318
  assert_pkg_version(
319
319
  "flashinfer_python",
320
- "0.2.0.post2",
320
+ "0.2.1.post1",
321
321
  "Please uninstall the old version and "
322
322
  "reinstall the latest version by following the instructions "
323
323
  "at https://docs.flashinfer.ai/installation.html.",
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
7
7
  Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
8
8
  """
9
9
 
10
+ import math
10
11
  import os
11
12
  from dataclasses import dataclass
12
13
  from enum import Enum, auto
@@ -20,6 +21,7 @@ import triton.language as tl
20
21
  from sglang.global_config import global_config
21
22
  from sglang.srt.layers.attention import AttentionBackend
22
23
  from sglang.srt.layers.dp_attention import get_attention_tp_size
24
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
23
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
24
26
  from sglang.srt.utils import is_flashinfer_available
25
27
 
@@ -35,7 +37,7 @@ if is_flashinfer_available():
35
37
  BatchPrefillWithRaggedKVCacheWrapper,
36
38
  )
37
39
  from flashinfer.cascade import merge_state
38
- from flashinfer.decode import PosEncodingMode
40
+ from flashinfer.mla import BatchMLAPagedAttentionWrapper
39
41
 
40
42
 
41
43
  class WrapperDispatch(Enum):
@@ -45,7 +47,9 @@ class WrapperDispatch(Enum):
45
47
 
46
48
  @dataclass
47
49
  class DecodeMetadata:
48
- decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
50
+ decode_wrappers: List[
51
+ Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
52
+ ]
49
53
 
50
54
 
51
55
  @dataclass
@@ -103,6 +107,12 @@ class FlashInferAttnBackend(AttentionBackend):
103
107
  if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
104
108
  global_config.flashinfer_workspace_size = 512 * 1024 * 1024
105
109
 
110
+ self.enable_flashinfer_mla = False
111
+ if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures:
112
+ if global_server_args_dict["enable_flashinfer_mla"]:
113
+ self.enable_flashinfer_mla = True
114
+ global_config.enable_flashinfer_mla = True
115
+
106
116
  # Allocate buffers
107
117
  global global_workspace_buffer
108
118
  if global_workspace_buffer is None:
@@ -120,6 +130,13 @@ class FlashInferAttnBackend(AttentionBackend):
120
130
  )
121
131
  for _ in range(self.num_wrappers)
122
132
  ]
133
+ if self.enable_flashinfer_mla:
134
+ self.qo_indptr = [
135
+ torch.zeros(
136
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
137
+ )
138
+ for _ in range(self.num_wrappers)
139
+ ]
123
140
  else:
124
141
  assert self.num_wrappers == 1
125
142
  self.kv_indptr = [kv_indptr_buf]
@@ -153,13 +170,18 @@ class FlashInferAttnBackend(AttentionBackend):
153
170
  self.prefill_wrappers_verify.append(
154
171
  BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
155
172
  )
156
- self.decode_wrappers.append(
157
- BatchDecodeWithPagedKVCacheWrapper(
158
- self.workspace_buffer,
159
- "NHD",
160
- use_tensor_cores=self.decode_use_tensor_cores,
173
+ if self.enable_flashinfer_mla:
174
+ self.decode_wrappers.append(
175
+ BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
176
+ )
177
+ else:
178
+ self.decode_wrappers.append(
179
+ BatchDecodeWithPagedKVCacheWrapper(
180
+ self.workspace_buffer,
181
+ "NHD",
182
+ use_tensor_cores=self.decode_use_tensor_cores,
183
+ )
161
184
  )
162
- )
163
185
 
164
186
  # Create indices updater
165
187
  if not skip_prefill:
@@ -274,19 +296,32 @@ class FlashInferAttnBackend(AttentionBackend):
274
296
  if forward_mode.is_decode_or_idle():
275
297
  decode_wrappers = []
276
298
  for i in range(self.num_wrappers):
277
- decode_wrappers.append(
278
- BatchDecodeWithPagedKVCacheWrapper(
279
- self.workspace_buffer,
280
- "NHD",
281
- use_cuda_graph=True,
282
- use_tensor_cores=self.decode_use_tensor_cores,
283
- paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
284
- paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
285
- paged_kv_last_page_len_buffer=self.kv_last_page_len[
286
- :num_tokens
287
- ],
299
+ if self.enable_flashinfer_mla:
300
+ decode_wrappers.append(
301
+ BatchMLAPagedAttentionWrapper(
302
+ self.workspace_buffer,
303
+ use_cuda_graph=True,
304
+ qo_indptr=self.qo_indptr[i][: num_tokens + 1],
305
+ kv_indptr=self.kv_indptr[i][: num_tokens + 1],
306
+ kv_indices=self.cuda_graph_kv_indices[i],
307
+ kv_len_arr=self.kv_last_page_len[:num_tokens],
308
+ backend="fa2",
309
+ )
310
+ )
311
+ else:
312
+ decode_wrappers.append(
313
+ BatchDecodeWithPagedKVCacheWrapper(
314
+ self.workspace_buffer,
315
+ "NHD",
316
+ use_cuda_graph=True,
317
+ use_tensor_cores=self.decode_use_tensor_cores,
318
+ paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
319
+ paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
320
+ paged_kv_last_page_len_buffer=self.kv_last_page_len[
321
+ :num_tokens
322
+ ],
323
+ )
288
324
  )
289
- )
290
325
  seq_lens_sum = seq_lens.sum().item()
291
326
  self.indices_updater_decode.update(
292
327
  req_pool_indices,
@@ -375,64 +410,94 @@ class FlashInferAttnBackend(AttentionBackend):
375
410
  forward_batch: ForwardBatch,
376
411
  save_kv_cache=True,
377
412
  ):
378
- prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
379
- self._get_wrapper_idx(layer)
380
- ]
381
- cache_loc = (
382
- forward_batch.out_cache_loc
383
- if not layer.is_cross_attention
384
- else forward_batch.encoder_out_cache_loc
385
- )
413
+ if global_config.enable_flashinfer_mla:
414
+ cache_loc = (
415
+ forward_batch.out_cache_loc
416
+ if not layer.is_cross_attention
417
+ else forward_batch.encoder_out_cache_loc
418
+ )
386
419
 
387
- logits_soft_cap = layer.logit_cap
420
+ logits_soft_cap = layer.logit_cap
388
421
 
389
- if not self.forward_metadata.use_ragged:
390
- if k is not None:
391
- assert v is not None
392
- if save_kv_cache:
393
- forward_batch.token_to_kv_pool.set_kv_buffer(
394
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
395
- )
396
-
397
- o = prefill_wrapper_paged.forward(
398
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
399
- forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
400
- causal=not layer.is_cross_attention,
401
- sm_scale=layer.scaling,
402
- window_left=layer.sliding_window_size,
403
- logits_soft_cap=logits_soft_cap,
404
- k_scale=layer.k_scale,
405
- v_scale=layer.v_scale,
406
- )
407
- else:
408
- o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
422
+ o1, _ = self.prefill_wrapper_ragged.forward_return_lse(
409
423
  q.view(-1, layer.tp_q_head_num, layer.head_dim),
410
424
  k.view(-1, layer.tp_k_head_num, layer.head_dim),
411
- v.view(-1, layer.tp_v_head_num, layer.head_dim),
425
+ v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
412
426
  causal=True,
413
427
  sm_scale=layer.scaling,
414
428
  logits_soft_cap=logits_soft_cap,
415
429
  )
416
430
 
417
- if self.forward_metadata.extend_no_prefix:
418
- o = o1
419
- else:
420
- o2, s2 = prefill_wrapper_paged.forward_return_lse(
431
+ o = o1
432
+
433
+ if save_kv_cache:
434
+ forward_batch.token_to_kv_pool.set_kv_buffer(
435
+ layer,
436
+ cache_loc,
437
+ k,
438
+ v,
439
+ )
440
+
441
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
442
+ else:
443
+ prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
444
+ self._get_wrapper_idx(layer)
445
+ ]
446
+ cache_loc = (
447
+ forward_batch.out_cache_loc
448
+ if not layer.is_cross_attention
449
+ else forward_batch.encoder_out_cache_loc
450
+ )
451
+
452
+ logits_soft_cap = layer.logit_cap
453
+
454
+ if not self.forward_metadata.use_ragged:
455
+ if k is not None:
456
+ assert v is not None
457
+ if save_kv_cache:
458
+ forward_batch.token_to_kv_pool.set_kv_buffer(
459
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
460
+ )
461
+
462
+ o = prefill_wrapper_paged.forward(
421
463
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
422
464
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
423
- causal=False,
465
+ causal=not layer.is_cross_attention,
424
466
  sm_scale=layer.scaling,
425
- logits_soft_cap=layer.logit_cap,
467
+ window_left=layer.sliding_window_size,
468
+ logits_soft_cap=logits_soft_cap,
469
+ k_scale=layer.k_scale,
470
+ v_scale=layer.v_scale,
471
+ )
472
+ else:
473
+ o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
474
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
475
+ k.view(-1, layer.tp_k_head_num, layer.head_dim),
476
+ v.view(-1, layer.tp_v_head_num, layer.head_dim),
477
+ causal=True,
478
+ sm_scale=layer.scaling,
479
+ logits_soft_cap=logits_soft_cap,
426
480
  )
427
481
 
428
- o, _ = merge_state(o1, s1, o2, s2)
482
+ if self.forward_metadata.extend_no_prefix:
483
+ o = o1
484
+ else:
485
+ o2, s2 = prefill_wrapper_paged.forward_return_lse(
486
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
487
+ forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
488
+ causal=False,
489
+ sm_scale=layer.scaling,
490
+ logits_soft_cap=layer.logit_cap,
491
+ )
429
492
 
430
- if save_kv_cache:
431
- forward_batch.token_to_kv_pool.set_kv_buffer(
432
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
433
- )
493
+ o, _ = merge_state(o1, s1, o2, s2)
494
+
495
+ if save_kv_cache:
496
+ forward_batch.token_to_kv_pool.set_kv_buffer(
497
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
498
+ )
434
499
 
435
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
500
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
436
501
 
437
502
  def forward_decode(
438
503
  self,
@@ -452,23 +517,45 @@ class FlashInferAttnBackend(AttentionBackend):
452
517
  else forward_batch.encoder_out_cache_loc
453
518
  )
454
519
 
455
- if k is not None:
456
- assert v is not None
457
- if save_kv_cache:
458
- forward_batch.token_to_kv_pool.set_kv_buffer(
459
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
460
- )
520
+ if self.enable_flashinfer_mla:
521
+ if k is not None:
522
+ assert v is not None
523
+ if save_kv_cache:
524
+ forward_batch.token_to_kv_pool.set_kv_buffer(
525
+ layer,
526
+ cache_loc,
527
+ k,
528
+ v,
529
+ )
530
+ reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
531
+ k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
532
+ reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
533
+ o = decode_wrapper.run(
534
+ reshaped_q[:, :, : layer.v_head_dim],
535
+ reshaped_q[:, :, layer.v_head_dim :],
536
+ reshaped_k[:, :, : layer.v_head_dim],
537
+ reshaped_k[:, :, layer.v_head_dim :],
538
+ )
461
539
 
462
- o = decode_wrapper.forward(
463
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
464
- forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
465
- sm_scale=layer.scaling,
466
- logits_soft_cap=layer.logit_cap,
467
- k_scale=layer.k_scale,
468
- v_scale=layer.v_scale,
469
- )
540
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
541
+ else:
542
+ if k is not None:
543
+ assert v is not None
544
+ if save_kv_cache:
545
+ forward_batch.token_to_kv_pool.set_kv_buffer(
546
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
547
+ )
470
548
 
471
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
549
+ o = decode_wrapper.forward(
550
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
551
+ forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
552
+ sm_scale=layer.scaling,
553
+ logits_soft_cap=layer.logit_cap,
554
+ k_scale=layer.k_scale,
555
+ v_scale=layer.v_scale,
556
+ )
557
+
558
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
472
559
 
473
560
  def _get_wrapper_idx(self, layer: RadixAttention):
474
561
  if self.num_wrappers == 1:
@@ -516,7 +603,9 @@ class FlashInferIndicesUpdaterDecode:
516
603
  req_pool_indices: torch.Tensor,
517
604
  seq_lens: torch.Tensor,
518
605
  seq_lens_sum: int,
519
- decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
606
+ decode_wrappers: List[
607
+ Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
608
+ ],
520
609
  encoder_lens: Optional[torch.Tensor],
521
610
  spec_info: Optional[SpecInfo],
522
611
  ):
@@ -528,7 +617,9 @@ class FlashInferIndicesUpdaterDecode:
528
617
  req_pool_indices: torch.Tensor,
529
618
  seq_lens: torch.Tensor,
530
619
  seq_lens_sum: int,
531
- decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
620
+ decode_wrappers: List[
621
+ Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
622
+ ],
532
623
  encoder_lens: Optional[torch.Tensor],
533
624
  spec_info: Optional[SpecInfo],
534
625
  ):
@@ -609,7 +700,9 @@ class FlashInferIndicesUpdaterDecode:
609
700
 
610
701
  def call_begin_forward(
611
702
  self,
612
- wrapper: BatchDecodeWithPagedKVCacheWrapper,
703
+ wrapper: Union[
704
+ BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
705
+ ],
613
706
  req_pool_indices: torch.Tensor,
614
707
  paged_kernel_lens: torch.Tensor,
615
708
  paged_kernel_lens_sum: int,
@@ -637,18 +730,37 @@ class FlashInferIndicesUpdaterDecode:
637
730
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
638
731
  bs = kv_indptr.shape[0] - 1
639
732
 
640
- wrapper.begin_forward(
641
- kv_indptr,
642
- kv_indices,
643
- self.kv_last_page_len[:bs],
644
- self.num_qo_heads,
645
- self.num_kv_heads,
646
- self.head_dim,
647
- 1,
648
- data_type=self.data_type,
649
- q_data_type=self.q_data_type,
650
- non_blocking=True,
651
- )
733
+ if global_config.enable_flashinfer_mla:
734
+ sm_scale = 1.0 / math.sqrt(192)
735
+ q_indptr = torch.arange(0, bs + 1).to(0).int()
736
+ kv_lens = paged_kernel_lens.to(torch.int32)
737
+ wrapper.plan(
738
+ q_indptr,
739
+ kv_indptr,
740
+ kv_indices,
741
+ kv_lens,
742
+ self.num_qo_heads,
743
+ 512,
744
+ 64,
745
+ 1,
746
+ False,
747
+ sm_scale,
748
+ self.data_type,
749
+ self.data_type,
750
+ )
751
+ else:
752
+ wrapper.begin_forward(
753
+ kv_indptr,
754
+ kv_indices,
755
+ self.kv_last_page_len[:bs],
756
+ self.num_qo_heads,
757
+ self.num_kv_heads,
758
+ self.head_dim,
759
+ 1,
760
+ data_type=self.data_type,
761
+ q_data_type=self.q_data_type,
762
+ non_blocking=True,
763
+ )
652
764
 
653
765
 
654
766
  class FlashInferIndicesUpdaterPrefill:
@@ -857,30 +969,42 @@ class FlashInferIndicesUpdaterPrefill:
857
969
 
858
970
  # extend part
859
971
  if use_ragged:
860
- wrapper_ragged.begin_forward(
861
- qo_indptr,
972
+ if global_config.enable_flashinfer_mla:
973
+ wrapper_ragged.begin_forward(
974
+ qo_indptr=qo_indptr,
975
+ kv_indptr=qo_indptr,
976
+ num_qo_heads=self.num_qo_heads,
977
+ num_kv_heads=self.num_kv_heads,
978
+ head_dim_qk=192,
979
+ head_dim_vo=128,
980
+ q_data_type=self.q_data_type,
981
+ )
982
+ else:
983
+ wrapper_ragged.begin_forward(
984
+ qo_indptr,
985
+ qo_indptr,
986
+ self.num_qo_heads,
987
+ self.num_kv_heads,
988
+ self.head_dim,
989
+ q_data_type=self.q_data_type,
990
+ )
991
+
992
+ if not global_config.enable_flashinfer_mla:
993
+ # cached part
994
+ wrapper_paged.begin_forward(
862
995
  qo_indptr,
996
+ kv_indptr,
997
+ kv_indices,
998
+ self.kv_last_page_len[:bs],
863
999
  self.num_qo_heads,
864
1000
  self.num_kv_heads,
865
1001
  self.head_dim,
1002
+ 1,
866
1003
  q_data_type=self.q_data_type,
1004
+ custom_mask=custom_mask,
1005
+ non_blocking=True,
867
1006
  )
868
1007
 
869
- # cached part
870
- wrapper_paged.begin_forward(
871
- qo_indptr,
872
- kv_indptr,
873
- kv_indices,
874
- self.kv_last_page_len[:bs],
875
- self.num_qo_heads,
876
- self.num_kv_heads,
877
- self.head_dim,
878
- 1,
879
- q_data_type=self.q_data_type,
880
- custom_mask=custom_mask,
881
- non_blocking=True,
882
- )
883
-
884
1008
 
885
1009
  class FlashInferMultiStepDraftBackend:
886
1010
  """
@@ -947,7 +1071,7 @@ class FlashInferMultiStepDraftBackend:
947
1071
  triton.next_power_of_2(bs),
948
1072
  )
949
1073
 
950
- for i in range(self.speculative_num_steps):
1074
+ for i in range(self.speculative_num_steps - 1):
951
1075
  forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
952
1076
  forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
953
1077
  : seq_lens_sum * self.topk + bs * (i + 1)
@@ -1163,6 +1287,7 @@ def fast_decode_plan(
1163
1287
  window_left,
1164
1288
  logits_soft_cap,
1165
1289
  head_dim,
1290
+ head_dim,
1166
1291
  empty_q_data,
1167
1292
  empty_kv_cache,
1168
1293
  stream.cuda_stream,