sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -0
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +7 -7
  6. sglang/srt/disaggregation/decode.py +8 -3
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +4 -5
  14. sglang/srt/entrypoints/openai/protocol.py +0 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +59 -265
  16. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  17. sglang/srt/function_call/ebnf_composer.py +1 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  20. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  21. sglang/srt/function_call/kimik2_detector.py +3 -3
  22. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  23. sglang/srt/jinja_template_utils.py +6 -0
  24. sglang/srt/layers/attention/aiter_backend.py +370 -107
  25. sglang/srt/layers/attention/ascend_backend.py +3 -0
  26. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  27. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  28. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  29. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  30. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  31. sglang/srt/layers/attention/vision.py +9 -1
  32. sglang/srt/layers/attention/wave_backend.py +627 -0
  33. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  34. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  35. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  36. sglang/srt/layers/communicator.py +8 -10
  37. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  38. sglang/srt/layers/linear.py +1 -0
  39. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  41. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  42. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  43. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  46. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  47. sglang/srt/layers/moe/topk.py +4 -1
  48. sglang/srt/layers/quantization/__init__.py +5 -3
  49. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  50. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  51. sglang/srt/layers/quantization/modelopt_quant.py +6 -11
  52. sglang/srt/layers/quantization/mxfp4.py +4 -1
  53. sglang/srt/layers/quantization/w4afp8.py +20 -11
  54. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  55. sglang/srt/layers/rotary_embedding.py +281 -2
  56. sglang/srt/lora/backend/base_backend.py +3 -23
  57. sglang/srt/lora/layers.py +60 -114
  58. sglang/srt/lora/lora.py +17 -62
  59. sglang/srt/lora/lora_manager.py +12 -48
  60. sglang/srt/lora/lora_registry.py +20 -9
  61. sglang/srt/lora/mem_pool.py +20 -63
  62. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  63. sglang/srt/lora/utils.py +25 -58
  64. sglang/srt/managers/cache_controller.py +21 -29
  65. sglang/srt/managers/detokenizer_manager.py +1 -1
  66. sglang/srt/managers/io_struct.py +6 -6
  67. sglang/srt/managers/mm_utils.py +1 -2
  68. sglang/srt/managers/multimodal_processor.py +1 -1
  69. sglang/srt/managers/schedule_batch.py +35 -20
  70. sglang/srt/managers/schedule_policy.py +6 -6
  71. sglang/srt/managers/scheduler.py +15 -7
  72. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  73. sglang/srt/managers/tokenizer_manager.py +25 -26
  74. sglang/srt/mem_cache/allocator.py +61 -87
  75. sglang/srt/mem_cache/hicache_storage.py +1 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  77. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  78. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  79. sglang/srt/mem_cache/radix_cache.py +2 -5
  80. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  81. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  82. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  83. sglang/srt/model_executor/cuda_graph_runner.py +22 -3
  84. sglang/srt/model_executor/forward_batch_info.py +26 -5
  85. sglang/srt/model_executor/model_runner.py +129 -35
  86. sglang/srt/model_loader/loader.py +18 -6
  87. sglang/srt/models/deepseek_v2.py +74 -35
  88. sglang/srt/models/gemma2.py +0 -34
  89. sglang/srt/models/gemma3n_mm.py +8 -9
  90. sglang/srt/models/glm4.py +6 -0
  91. sglang/srt/models/glm4_moe.py +9 -9
  92. sglang/srt/models/glm4v.py +589 -0
  93. sglang/srt/models/glm4v_moe.py +400 -0
  94. sglang/srt/models/gpt_oss.py +136 -19
  95. sglang/srt/models/granite.py +0 -25
  96. sglang/srt/models/llama.py +0 -25
  97. sglang/srt/models/llama4.py +1 -1
  98. sglang/srt/models/qwen2_5_vl.py +7 -3
  99. sglang/srt/models/qwen2_audio.py +10 -9
  100. sglang/srt/models/qwen3.py +0 -24
  101. sglang/srt/models/registry.py +1 -1
  102. sglang/srt/models/torch_native_llama.py +0 -24
  103. sglang/srt/multimodal/processors/base_processor.py +23 -13
  104. sglang/srt/multimodal/processors/glm4v.py +132 -0
  105. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  106. sglang/srt/reasoning_parser.py +316 -0
  107. sglang/srt/server_args.py +115 -139
  108. sglang/srt/speculative/eagle_worker.py +16 -0
  109. sglang/srt/two_batch_overlap.py +12 -4
  110. sglang/srt/utils.py +3 -3
  111. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  112. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  113. sglang/test/doc_patch.py +59 -0
  114. sglang/test/few_shot_gsm8k.py +1 -1
  115. sglang/test/few_shot_gsm8k_engine.py +1 -1
  116. sglang/test/run_eval.py +4 -1
  117. sglang/test/simple_eval_common.py +6 -0
  118. sglang/test/simple_eval_gpqa.py +2 -0
  119. sglang/test/test_fp4_moe.py +118 -36
  120. sglang/utils.py +1 -1
  121. sglang/version.py +1 -1
  122. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
  123. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
  124. sglang/lang/backend/__init__.py +0 -0
  125. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  126. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  127. /sglang/{api.py → lang/api.py} +0 -0
  128. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  129. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -75,6 +75,9 @@ class AscendAttnBackend(AttentionBackend):
