sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -0
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +7 -7
- sglang/srt/disaggregation/decode.py +8 -3
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +4 -5
- sglang/srt/entrypoints/openai/protocol.py +0 -9
- sglang/srt/entrypoints/openai/serving_chat.py +59 -265
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +8 -10
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/quantization/__init__.py +5 -3
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/modelopt_quant.py +6 -11
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +21 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +6 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +35 -20
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +15 -7
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +25 -26
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +22 -3
- sglang/srt/model_executor/forward_batch_info.py +26 -5
- sglang/srt/model_executor/model_runner.py +129 -35
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_v2.py +74 -35
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +9 -9
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +136 -19
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/server_args.py +115 -139
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +12 -4
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -483,7 +483,7 @@ class DualChunkFlashAttentionBackend(AttentionBackend):
|
|
483
483
|
).squeeze(1)
|
484
484
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
485
485
|
|
486
|
-
def init_cuda_graph_state(self, max_bs: int):
|
486
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
487
487
|
"""Initialize CUDA graph state for the attention backend.
|
488
488
|
|
489
489
|
Args:
|
@@ -629,6 +629,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
629
629
|
# For multi-head latent attention
|
630
630
|
q_rope: Optional[torch.Tensor] = None,
|
631
631
|
k_rope: Optional[torch.Tensor] = None,
|
632
|
+
sinks: Optional[torch.Tensor] = None,
|
632
633
|
):
|
633
634
|
if k is not None:
|
634
635
|
assert v is not None
|
@@ -687,6 +688,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
687
688
|
forward_batch.forward_mode.is_target_verify() and self.topk > 1
|
688
689
|
)
|
689
690
|
|
691
|
+
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
692
|
+
kwargs = {}
|
693
|
+
if sinks is not None:
|
694
|
+
kwargs["sinks"] = sinks
|
695
|
+
|
690
696
|
# Get the appropriate page table based on whether we're using local attention
|
691
697
|
if use_local_attn:
|
692
698
|
local_metadata = metadata.local_attn_metadata
|
@@ -737,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
737
743
|
k_descale=k_descale,
|
738
744
|
v_descale=v_descale,
|
739
745
|
return_softmax_lse=use_cascade_attn,
|
746
|
+
**kwargs,
|
740
747
|
)
|
741
748
|
|
742
749
|
if use_cascade_attn:
|
@@ -757,6 +764,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
757
764
|
k_descale=k_descale,
|
758
765
|
v_descale=v_descale,
|
759
766
|
return_softmax_lse=True,
|
767
|
+
**kwargs,
|
760
768
|
)
|
761
769
|
o, _ = merge_state_v2_wrapper(
|
762
770
|
o,
|
@@ -898,6 +906,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
898
906
|
# For multi-head latent attention
|
899
907
|
q_rope: Optional[torch.Tensor] = None,
|
900
908
|
k_rope: Optional[torch.Tensor] = None,
|
909
|
+
sinks: Optional[torch.Tensor] = None,
|
901
910
|
) -> torch.Tensor:
|
902
911
|
if k is not None:
|
903
912
|
assert v is not None
|
@@ -943,6 +952,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
943
952
|
)
|
944
953
|
causal = not layer.is_cross_attention
|
945
954
|
|
955
|
+
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
956
|
+
kwargs = {}
|
957
|
+
if sinks is not None:
|
958
|
+
kwargs["sinks"] = sinks
|
959
|
+
|
946
960
|
k_descale, v_descale = None, None
|
947
961
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
948
962
|
# has corresponding quantization method so that layer.k_scale is not None,
|
@@ -985,6 +999,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
985
999
|
softcap=layer.logit_cap,
|
986
1000
|
k_descale=k_descale,
|
987
1001
|
v_descale=v_descale,
|
1002
|
+
**kwargs,
|
988
1003
|
)
|
989
1004
|
elif use_local_attn:
|
990
1005
|
# Use chunked (local) attention batching for self-attention
|
@@ -1003,6 +1018,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1003
1018
|
softcap=layer.logit_cap,
|
1004
1019
|
k_descale=k_descale,
|
1005
1020
|
v_descale=v_descale,
|
1021
|
+
**kwargs,
|
1006
1022
|
)
|
1007
1023
|
else:
|
1008
1024
|
page_table = metadata.page_table
|
@@ -1030,6 +1046,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1030
1046
|
k_descale=k_descale,
|
1031
1047
|
v_descale=v_descale,
|
1032
1048
|
return_softmax_lse=use_cascade_attn,
|
1049
|
+
**kwargs,
|
1033
1050
|
)
|
1034
1051
|
if use_cascade_attn:
|
1035
1052
|
o, softmax_lse, *rest = result
|
@@ -1050,6 +1067,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1050
1067
|
k_descale=k_descale,
|
1051
1068
|
v_descale=v_descale,
|
1052
1069
|
return_softmax_lse=True,
|
1070
|
+
**kwargs,
|
1053
1071
|
)
|
1054
1072
|
)
|
1055
1073
|
o, _ = merge_state_v2(
|
@@ -66,6 +66,10 @@ class PrefillMetadata:
|
|
66
66
|
# Reuse this workspace buffer across all flashinfer wrappers
|
67
67
|
global_workspace_buffer = None
|
68
68
|
|
69
|
+
# Use as a fast path to override the indptr in flashinfer's plan function
|
70
|
+
# This is used to remove some host-to-device copy overhead.
|
71
|
+
global_override_indptr_cpu = None
|
72
|
+
|
69
73
|
|
70
74
|
class FlashInferAttnBackend(AttentionBackend):
|
71
75
|
"""Flashinfer attention kernels."""
|
@@ -205,6 +209,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
205
209
|
self.indices_updater_decode.update(
|
206
210
|
forward_batch.req_pool_indices,
|
207
211
|
forward_batch.seq_lens,
|
212
|
+
forward_batch.seq_lens_cpu,
|
208
213
|
forward_batch.seq_lens_sum,
|
209
214
|
decode_wrappers=self.decode_wrappers,
|
210
215
|
encoder_lens=forward_batch.encoder_lens,
|
@@ -215,6 +220,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
215
220
|
self.indices_updater_prefill.update(
|
216
221
|
forward_batch.req_pool_indices,
|
217
222
|
forward_batch.seq_lens,
|
223
|
+
forward_batch.seq_lens_cpu,
|
218
224
|
forward_batch.seq_lens_sum,
|
219
225
|
prefix_lens=None,
|
220
226
|
prefill_wrappers=self.prefill_wrappers_paged,
|
@@ -229,6 +235,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
229
235
|
self.indices_updater_prefill.update(
|
230
236
|
forward_batch.req_pool_indices,
|
231
237
|
forward_batch.seq_lens,
|
238
|
+
forward_batch.seq_lens_cpu,
|
232
239
|
forward_batch.seq_lens_sum,
|
233
240
|
prefix_lens=None,
|
234
241
|
prefill_wrappers=self.prefill_wrappers_verify,
|
@@ -252,6 +259,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
252
259
|
self.indices_updater_prefill.update(
|
253
260
|
forward_batch.req_pool_indices,
|
254
261
|
forward_batch.seq_lens,
|
262
|
+
forward_batch.seq_lens_cpu,
|
255
263
|
forward_batch.seq_lens_sum,
|
256
264
|
prefix_lens,
|
257
265
|
prefill_wrappers=self.prefill_wrappers_paged,
|
@@ -327,6 +335,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
327
335
|
self.indices_updater_decode.update(
|
328
336
|
req_pool_indices,
|
329
337
|
seq_lens,
|
338
|
+
seq_lens.cpu(), # may add a little overhead in capture stage
|
330
339
|
seq_lens_sum,
|
331
340
|
decode_wrappers=decode_wrappers,
|
332
341
|
encoder_lens=encoder_lens,
|
@@ -358,6 +367,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
358
367
|
self.indices_updater_prefill.update(
|
359
368
|
req_pool_indices,
|
360
369
|
seq_lens,
|
370
|
+
seq_lens.cpu(), # may add a little overhead in capture stage
|
361
371
|
seq_lens_sum,
|
362
372
|
prefix_lens=None,
|
363
373
|
prefill_wrappers=prefill_wrappers,
|
@@ -387,6 +397,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
387
397
|
self.indices_updater_prefill.update(
|
388
398
|
req_pool_indices,
|
389
399
|
seq_lens,
|
400
|
+
seq_lens.cpu(), # may add a little overhead in capture stage
|
390
401
|
seq_lens_sum,
|
391
402
|
prefix_lens=None,
|
392
403
|
prefill_wrappers=prefill_wrappers,
|
@@ -414,6 +425,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
414
425
|
self.indices_updater_decode.update(
|
415
426
|
req_pool_indices[:bs],
|
416
427
|
seq_lens[:bs],
|
428
|
+
seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
|
417
429
|
seq_lens_sum,
|
418
430
|
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
419
431
|
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
@@ -423,6 +435,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
423
435
|
self.indices_updater_prefill.update(
|
424
436
|
req_pool_indices[:bs],
|
425
437
|
seq_lens[:bs],
|
438
|
+
seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
|
426
439
|
seq_lens_sum,
|
427
440
|
prefix_lens=None,
|
428
441
|
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
|
@@ -434,6 +447,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
434
447
|
self.indices_updater_prefill.update(
|
435
448
|
req_pool_indices[:bs],
|
436
449
|
seq_lens[:bs],
|
450
|
+
seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
|
437
451
|
seq_lens_sum,
|
438
452
|
prefix_lens=None,
|
439
453
|
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
|
@@ -581,7 +595,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
581
595
|
|
582
596
|
|
583
597
|
class FlashInferIndicesUpdaterDecode:
|
584
|
-
def __init__(self, model_runner: ModelRunner, attn_backend:
|
598
|
+
def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):
|
585
599
|
# Parse Constants
|
586
600
|
self.num_qo_heads = (
|
587
601
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
@@ -614,6 +628,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
614
628
|
self,
|
615
629
|
req_pool_indices: torch.Tensor,
|
616
630
|
seq_lens: torch.Tensor,
|
631
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
617
632
|
seq_lens_sum: int,
|
618
633
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
619
634
|
encoder_lens: Optional[torch.Tensor],
|
@@ -626,6 +641,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
626
641
|
self,
|
627
642
|
req_pool_indices: torch.Tensor,
|
628
643
|
seq_lens: torch.Tensor,
|
644
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
629
645
|
seq_lens_sum: int,
|
630
646
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
631
647
|
encoder_lens: Optional[torch.Tensor],
|
@@ -640,30 +656,39 @@ class FlashInferIndicesUpdaterDecode:
|
|
640
656
|
self.kv_indptr[0],
|
641
657
|
None,
|
642
658
|
spec_info,
|
659
|
+
seq_lens_cpu,
|
643
660
|
)
|
644
661
|
|
645
662
|
def update_sliding_window(
|
646
663
|
self,
|
647
664
|
req_pool_indices: torch.Tensor,
|
648
665
|
seq_lens: torch.Tensor,
|
666
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
649
667
|
seq_lens_sum: int,
|
650
668
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
651
669
|
encoder_lens: Optional[torch.Tensor],
|
652
670
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
653
671
|
):
|
672
|
+
assert self.sliding_window_size is not None
|
654
673
|
for wrapper_id in range(2):
|
655
674
|
if wrapper_id == 0:
|
656
675
|
# Sliding window attention
|
657
|
-
paged_kernel_lens_tmp = torch.
|
658
|
-
seq_lens,
|
659
|
-
torch.tensor(self.sliding_window_size + 1),
|
676
|
+
paged_kernel_lens_tmp = torch.clamp(
|
677
|
+
seq_lens, max=self.sliding_window_size + 1
|
660
678
|
)
|
661
|
-
|
679
|
+
if seq_lens_cpu is not None:
|
680
|
+
seq_lens_cpu_tmp = torch.clamp(
|
681
|
+
seq_lens_cpu, max=self.sliding_window_size + 1
|
682
|
+
)
|
683
|
+
paged_kernel_lens_sum_tmp = seq_lens_cpu_tmp.sum().item()
|
684
|
+
else:
|
685
|
+
paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
|
662
686
|
kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp
|
663
687
|
else:
|
664
688
|
# Full attention
|
665
689
|
paged_kernel_lens_tmp = seq_lens
|
666
690
|
paged_kernel_lens_sum_tmp = seq_lens_sum
|
691
|
+
seq_lens_cpu_tmp = seq_lens_cpu
|
667
692
|
kv_start_idx_tmp = None
|
668
693
|
|
669
694
|
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
|
@@ -678,6 +703,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
678
703
|
self.kv_indptr[wrapper_id],
|
679
704
|
kv_start_idx_tmp,
|
680
705
|
spec_info,
|
706
|
+
seq_lens_cpu=seq_lens_cpu_tmp,
|
681
707
|
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
682
708
|
)
|
683
709
|
|
@@ -685,6 +711,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
685
711
|
self,
|
686
712
|
req_pool_indices: torch.Tensor,
|
687
713
|
seq_lens: torch.Tensor,
|
714
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
688
715
|
seq_lens_sum: int,
|
689
716
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
690
717
|
encoder_lens: Optional[torch.Tensor],
|
@@ -709,6 +736,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
709
736
|
self.kv_indptr[wrapper_id],
|
710
737
|
kv_start_idx,
|
711
738
|
spec_info,
|
739
|
+
seq_lens_cpu=seq_lens_cpu,
|
712
740
|
)
|
713
741
|
|
714
742
|
def call_begin_forward(
|
@@ -720,6 +748,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
720
748
|
kv_indptr: torch.Tensor,
|
721
749
|
kv_start_idx: torch.Tensor,
|
722
750
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
751
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
723
752
|
use_sliding_window_kv_pool: bool = False,
|
724
753
|
):
|
725
754
|
if spec_info is None:
|
@@ -756,6 +785,14 @@ class FlashInferIndicesUpdaterDecode:
|
|
756
785
|
)
|
757
786
|
)
|
758
787
|
|
788
|
+
global global_override_indptr_cpu
|
789
|
+
locally_override = False
|
790
|
+
if seq_lens_cpu is not None and global_override_indptr_cpu is None:
|
791
|
+
locally_override = True
|
792
|
+
global_override_indptr_cpu = torch.empty_like(kv_indptr, device="cpu")
|
793
|
+
global_override_indptr_cpu[0] = 0
|
794
|
+
global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
|
795
|
+
|
759
796
|
wrapper.begin_forward(
|
760
797
|
kv_indptr,
|
761
798
|
kv_indices,
|
@@ -769,9 +806,12 @@ class FlashInferIndicesUpdaterDecode:
|
|
769
806
|
non_blocking=True,
|
770
807
|
)
|
771
808
|
|
809
|
+
if locally_override:
|
810
|
+
global_override_indptr_cpu = None
|
811
|
+
|
772
812
|
|
773
813
|
class FlashInferIndicesUpdaterPrefill:
|
774
|
-
def __init__(self, model_runner: ModelRunner, attn_backend:
|
814
|
+
def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):
|
775
815
|
# Parse Constants
|
776
816
|
self.num_qo_heads = (
|
777
817
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
@@ -806,6 +846,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
806
846
|
self,
|
807
847
|
req_pool_indices: torch.Tensor,
|
808
848
|
seq_lens: torch.Tensor,
|
849
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
809
850
|
seq_lens_sum: int,
|
810
851
|
prefix_lens: torch.Tensor,
|
811
852
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
@@ -820,6 +861,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
820
861
|
self,
|
821
862
|
req_pool_indices: torch.Tensor,
|
822
863
|
seq_lens: torch.Tensor,
|
864
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
823
865
|
seq_lens_sum: int,
|
824
866
|
prefix_lens: torch.Tensor,
|
825
867
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
@@ -853,6 +895,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
853
895
|
self,
|
854
896
|
req_pool_indices: torch.Tensor,
|
855
897
|
seq_lens: torch.Tensor,
|
898
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
856
899
|
seq_lens_sum: int,
|
857
900
|
prefix_lens: torch.Tensor,
|
858
901
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
@@ -898,6 +941,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
898
941
|
self,
|
899
942
|
req_pool_indices: torch.Tensor,
|
900
943
|
seq_lens: torch.Tensor,
|
944
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
901
945
|
seq_lens_sum: int,
|
902
946
|
prefix_lens: torch.Tensor,
|
903
947
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
@@ -1020,11 +1064,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|
1020
1064
|
)
|
1021
1065
|
|
1022
1066
|
|
1023
|
-
# Use as a fast path to override the indptr in flashinfer's plan function
|
1024
|
-
# This is used to remove some host-to-device copy overhead.
|
1025
|
-
global global_override_indptr_cpu
|
1026
|
-
|
1027
|
-
|
1028
1067
|
class FlashInferMultiStepDraftBackend:
|
1029
1068
|
"""
|
1030
1069
|
Wrap multiple flashinfer attention backends as one for multiple consecutive
|
@@ -1056,7 +1095,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1056
1095
|
self.kv_last_page_len = torch.ones(
|
1057
1096
|
(max_bs,), dtype=torch.int32, device=model_runner.device
|
1058
1097
|
)
|
1059
|
-
self.attn_backends = []
|
1098
|
+
self.attn_backends: List[FlashInferAttnBackend] = []
|
1060
1099
|
for i in range(self.speculative_num_steps):
|
1061
1100
|
self.attn_backends.append(
|
1062
1101
|
FlashInferAttnBackend(
|
@@ -1176,7 +1215,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1176
1215
|
encoder_lens=None,
|
1177
1216
|
forward_mode=ForwardMode.DECODE,
|
1178
1217
|
spec_info=forward_batch.spec_info,
|
1179
|
-
seq_lens_cpu=
|
1218
|
+
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
1180
1219
|
)
|
1181
1220
|
|
1182
1221
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
@@ -287,38 +287,135 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
287
287
|
)
|
288
288
|
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
|
289
289
|
|
290
|
+
def quantize_and_rope_for_fp8(
|
291
|
+
self,
|
292
|
+
q_nope: torch.Tensor,
|
293
|
+
q_rope: torch.Tensor,
|
294
|
+
k_nope: torch.Tensor,
|
295
|
+
k_rope: torch.Tensor,
|
296
|
+
forward_batch: ForwardBatch,
|
297
|
+
cos_sin_cache: torch.Tensor,
|
298
|
+
is_neox: bool,
|
299
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
300
|
+
"""Quantize and apply RoPE for FP8 attention path.
|
301
|
+
|
302
|
+
This function handles the FP8 quantization and RoPE application for MLA attention.
|
303
|
+
It takes separate query/key nope and rope components, applies RoPE to the rope parts,
|
304
|
+
quantizes all components to FP8, and merges the query components into a single tensor.
|
305
|
+
|
306
|
+
Args:
|
307
|
+
q_nope: Query no-position-encoding component [seq_len, num_heads, kv_lora_rank]
|
308
|
+
- expected dtype: torch.bfloat16
|
309
|
+
q_rope: Query RoPE component [seq_len, num_heads, qk_rope_head_dim]
|
310
|
+
- expected dtype: torch.bfloat16
|
311
|
+
k_nope: Key no-position-encoding component [seq_len, num_heads, kv_lora_rank]
|
312
|
+
- expected dtype: torch.bfloat16
|
313
|
+
k_rope: Key RoPE component [seq_len, num_heads, qk_rope_head_dim]
|
314
|
+
- expected dtype: torch.bfloat16
|
315
|
+
forward_batch: Forward batch containing position information
|
316
|
+
cos_sin_cache: Precomputed cosine/sine cache for RoPE
|
317
|
+
- expected dtype: matches q_/k_ input dtype (torch.bfloat16)
|
318
|
+
is_neox: Whether to use NeoX-style RoPE (interleaved) or GPT-style (half rotation)
|
319
|
+
|
320
|
+
Returns:
|
321
|
+
tuple: (merged_q_out, k_nope_out, k_rope_out) quantized to FP8
|
322
|
+
- merged_q_out: [seq_len, num_heads, kv_lora_rank + qk_rope_head_dim], dtype=torch.float8_e4m3fn
|
323
|
+
- k_nope_out: [seq_len, num_heads, kv_lora_rank], dtype=torch.float8_e4m3fn
|
324
|
+
- k_rope_out: [seq_len, num_heads, qk_rope_head_dim], dtype=torch.float8_e4m3fn
|
325
|
+
"""
|
326
|
+
attn_dtype = torch.float8_e4m3fn
|
327
|
+
q_len, num_heads = q_rope.shape[0], q_rope.shape[1]
|
328
|
+
|
329
|
+
# Allocate output tensors with FP8 dtype
|
330
|
+
# Query output will contain merged nope + rope components
|
331
|
+
q_out = q_rope.new_empty(
|
332
|
+
q_len,
|
333
|
+
num_heads,
|
334
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
335
|
+
dtype=attn_dtype,
|
336
|
+
)
|
337
|
+
|
338
|
+
# Key outputs maintain original shapes but with FP8 dtype
|
339
|
+
k_rope_out = k_rope.new_empty(k_rope.shape, dtype=attn_dtype)
|
340
|
+
k_nope_out = k_nope.new_empty(k_nope.shape, dtype=attn_dtype)
|
341
|
+
|
342
|
+
# Apply RoPE and quantize all components in a single fused kernel call
|
343
|
+
# This kernel handles:
|
344
|
+
# 1. RoPE application to q_rope and k_rope using cos_sin_cache and positions
|
345
|
+
# 2. Quantization of all components to FP8 format
|
346
|
+
# 3. Output placement into pre-allocated tensors
|
347
|
+
flashinfer.rope.mla_rope_quantize_fp8(
|
348
|
+
q_rope=q_rope,
|
349
|
+
k_rope=k_rope,
|
350
|
+
q_nope=q_nope,
|
351
|
+
k_nope=k_nope,
|
352
|
+
cos_sin_cache=cos_sin_cache,
|
353
|
+
pos_ids=forward_batch.positions,
|
354
|
+
is_neox=is_neox,
|
355
|
+
quantize_dtype=attn_dtype,
|
356
|
+
# Output tensor slicing: q_out contains [nope_part, rope_part]
|
357
|
+
q_rope_out=q_out[..., self.kv_lora_rank :], # RoPE part goes to end
|
358
|
+
k_rope_out=k_rope_out,
|
359
|
+
q_nope_out=q_out[..., : self.kv_lora_rank], # Nope part goes to beginning
|
360
|
+
k_nope_out=k_nope_out,
|
361
|
+
# Quantization scales (set to 1.0 for no additional scaling)
|
362
|
+
quant_scale_q=1.0,
|
363
|
+
quant_scale_kv=1.0,
|
364
|
+
)
|
365
|
+
|
366
|
+
return q_out, k_nope_out, k_rope_out
|
367
|
+
|
290
368
|
def forward_decode(
|
291
369
|
self,
|
292
|
-
q: torch.Tensor,
|
293
|
-
k: torch.Tensor,
|
294
|
-
v: torch.Tensor,
|
370
|
+
q: torch.Tensor, # q_nope
|
371
|
+
k: torch.Tensor, # k_nope
|
372
|
+
v: torch.Tensor, # not used in this backend
|
295
373
|
layer: RadixAttention,
|
296
374
|
forward_batch: ForwardBatch,
|
297
375
|
save_kv_cache: bool = True,
|
298
376
|
q_rope: Optional[torch.Tensor] = None,
|
299
377
|
k_rope: Optional[torch.Tensor] = None,
|
378
|
+
cos_sin_cache: Optional[torch.Tensor] = None,
|
379
|
+
is_neox: Optional[bool] = False,
|
300
380
|
) -> torch.Tensor:
|
301
381
|
"""Run forward for decode using TRTLLM MLA kernel."""
|
382
|
+
merge_query = q_rope is not None
|
383
|
+
if self.data_type == torch.float8_e4m3fn:
|
384
|
+
# For FP8 path, we quantize the query and rope parts and merge them into a single tensor
|
385
|
+
# Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
|
386
|
+
assert all(
|
387
|
+
x is not None for x in [q_rope, k_rope, cos_sin_cache]
|
388
|
+
), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
|
389
|
+
q, k, k_rope = self.quantize_and_rope_for_fp8(
|
390
|
+
q,
|
391
|
+
q_rope,
|
392
|
+
k.squeeze(1),
|
393
|
+
k_rope.squeeze(1),
|
394
|
+
forward_batch,
|
395
|
+
cos_sin_cache,
|
396
|
+
is_neox,
|
397
|
+
)
|
398
|
+
merge_query = False
|
399
|
+
|
302
400
|
# Save KV cache if requested
|
303
|
-
if
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
401
|
+
if save_kv_cache:
|
402
|
+
assert (
|
403
|
+
k is not None and k_rope is not None
|
404
|
+
), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
|
405
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
406
|
+
layer, forward_batch.out_cache_loc, k, k_rope
|
407
|
+
)
|
311
408
|
|
312
409
|
# Prepare query tensor inline
|
313
|
-
if
|
314
|
-
#
|
410
|
+
if merge_query:
|
411
|
+
# For FP16 path, we merge the query and rope parts into a single tensor
|
315
412
|
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
316
413
|
q_rope_reshaped = q_rope.view(
|
317
414
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
318
415
|
)
|
319
416
|
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
|
320
417
|
else:
|
321
|
-
#
|
418
|
+
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
322
419
|
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
323
420
|
|
324
421
|
# Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
|
@@ -327,9 +424,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
327
424
|
|
328
425
|
# Prepare KV cache inline
|
329
426
|
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
330
|
-
|
331
|
-
# TRT-LLM expects single KV data with extra dimension
|
332
|
-
kv_cache = pages.unsqueeze(1)
|
427
|
+
kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
|
333
428
|
|
334
429
|
# Get metadata
|
335
430
|
metadata = (
|
@@ -337,11 +432,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
337
432
|
or self.forward_metadata
|
338
433
|
)
|
339
434
|
|
340
|
-
# Scale computation for TRTLLM MLA kernel:
|
341
|
-
#
|
342
|
-
#
|
343
|
-
# -
|
344
|
-
#
|
435
|
+
# Scale computation for TRTLLM MLA kernel BMM1 operation:
|
436
|
+
# The final BMM1 scale is computed as: q_scale * k_scale * softmax_scale
|
437
|
+
# Scale components:
|
438
|
+
# - q_scale: Query scaling factor (set to 1.0 for both FP16/FP8 paths)
|
439
|
+
# - k_scale: Key scaling factor from model checkpoint (defaults to 1.0 if not available)
|
440
|
+
# - softmax_scale: Attention softmax scaling = 1/sqrt(head_dim), pre-computed as layer.scaling
|
441
|
+
# This unified approach works for both FP16 and FP8 quantized attention paths.
|
345
442
|
q_scale = 1.0
|
346
443
|
k_scale = (
|
347
444
|
layer.k_scale_float
|
@@ -245,6 +245,8 @@ class VisionTritonAttention(nn.Module):
|
|
245
245
|
k: torch.Tensor,
|
246
246
|
v: torch.Tensor,
|
247
247
|
cu_seqlens: Optional[torch.Tensor],
|
248
|
+
bsz: int,
|
249
|
+
seq_len: int,
|
248
250
|
**kwargs,
|
249
251
|
) -> torch.Tensor:
|
250
252
|
r"""
|
@@ -253,6 +255,8 @@ class VisionTritonAttention(nn.Module):
|
|
253
255
|
Returns:
|
254
256
|
[b * s, h, head_size]
|
255
257
|
"""
|
258
|
+
if cu_seqlens is None:
|
259
|
+
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
|
256
260
|
|
257
261
|
# [b * s, head, head_size]
|
258
262
|
output = torch.empty_like(q)
|
@@ -401,7 +405,11 @@ class VisionAttention(nn.Module):
|
|
401
405
|
# priority: server_args > passed qkv_backend > sdpa
|
402
406
|
if global_server_args_dict["mm_attention_backend"] is None:
|
403
407
|
if qkv_backend is None:
|
404
|
-
|
408
|
+
if is_cuda():
|
409
|
+
# Double prefill throughput by setting attn backend to Triton on CUDA
|
410
|
+
qkv_backend = "triton_attn"
|
411
|
+
else:
|
412
|
+
qkv_backend = "sdpa"
|
405
413
|
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
406
414
|
else:
|
407
415
|
qkv_backend = global_server_args_dict["mm_attention_backend"]
|