75
75
  )
76
76
  self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
77
77
 
78
+ def get_cuda_graph_seq_len_fill_value(self):
79
+ return 1
80
+
78
81
  def forward_extend(
79
82
  self,
80
83
  q,
@@ -483,7 +483,7 @@ class DualChunkFlashAttentionBackend(AttentionBackend):
483
483
  ).squeeze(1)
484
484
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
485
485
 
486
- def init_cuda_graph_state(self, max_bs: int):
486
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
487
487
  """Initialize CUDA graph state for the attention backend.
488
488
 
489
489
  Args:
@@ -629,6 +629,7 @@ class FlashAttentionBackend(AttentionBackend):
629
629
  # For multi-head latent attention
630
630
  q_rope: Optional[torch.Tensor] = None,
631
631
  k_rope: Optional[torch.Tensor] = None,
632
+ sinks: Optional[torch.Tensor] = None,
632
633
  ):
633
634
  if k is not None:
634
635
  assert v is not None
@@ -687,6 +688,11 @@ class FlashAttentionBackend(AttentionBackend):
687
688
  forward_batch.forward_mode.is_target_verify() and self.topk > 1
688
689
  )
689
690
 
691
+ # For fa3 interface version compatibility, we put new fields into conditional keyword args
692
+ kwargs = {}
693
+ if sinks is not None:
694
+ kwargs["sinks"] = sinks
695
+
690
696
  # Get the appropriate page table based on whether we're using local attention
691
697
  if use_local_attn:
692
698
  local_metadata = metadata.local_attn_metadata
@@ -737,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
737
743
  k_descale=k_descale,
738
744
  v_descale=v_descale,
739
745
  return_softmax_lse=use_cascade_attn,
746
+ **kwargs,
740
747
  )
741
748
 
742
749
  if use_cascade_attn:
@@ -757,6 +764,7 @@ class FlashAttentionBackend(AttentionBackend):
757
764
  k_descale=k_descale,
758
765
  v_descale=v_descale,
759
766
  return_softmax_lse=True,
767
+ **kwargs,
760
768
  )
761
769
  o, _ = merge_state_v2_wrapper(
762
770
  o,
@@ -898,6 +906,7 @@ class FlashAttentionBackend(AttentionBackend):
898
906
  # For multi-head latent attention
899
907
  q_rope: Optional[torch.Tensor] = None,
900
908
  k_rope: Optional[torch.Tensor] = None,
909
+ sinks: Optional[torch.Tensor] = None,
901
910
  ) -> torch.Tensor:
902
911
  if k is not None:
903
912
  assert v is not None
@@ -943,6 +952,11 @@ class FlashAttentionBackend(AttentionBackend):
943
952
  )
944
953
  causal = not layer.is_cross_attention
945
954
 
955
+ # For fa3 interface version compatibility, we put new fields into conditional keyword args
956
+ kwargs = {}
957
+ if sinks is not None:
958
+ kwargs["sinks"] = sinks
959
+
946
960
  k_descale, v_descale = None, None
947
961
  # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
948
962
  # has corresponding quantization method so that layer.k_scale is not None,
@@ -985,6 +999,7 @@ class FlashAttentionBackend(AttentionBackend):
985
999
  softcap=layer.logit_cap,
986
1000
  k_descale=k_descale,
987
1001
  v_descale=v_descale,
1002
+ **kwargs,
988
1003
  )
989
1004
  elif use_local_attn:
990
1005
  # Use chunked (local) attention batching for self-attention
@@ -1003,6 +1018,7 @@ class FlashAttentionBackend(AttentionBackend):
1003
1018
  softcap=layer.logit_cap,
1004
1019
  k_descale=k_descale,
1005
1020
  v_descale=v_descale,
1021
+ **kwargs,
1006
1022
  )
1007
1023
  else:
1008
1024
  page_table = metadata.page_table
@@ -1030,6 +1046,7 @@ class FlashAttentionBackend(AttentionBackend):
1030
1046
  k_descale=k_descale,
1031
1047
  v_descale=v_descale,
1032
1048
  return_softmax_lse=use_cascade_attn,
1049
+ **kwargs,
1033
1050
  )
1034
1051
  if use_cascade_attn:
1035
1052
  o, softmax_lse, *rest = result
@@ -1050,6 +1067,7 @@ class FlashAttentionBackend(AttentionBackend):
1050
1067
  k_descale=k_descale,
1051
1068
  v_descale=v_descale,
1052
1069
  return_softmax_lse=True,
1070
+ **kwargs,
1053
1071
  )
1054
1072
  )
1055
1073
  o, _ = merge_state_v2(
@@ -66,6 +66,10 @@ class PrefillMetadata:
66
66
  # Reuse this workspace buffer across all flashinfer wrappers
67
67
  global_workspace_buffer = None
68
68
 
69
+ # Use as a fast path to override the indptr in flashinfer's plan function
70
+ # This is used to remove some host-to-device copy overhead.
71
+ global_override_indptr_cpu = None
72
+
69
73
 
70
74
  class FlashInferAttnBackend(AttentionBackend):
71
75
  """Flashinfer attention kernels."""
@@ -205,6 +209,7 @@ class FlashInferAttnBackend(AttentionBackend):
205
209
  self.indices_updater_decode.update(
206
210
  forward_batch.req_pool_indices,
207
211
  forward_batch.seq_lens,
212
+ forward_batch.seq_lens_cpu,
208
213
  forward_batch.seq_lens_sum,
209
214
  decode_wrappers=self.decode_wrappers,
210
215
  encoder_lens=forward_batch.encoder_lens,
@@ -215,6 +220,7 @@ class FlashInferAttnBackend(AttentionBackend):
215
220
  self.indices_updater_prefill.update(
216
221
  forward_batch.req_pool_indices,
217
222
  forward_batch.seq_lens,
223
+ forward_batch.seq_lens_cpu,
218
224
  forward_batch.seq_lens_sum,
219
225
  prefix_lens=None,
220
226
  prefill_wrappers=self.prefill_wrappers_paged,
@@ -229,6 +235,7 @@ class FlashInferAttnBackend(AttentionBackend):
229
235
  self.indices_updater_prefill.update(
230
236
  forward_batch.req_pool_indices,
231
237
  forward_batch.seq_lens,
238
+ forward_batch.seq_lens_cpu,
232
239
  forward_batch.seq_lens_sum,
233
240
  prefix_lens=None,
234
241
  prefill_wrappers=self.prefill_wrappers_verify,
@@ -252,6 +259,7 @@ class FlashInferAttnBackend(AttentionBackend):
252
259
  self.indices_updater_prefill.update(
253
260
  forward_batch.req_pool_indices,
254
261
  forward_batch.seq_lens,
262
+ forward_batch.seq_lens_cpu,
255
263
  forward_batch.seq_lens_sum,
256
264
  prefix_lens,
257
265
  prefill_wrappers=self.prefill_wrappers_paged,
@@ -327,6 +335,7 @@ class FlashInferAttnBackend(AttentionBackend):
327
335
  self.indices_updater_decode.update(
328
336
  req_pool_indices,
329
337
  seq_lens,
338
+ seq_lens.cpu(), # may add a little overhead in capture stage
330
339
  seq_lens_sum,
331
340
  decode_wrappers=decode_wrappers,
332
341
  encoder_lens=encoder_lens,
@@ -358,6 +367,7 @@ class FlashInferAttnBackend(AttentionBackend):
358
367
  self.indices_updater_prefill.update(
359
368
  req_pool_indices,
360
369
  seq_lens,
370
+ seq_lens.cpu(), # may add a little overhead in capture stage
361
371
  seq_lens_sum,
362
372
  prefix_lens=None,
363
373
  prefill_wrappers=prefill_wrappers,
@@ -387,6 +397,7 @@ class FlashInferAttnBackend(AttentionBackend):
387
397
  self.indices_updater_prefill.update(
388
398
  req_pool_indices,
389
399
  seq_lens,
400
+ seq_lens.cpu(), # may add a little overhead in capture stage
390
401
  seq_lens_sum,
391
402
  prefix_lens=None,
392
403
  prefill_wrappers=prefill_wrappers,
@@ -414,6 +425,7 @@ class FlashInferAttnBackend(AttentionBackend):
414
425
  self.indices_updater_decode.update(
415
426
  req_pool_indices[:bs],
416
427
  seq_lens[:bs],
428
+ seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
417
429
  seq_lens_sum,
418
430
  decode_wrappers=self.decode_cuda_graph_metadata[bs],
419
431
  encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
@@ -423,6 +435,7 @@ class FlashInferAttnBackend(AttentionBackend):
423
435
  self.indices_updater_prefill.update(
424
436
  req_pool_indices[:bs],
425
437
  seq_lens[:bs],
438
+ seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
426
439
  seq_lens_sum,
427
440
  prefix_lens=None,
428
441
  prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
@@ -434,6 +447,7 @@ class FlashInferAttnBackend(AttentionBackend):
434
447
  self.indices_updater_prefill.update(
435
448
  req_pool_indices[:bs],
436
449
  seq_lens[:bs],
450
+ seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
437
451
  seq_lens_sum,
438
452
  prefix_lens=None,
439
453
  prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
@@ -581,7 +595,7 @@ class FlashInferAttnBackend(AttentionBackend):
581
595
 
582
596
 
583
597
  class FlashInferIndicesUpdaterDecode:
584
- def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
598
+ def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):
585
599
  # Parse Constants
586
600
  self.num_qo_heads = (
587
601
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
@@ -614,6 +628,7 @@ class FlashInferIndicesUpdaterDecode:
614
628
  self,
615
629
  req_pool_indices: torch.Tensor,
616
630
  seq_lens: torch.Tensor,
631
+ seq_lens_cpu: Optional[torch.Tensor],
617
632
  seq_lens_sum: int,
618
633
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
619
634
  encoder_lens: Optional[torch.Tensor],
@@ -626,6 +641,7 @@ class FlashInferIndicesUpdaterDecode:
626
641
  self,
627
642
  req_pool_indices: torch.Tensor,
628
643
  seq_lens: torch.Tensor,
644
+ seq_lens_cpu: Optional[torch.Tensor],
629
645
  seq_lens_sum: int,
630
646
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
631
647
  encoder_lens: Optional[torch.Tensor],
@@ -640,30 +656,39 @@ class FlashInferIndicesUpdaterDecode:
640
656
  self.kv_indptr[0],
641
657
  None,
642
658
  spec_info,
659
+ seq_lens_cpu,
643
660
  )
644
661
 
645
662
  def update_sliding_window(
646
663
  self,
647
664
  req_pool_indices: torch.Tensor,
648
665
  seq_lens: torch.Tensor,
666
+ seq_lens_cpu: Optional[torch.Tensor],
649
667
  seq_lens_sum: int,
650
668
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
651
669
  encoder_lens: Optional[torch.Tensor],
652
670
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
653
671
  ):
672
+ assert self.sliding_window_size is not None
654
673
  for wrapper_id in range(2):
655
674
  if wrapper_id == 0:
656
675
  # Sliding window attention
657
- paged_kernel_lens_tmp = torch.minimum( # TODO: replace this with clamp
658
- seq_lens,
659
- torch.tensor(self.sliding_window_size + 1),
676
+ paged_kernel_lens_tmp = torch.clamp(
677
+ seq_lens, max=self.sliding_window_size + 1
660
678
  )
661
- paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
679
+ if seq_lens_cpu is not None:
680
+ seq_lens_cpu_tmp = torch.clamp(
681
+ seq_lens_cpu, max=self.sliding_window_size + 1
682
+ )
683
+ paged_kernel_lens_sum_tmp = seq_lens_cpu_tmp.sum().item()
684
+ else:
685
+ paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
662
686
  kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp
663
687
  else:
664
688
  # Full attention
665
689
  paged_kernel_lens_tmp = seq_lens
666
690
  paged_kernel_lens_sum_tmp = seq_lens_sum
691
+ seq_lens_cpu_tmp = seq_lens_cpu
667
692
  kv_start_idx_tmp = None
668
693
 
669
694
  use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
@@ -678,6 +703,7 @@ class FlashInferIndicesUpdaterDecode:
678
703
  self.kv_indptr[wrapper_id],
679
704
  kv_start_idx_tmp,
680
705
  spec_info,
706
+ seq_lens_cpu=seq_lens_cpu_tmp,
681
707
  use_sliding_window_kv_pool=use_sliding_window_kv_pool,
682
708
  )
683
709
 
@@ -685,6 +711,7 @@ class FlashInferIndicesUpdaterDecode:
685
711
  self,
686
712
  req_pool_indices: torch.Tensor,
687
713
  seq_lens: torch.Tensor,
714
+ seq_lens_cpu: Optional[torch.Tensor],
688
715
  seq_lens_sum: int,
689
716
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
690
717
  encoder_lens: Optional[torch.Tensor],
@@ -709,6 +736,7 @@ class FlashInferIndicesUpdaterDecode:
709
736
  self.kv_indptr[wrapper_id],
710
737
  kv_start_idx,
711
738
  spec_info,
739
+ seq_lens_cpu=seq_lens_cpu,
712
740
  )
713
741
 
714
742
  def call_begin_forward(
@@ -720,6 +748,7 @@ class FlashInferIndicesUpdaterDecode:
720
748
  kv_indptr: torch.Tensor,
721
749
  kv_start_idx: torch.Tensor,
722
750
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
751
+ seq_lens_cpu: Optional[torch.Tensor],
723
752
  use_sliding_window_kv_pool: bool = False,
724
753
  ):
725
754
  if spec_info is None:
@@ -756,6 +785,14 @@ class FlashInferIndicesUpdaterDecode:
756
785
  )
757
786
  )
758
787
 
788
+ global global_override_indptr_cpu
789
+ locally_override = False
790
+ if seq_lens_cpu is not None and global_override_indptr_cpu is None:
791
+ locally_override = True
792
+ global_override_indptr_cpu = torch.empty_like(kv_indptr, device="cpu")
793
+ global_override_indptr_cpu[0] = 0
794
+ global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
795
+
759
796
  wrapper.begin_forward(
760
797
  kv_indptr,
761
798
  kv_indices,
@@ -769,9 +806,12 @@ class FlashInferIndicesUpdaterDecode:
769
806
  non_blocking=True,
770
807
  )
771
808
 
809
+ if locally_override:
810
+ global_override_indptr_cpu = None
811
+
772
812
 
773
813
  class FlashInferIndicesUpdaterPrefill:
774
- def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
814
+ def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):
775
815
  # Parse Constants
776
816
  self.num_qo_heads = (
777
817
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
@@ -806,6 +846,7 @@ class FlashInferIndicesUpdaterPrefill:
806
846
  self,
807
847
  req_pool_indices: torch.Tensor,
808
848
  seq_lens: torch.Tensor,
849
+ seq_lens_cpu: Optional[torch.Tensor],
809
850
  seq_lens_sum: int,
810
851
  prefix_lens: torch.Tensor,
811
852
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
@@ -820,6 +861,7 @@ class FlashInferIndicesUpdaterPrefill:
820
861
  self,
821
862
  req_pool_indices: torch.Tensor,
822
863
  seq_lens: torch.Tensor,
864
+ seq_lens_cpu: Optional[torch.Tensor],
823
865
  seq_lens_sum: int,
824
866
  prefix_lens: torch.Tensor,
825
867
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
@@ -853,6 +895,7 @@ class FlashInferIndicesUpdaterPrefill:
853
895
  self,
854
896
  req_pool_indices: torch.Tensor,
855
897
  seq_lens: torch.Tensor,
898
+ seq_lens_cpu: Optional[torch.Tensor],
856
899
  seq_lens_sum: int,
857
900
  prefix_lens: torch.Tensor,
858
901
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
@@ -898,6 +941,7 @@ class FlashInferIndicesUpdaterPrefill:
898
941
  self,
899
942
  req_pool_indices: torch.Tensor,
900
943
  seq_lens: torch.Tensor,
944
+ seq_lens_cpu: Optional[torch.Tensor],
901
945
  seq_lens_sum: int,
902
946
  prefix_lens: torch.Tensor,
903
947
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
@@ -1020,11 +1064,6 @@ class FlashInferIndicesUpdaterPrefill:
1020
1064
  )
1021
1065
 
1022
1066
 
1023
- # Use as a fast path to override the indptr in flashinfer's plan function
1024
- # This is used to remove some host-to-device copy overhead.
1025
- global global_override_indptr_cpu
1026
-
1027
-
1028
1067
  class FlashInferMultiStepDraftBackend:
1029
1068
  """
1030
1069
  Wrap multiple flashinfer attention backends as one for multiple consecutive
@@ -1056,7 +1095,7 @@ class FlashInferMultiStepDraftBackend:
1056
1095
  self.kv_last_page_len = torch.ones(
1057
1096
  (max_bs,), dtype=torch.int32, device=model_runner.device
1058
1097
  )
1059
- self.attn_backends = []
1098
+ self.attn_backends: List[FlashInferAttnBackend] = []
1060
1099
  for i in range(self.speculative_num_steps):
1061
1100
  self.attn_backends.append(
1062
1101
  FlashInferAttnBackend(
@@ -1176,7 +1215,7 @@ class FlashInferMultiStepDraftBackend:
1176
1215
  encoder_lens=None,
1177
1216
  forward_mode=ForwardMode.DECODE,
1178
1217
  spec_info=forward_batch.spec_info,
1179
- seq_lens_cpu=None,
1218
+ seq_lens_cpu=forward_batch.seq_lens_cpu,
1180
1219
  )
1181
1220
 
1182
1221
  self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
@@ -1,4 +1,4 @@
1
- from typing import TYPE_CHECKING, Optional, Union
1
+ from typing import Optional, Union
2
2
 
3
3
  import torch
4
4
 
@@ -287,38 +287,135 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
287
287
  )
288
288
  forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
289
289
 
290
+ def quantize_and_rope_for_fp8(
291
+ self,
292
+ q_nope: torch.Tensor,
293
+ q_rope: torch.Tensor,
294
+ k_nope: torch.Tensor,
295
+ k_rope: torch.Tensor,
296
+ forward_batch: ForwardBatch,
297
+ cos_sin_cache: torch.Tensor,
298
+ is_neox: bool,
299
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
300
+ """Quantize and apply RoPE for FP8 attention path.
301
+
302
+ This function handles the FP8 quantization and RoPE application for MLA attention.
303
+ It takes separate query/key nope and rope components, applies RoPE to the rope parts,
304
+ quantizes all components to FP8, and merges the query components into a single tensor.
305
+
306
+ Args:
307
+ q_nope: Query no-position-encoding component [seq_len, num_heads, kv_lora_rank]
308
+ - expected dtype: torch.bfloat16
309
+ q_rope: Query RoPE component [seq_len, num_heads, qk_rope_head_dim]
310
+ - expected dtype: torch.bfloat16
311
+ k_nope: Key no-position-encoding component [seq_len, num_heads, kv_lora_rank]
312
+ - expected dtype: torch.bfloat16
313
+ k_rope: Key RoPE component [seq_len, num_heads, qk_rope_head_dim]
314
+ - expected dtype: torch.bfloat16
315
+ forward_batch: Forward batch containing position information
316
+ cos_sin_cache: Precomputed cosine/sine cache for RoPE
317
+ - expected dtype: matches q_/k_ input dtype (torch.bfloat16)
318
+ is_neox: Whether to use NeoX-style RoPE (interleaved) or GPT-style (half rotation)
319
+
320
+ Returns:
321
+ tuple: (merged_q_out, k_nope_out, k_rope_out) quantized to FP8
322
+ - merged_q_out: [seq_len, num_heads, kv_lora_rank + qk_rope_head_dim], dtype=torch.float8_e4m3fn
323
+ - k_nope_out: [seq_len, num_heads, kv_lora_rank], dtype=torch.float8_e4m3fn
324
+ - k_rope_out: [seq_len, num_heads, qk_rope_head_dim], dtype=torch.float8_e4m3fn
325
+ """
326
+ attn_dtype = torch.float8_e4m3fn
327
+ q_len, num_heads = q_rope.shape[0], q_rope.shape[1]
328
+
329
+ # Allocate output tensors with FP8 dtype
330
+ # Query output will contain merged nope + rope components
331
+ q_out = q_rope.new_empty(
332
+ q_len,
333
+ num_heads,
334
+ self.kv_lora_rank + self.qk_rope_head_dim,
335
+ dtype=attn_dtype,
336
+ )
337
+
338
+ # Key outputs maintain original shapes but with FP8 dtype
339
+ k_rope_out = k_rope.new_empty(k_rope.shape, dtype=attn_dtype)
340
+ k_nope_out = k_nope.new_empty(k_nope.shape, dtype=attn_dtype)
341
+
342
+ # Apply RoPE and quantize all components in a single fused kernel call
343
+ # This kernel handles:
344
+ # 1. RoPE application to q_rope and k_rope using cos_sin_cache and positions
345
+ # 2. Quantization of all components to FP8 format
346
+ # 3. Output placement into pre-allocated tensors
347
+ flashinfer.rope.mla_rope_quantize_fp8(
348
+ q_rope=q_rope,
349
+ k_rope=k_rope,
350
+ q_nope=q_nope,
351
+ k_nope=k_nope,
352
+ cos_sin_cache=cos_sin_cache,
353
+ pos_ids=forward_batch.positions,
354
+ is_neox=is_neox,
355
+ quantize_dtype=attn_dtype,
356
+ # Output tensor slicing: q_out contains [nope_part, rope_part]
357
+ q_rope_out=q_out[..., self.kv_lora_rank :], # RoPE part goes to end
358
+ k_rope_out=k_rope_out,
359
+ q_nope_out=q_out[..., : self.kv_lora_rank], # Nope part goes to beginning
360
+ k_nope_out=k_nope_out,
361
+ # Quantization scales (set to 1.0 for no additional scaling)
362
+ quant_scale_q=1.0,
363
+ quant_scale_kv=1.0,
364
+ )
365
+
366
+ return q_out, k_nope_out, k_rope_out
367
+
290
368
  def forward_decode(
291
369
  self,
292
- q: torch.Tensor,
293
- k: torch.Tensor,
294
- v: torch.Tensor,
370
+ q: torch.Tensor, # q_nope
371
+ k: torch.Tensor, # k_nope
372
+ v: torch.Tensor, # not used in this backend
295
373
  layer: RadixAttention,
296
374
  forward_batch: ForwardBatch,
297
375
  save_kv_cache: bool = True,
298
376
  q_rope: Optional[torch.Tensor] = None,
299
377
  k_rope: Optional[torch.Tensor] = None,
378
+ cos_sin_cache: Optional[torch.Tensor] = None,
379
+ is_neox: Optional[bool] = False,
300
380
  ) -> torch.Tensor:
301
381
  """Run forward for decode using TRTLLM MLA kernel."""
382
+ merge_query = q_rope is not None
383
+ if self.data_type == torch.float8_e4m3fn:
384
+ # For FP8 path, we quantize the query and rope parts and merge them into a single tensor
385
+ # Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
386
+ assert all(
387
+ x is not None for x in [q_rope, k_rope, cos_sin_cache]
388
+ ), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
389
+ q, k, k_rope = self.quantize_and_rope_for_fp8(
390
+ q,
391
+ q_rope,
392
+ k.squeeze(1),
393
+ k_rope.squeeze(1),
394
+ forward_batch,
395
+ cos_sin_cache,
396
+ is_neox,
397
+ )
398
+ merge_query = False
399
+
302
400
  # Save KV cache if requested
303
- if k is not None and save_kv_cache:
304
- cache_loc = forward_batch.out_cache_loc
305
- if k_rope is not None:
306
- forward_batch.token_to_kv_pool.set_mla_kv_buffer(
307
- layer, cache_loc, k, k_rope
308
- )
309
- elif v is not None:
310
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
401
+ if save_kv_cache:
402
+ assert (
403
+ k is not None and k_rope is not None
404
+ ), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
405
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
406
+ layer, forward_batch.out_cache_loc, k, k_rope
407
+ )
311
408
 
312
409
  # Prepare query tensor inline
313
- if q_rope is not None:
314
- # q contains NOPE part (v_head_dim)
410
+ if merge_query:
411
+ # For FP16 path, we merge the query and rope parts into a single tensor
315
412
  q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
316
413
  q_rope_reshaped = q_rope.view(
317
414
  -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
318
415
  )
319
416
  query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
320
417
  else:
321
- # q already has both parts
418
+ # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
322
419
  query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
323
420
 
324
421
  # Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
@@ -327,9 +424,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
327
424
 
328
425
  # Prepare KV cache inline
329
426
  k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
330
- pages = k_cache.view(-1, self.page_size, self.kv_cache_dim)
331
- # TRT-LLM expects single KV data with extra dimension
332
- kv_cache = pages.unsqueeze(1)
427
+ kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
333
428
 
334
429
  # Get metadata
335
430
  metadata = (
@@ -337,11 +432,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
337
432
  or self.forward_metadata
338
433
  )
339
434
 
340
- # Scale computation for TRTLLM MLA kernel:
341
- # - BMM1 scale = q_scale * k_scale * softmax_scale
342
- # - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling
343
- # - k_scale is read from model checkpoint if available
344
- # TODO: Change once fp8 path is supported
435
+ # Scale computation for TRTLLM MLA kernel BMM1 operation:
436
+ # The final BMM1 scale is computed as: q_scale * k_scale * softmax_scale
437
+ # Scale components:
438
+ # - q_scale: Query scaling factor (set to 1.0 for both FP16/FP8 paths)
439
+ # - k_scale: Key scaling factor from model checkpoint (defaults to 1.0 if not available)
440
+ # - softmax_scale: Attention softmax scaling = 1/sqrt(head_dim), pre-computed as layer.scaling
441
+ # This unified approach works for both FP16 and FP8 quantized attention paths.
345
442
  q_scale = 1.0
346
443
  k_scale = (
347
444
  layer.k_scale_float
@@ -245,6 +245,8 @@ class VisionTritonAttention(nn.Module):
245
245
  k: torch.Tensor,
246
246
  v: torch.Tensor,
247
247
  cu_seqlens: Optional[torch.Tensor],
248
+ bsz: int,
249
+ seq_len: int,
248
250
  **kwargs,
249
251
  ) -> torch.Tensor:
250
252
  r"""
@@ -253,6 +255,8 @@ class VisionTritonAttention(nn.Module):
253
255
  Returns:
254
256
  [b * s, h, head_size]
255
257
  """
258
+ if cu_seqlens is None:
259
+ cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
256
260
 
257
261
  # [b * s, head, head_size]
258
262
  output = torch.empty_like(q)
@@ -401,7 +405,11 @@ class VisionAttention(nn.Module):
401
405
  # priority: server_args > passed qkv_backend > sdpa
402
406
  if global_server_args_dict["mm_attention_backend"] is None:
403
407
  if qkv_backend is None:
404
- qkv_backend = "sdpa"
408
+ if is_cuda():
409
+ # Double prefill throughput by setting attn backend to Triton on CUDA
410
+ qkv_backend = "triton_attn"
411
+ else:
412
+ qkv_backend = "sdpa"
405
413
  print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
406
414
  else:
407
415
  qkv_backend = global_server_args_dict["mm_attention_backend"]