sglang 0.5.1.post1__py3-none-any.whl → 0.5.1.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +79 -53
- sglang/bench_serving.py +186 -14
- sglang/profiler.py +0 -1
- sglang/srt/conversation.py +38 -5
- sglang/srt/disaggregation/decode.py +4 -0
- sglang/srt/disaggregation/prefill.py +4 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/openai/protocol.py +27 -24
- sglang/srt/entrypoints/openai/serving_chat.py +50 -9
- sglang/srt/entrypoints/openai/serving_completions.py +15 -0
- sglang/srt/entrypoints/tool.py +7 -7
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/harmony_parser.py +588 -0
- sglang/srt/hf_transformers_utils.py +16 -7
- sglang/srt/layers/attention/ascend_backend.py +218 -111
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +76 -91
- sglang/srt/layers/attention/utils.py +15 -94
- sglang/srt/layers/communicator.py +1 -2
- sglang/srt/layers/moe/cutlass_moe.py +0 -15
- sglang/srt/layers/moe/ep_moe/layer.py +1 -7
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -7
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
- sglang/srt/layers/quantization/fp8.py +2 -1
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/mxfp4.py +16 -23
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/lora_manager.py +29 -12
- sglang/srt/managers/cache_controller.py +223 -156
- sglang/srt/managers/detokenizer_manager.py +5 -0
- sglang/srt/managers/io_struct.py +30 -0
- sglang/srt/managers/scheduler.py +58 -7
- sglang/srt/managers/scheduler_metrics_mixin.py +15 -0
- sglang/srt/managers/tokenizer_manager.py +36 -3
- sglang/srt/mem_cache/hicache_storage.py +31 -20
- sglang/srt/mem_cache/hiradix_cache.py +12 -3
- sglang/srt/mem_cache/memory_pool.py +73 -14
- sglang/srt/mem_cache/memory_pool_host.py +3 -2
- sglang/srt/mem_cache/radix_cache.py +1 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +5 -13
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +85 -81
- sglang/srt/metrics/collector.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/model_runner.py +1 -1
- sglang/srt/models/deepseek_v2.py +12 -3
- sglang/srt/models/gpt_oss.py +2 -1
- sglang/srt/models/qwen2_5_vl.py +1 -0
- sglang/srt/offloader.py +115 -0
- sglang/srt/reasoning_parser.py +56 -300
- sglang/srt/server_args.py +10 -5
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +59 -12
- sglang/test/test_cutlass_moe.py +33 -28
- sglang/version.py +1 -1
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/METADATA +6 -5
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/RECORD +69 -65
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/top_level.txt +0 -0
@@ -24,13 +24,18 @@ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
|
24
24
|
|
25
25
|
from sglang.global_config import global_config
|
26
26
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
27
|
-
from sglang.srt.layers.attention.
|
27
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
28
|
+
create_flashinfer_kv_indices_triton,
|
29
|
+
)
|
28
30
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
29
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
30
31
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
31
32
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
32
33
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
33
|
-
from sglang.srt.utils import
|
34
|
+
from sglang.srt.utils import (
|
35
|
+
is_flashinfer_available,
|
36
|
+
is_sm100_supported,
|
37
|
+
next_power_of_2,
|
38
|
+
)
|
34
39
|
|
35
40
|
if TYPE_CHECKING:
|
36
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -179,6 +184,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
179
184
|
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
180
185
|
):
|
181
186
|
super().__init__()
|
187
|
+
|
182
188
|
# Parse constants
|
183
189
|
self.max_context_len = model_runner.model_config.context_len
|
184
190
|
self.device = model_runner.device
|
@@ -210,25 +216,15 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
210
216
|
else:
|
211
217
|
self.kv_indptr = kv_indptr_buf
|
212
218
|
|
213
|
-
self.kv_indices = torch.empty(
|
214
|
-
(max_bs * (self.max_context_len + self.page_size - 1) // self.page_size,),
|
215
|
-
dtype=torch.int32,
|
216
|
-
device=model_runner.device,
|
217
|
-
)
|
218
|
-
|
219
219
|
if not self.skip_prefill:
|
220
220
|
self.qo_indptr = torch.zeros(
|
221
221
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
222
222
|
)
|
223
223
|
|
224
224
|
if q_indptr_decode_buf is None:
|
225
|
-
# A hack to pre-initialize large batch size for dp attention
|
226
|
-
if model_runner.server_args.enable_dp_attention:
|
227
|
-
max_bs = model_runner.server_args.dp_size * max_bs
|
228
225
|
self.q_indptr_decode = torch.arange(
|
229
226
|
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
|
230
227
|
)
|
231
|
-
|
232
228
|
else:
|
233
229
|
self.q_indptr_decode = q_indptr_decode_buf
|
234
230
|
|
@@ -273,7 +269,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
273
269
|
self.prefill_cuda_graph_metadata = {} # For verify
|
274
270
|
|
275
271
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
276
|
-
|
277
272
|
if forward_batch.forward_mode.is_decode_or_idle():
|
278
273
|
self.indices_updater_decode.update(
|
279
274
|
forward_batch.req_pool_indices,
|
@@ -331,9 +326,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
331
326
|
max_num_tokens: int,
|
332
327
|
kv_indices_buf: Optional[torch.Tensor] = None,
|
333
328
|
):
|
334
|
-
|
335
|
-
|
336
|
-
|
329
|
+
if kv_indices_buf is None:
|
330
|
+
cuda_graph_kv_indices = torch.zeros(
|
331
|
+
(max_bs * self.max_context_len,),
|
332
|
+
dtype=torch.int32,
|
333
|
+
device="cuda",
|
334
|
+
)
|
335
|
+
else:
|
336
|
+
cuda_graph_kv_indices = kv_indices_buf
|
337
|
+
|
338
|
+
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
337
339
|
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
|
338
340
|
self.cuda_graph_kv_indptr = self.kv_indptr.clone()
|
339
341
|
self.cuda_graph_kv_lens = torch.ones(
|
@@ -359,7 +361,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
359
361
|
forward_mode: ForwardMode,
|
360
362
|
spec_info: Optional[SpecInfo],
|
361
363
|
):
|
362
|
-
|
363
364
|
if forward_mode.is_decode_or_idle():
|
364
365
|
decode_wrapper = BatchMLAPagedAttentionWrapper(
|
365
366
|
self.workspace_buffer,
|
@@ -370,6 +371,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
370
371
|
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
|
371
372
|
backend="auto",
|
372
373
|
)
|
374
|
+
|
373
375
|
seq_lens_sum = seq_lens.sum().item()
|
374
376
|
self.indices_updater_decode.update(
|
375
377
|
req_pool_indices,
|
@@ -440,13 +442,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
440
442
|
spec_info: Optional[SpecInfo],
|
441
443
|
seq_lens_cpu: Optional[torch.Tensor],
|
442
444
|
):
|
443
|
-
|
444
445
|
if forward_mode.is_decode_or_idle():
|
445
446
|
assert seq_lens_cpu is not None
|
446
447
|
kv_len_arr_cpu = seq_lens_cpu[:bs]
|
447
|
-
num_pages_per_req = (seq_lens_cpu + self.page_size - 1) // self.page_size
|
448
448
|
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
|
449
|
-
|
449
|
+
kv_len_arr_cpu, dim=0
|
450
450
|
)
|
451
451
|
self.fast_decode_kwargs.update(
|
452
452
|
{
|
@@ -455,6 +455,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
455
455
|
"kv_len_arr_cpu": kv_len_arr_cpu,
|
456
456
|
}
|
457
457
|
)
|
458
|
+
|
458
459
|
self.indices_updater_decode.update(
|
459
460
|
req_pool_indices[:bs],
|
460
461
|
seq_lens[:bs],
|
@@ -534,6 +535,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
534
535
|
q_rope = q_rope.view(
|
535
536
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
536
537
|
)
|
538
|
+
|
537
539
|
if self.forward_metadata.use_ragged:
|
538
540
|
# ragged prefill
|
539
541
|
if q_rope is not None:
|
@@ -554,8 +556,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
554
556
|
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
555
557
|
q.dtype
|
556
558
|
)
|
557
|
-
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
|
558
|
-
|
559
559
|
if q_rope is None:
|
560
560
|
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
561
561
|
q, q_rope = (
|
@@ -617,17 +617,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
617
617
|
q_nope = reshaped_q[:, :, : layer.v_head_dim]
|
618
618
|
q_rope = reshaped_q[:, :, layer.v_head_dim :]
|
619
619
|
|
620
|
-
|
620
|
+
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
621
621
|
q.dtype
|
622
622
|
)
|
623
|
-
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
|
624
623
|
|
625
624
|
o = q_nope.new_empty(q_nope.shape)
|
625
|
+
# Direct call to run without the wrapper
|
626
626
|
o = decode_wrapper.run(
|
627
627
|
q_nope,
|
628
628
|
q_rope,
|
629
|
-
|
630
|
-
|
629
|
+
k_buffer[:, :, : layer.v_head_dim],
|
630
|
+
k_buffer[:, :, layer.v_head_dim :],
|
631
631
|
out=o,
|
632
632
|
)
|
633
633
|
|
@@ -646,10 +646,9 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
646
646
|
self.scaling = model_runner.model_config.scaling
|
647
647
|
self.data_type = model_runner.dtype
|
648
648
|
self.attn_backend = attn_backend
|
649
|
-
|
649
|
+
|
650
650
|
# Buffers and wrappers
|
651
651
|
self.kv_indptr = attn_backend.kv_indptr
|
652
|
-
self.kv_indices = attn_backend.kv_indices
|
653
652
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
654
653
|
self.q_indptr = attn_backend.q_indptr_decode
|
655
654
|
|
@@ -693,17 +692,13 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
693
692
|
kv_lens = paged_kernel_lens.to(torch.int32)
|
694
693
|
sm_scale = self.scaling
|
695
694
|
if spec_info is None:
|
696
|
-
|
697
|
-
paged_kernel_lens + self.page_size - 1
|
698
|
-
) // self.page_size
|
699
|
-
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
|
695
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
700
696
|
kv_indptr = kv_indptr[: bs + 1]
|
701
697
|
kv_indices = (
|
702
|
-
|
698
|
+
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
|
703
699
|
if not init_metadata_replay
|
704
700
|
else fast_decode_kwargs["kv_indices"]
|
705
701
|
)
|
706
|
-
|
707
702
|
create_flashinfer_kv_indices_triton[(bs,)](
|
708
703
|
self.req_to_token,
|
709
704
|
req_pool_indices,
|
@@ -712,40 +707,39 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
712
707
|
None,
|
713
708
|
kv_indices,
|
714
709
|
self.req_to_token.shape[1],
|
715
|
-
self.page_size,
|
716
710
|
)
|
717
711
|
else:
|
718
712
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
719
713
|
|
720
714
|
if not init_metadata_replay:
|
721
715
|
wrapper.plan(
|
722
|
-
|
723
|
-
kv_indptr
|
724
|
-
kv_indices
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
sm_scale
|
732
|
-
|
733
|
-
|
716
|
+
q_indptr,
|
717
|
+
kv_indptr,
|
718
|
+
kv_indices,
|
719
|
+
kv_lens,
|
720
|
+
self.num_local_heads,
|
721
|
+
self.kv_lora_rank,
|
722
|
+
self.qk_rope_head_dim,
|
723
|
+
1,
|
724
|
+
False,
|
725
|
+
sm_scale,
|
726
|
+
self.data_type,
|
727
|
+
self.data_type,
|
734
728
|
)
|
735
729
|
else:
|
736
730
|
wrapper.plan(
|
737
|
-
|
738
|
-
|
739
|
-
kv_indices
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
sm_scale
|
747
|
-
|
748
|
-
|
731
|
+
fast_decode_kwargs["qo_indptr_cpu"],
|
732
|
+
fast_decode_kwargs["kv_indptr_cpu"],
|
733
|
+
kv_indices,
|
734
|
+
fast_decode_kwargs["kv_len_arr_cpu"],
|
735
|
+
self.num_local_heads,
|
736
|
+
self.kv_lora_rank,
|
737
|
+
self.qk_rope_head_dim,
|
738
|
+
1,
|
739
|
+
False,
|
740
|
+
sm_scale,
|
741
|
+
self.data_type,
|
742
|
+
self.data_type,
|
749
743
|
)
|
750
744
|
|
751
745
|
|
@@ -767,14 +761,12 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
767
761
|
# Buffers and wrappers
|
768
762
|
self.kv_indptr = attn_backend.kv_indptr
|
769
763
|
self.qo_indptr = attn_backend.qo_indptr
|
770
|
-
self.kv_indices = attn_backend.kv_indices
|
771
764
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
772
765
|
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
773
|
-
self.page_size = model_runner.page_size
|
774
766
|
|
775
767
|
def update(
|
776
768
|
self,
|
777
|
-
req_pool_indices: torch.
|
769
|
+
req_pool_indices: torch.Tnesor,
|
778
770
|
seq_lens: torch.Tensor,
|
779
771
|
seq_lens_sum: int,
|
780
772
|
prefix_lens: torch.Tensor,
|
@@ -788,6 +780,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
788
780
|
else:
|
789
781
|
paged_kernel_lens = seq_lens
|
790
782
|
paged_kernel_lens_sum = seq_lens_sum
|
783
|
+
|
791
784
|
self.call_begin_forward(
|
792
785
|
self.prefill_wrapper_ragged,
|
793
786
|
prefill_wrapper_paged,
|
@@ -821,12 +814,13 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
821
814
|
|
822
815
|
if spec_info is None:
|
823
816
|
assert len(seq_lens) == len(req_pool_indices)
|
824
|
-
|
825
|
-
paged_kernel_lens + self.page_size - 1
|
826
|
-
) // self.page_size
|
827
|
-
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
|
817
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
828
818
|
kv_indptr = kv_indptr[: bs + 1]
|
829
|
-
kv_indices =
|
819
|
+
kv_indices = torch.empty(
|
820
|
+
paged_kernel_lens_sum,
|
821
|
+
dtype=torch.int32,
|
822
|
+
device=req_pool_indices.device,
|
823
|
+
)
|
830
824
|
create_flashinfer_kv_indices_triton[(bs,)](
|
831
825
|
self.req_to_token,
|
832
826
|
req_pool_indices,
|
@@ -835,7 +829,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
835
829
|
None,
|
836
830
|
kv_indices,
|
837
831
|
self.req_to_token.shape[1],
|
838
|
-
self.page_size,
|
839
832
|
)
|
840
833
|
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
841
834
|
qo_indptr = qo_indptr[: bs + 1]
|
@@ -853,6 +846,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
853
846
|
self.req_to_token,
|
854
847
|
)
|
855
848
|
)
|
849
|
+
|
856
850
|
if use_ragged:
|
857
851
|
# ragged prefill
|
858
852
|
wrapper_ragged.begin_forward(
|
@@ -867,26 +861,20 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
867
861
|
)
|
868
862
|
else:
|
869
863
|
# mla paged prefill
|
870
|
-
|
871
|
-
assert (
|
872
|
-
self.page_size == 1
|
873
|
-
), "Only page_size=1 is supported for flashinfer backend with speculative decoding"
|
874
|
-
kv_lens = kv_indptr[1:] - kv_indptr[:-1]
|
875
|
-
else:
|
876
|
-
kv_lens = paged_kernel_lens.to(torch.int32)
|
864
|
+
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
|
877
865
|
wrapper_paged.plan(
|
878
|
-
qo_indptr
|
879
|
-
kv_indptr
|
880
|
-
kv_indices
|
881
|
-
kv_len_arr
|
882
|
-
|
883
|
-
|
884
|
-
|
885
|
-
|
886
|
-
|
887
|
-
sm_scale
|
888
|
-
|
889
|
-
|
866
|
+
qo_indptr,
|
867
|
+
kv_indptr,
|
868
|
+
kv_indices,
|
869
|
+
kv_len_arr,
|
870
|
+
self.num_local_heads,
|
871
|
+
self.kv_lora_rank,
|
872
|
+
self.qk_rope_head_dim,
|
873
|
+
1,
|
874
|
+
True,
|
875
|
+
sm_scale,
|
876
|
+
self.q_data_type,
|
877
|
+
self.data_type,
|
890
878
|
)
|
891
879
|
|
892
880
|
|
@@ -981,7 +969,6 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
981
969
|
call_fn(i, forward_batch)
|
982
970
|
|
983
971
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
984
|
-
|
985
972
|
kv_indices = torch.zeros(
|
986
973
|
(
|
987
974
|
self.speculative_num_steps,
|
@@ -1017,7 +1004,6 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
1017
1004
|
)
|
1018
1005
|
|
1019
1006
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
1020
|
-
|
1021
1007
|
def call_fn(i, forward_batch):
|
1022
1008
|
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
1023
1009
|
forward_batch.batch_size,
|
@@ -1034,7 +1020,6 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
1034
1020
|
def init_forward_metadata_replay_cuda_graph(
|
1035
1021
|
self, forward_batch: ForwardBatch, bs: int
|
1036
1022
|
):
|
1037
|
-
|
1038
1023
|
def call_fn(i, forward_batch):
|
1039
1024
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
1040
1025
|
bs,
|
@@ -9,89 +9,18 @@ TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
|
|
9
9
|
|
10
10
|
@triton.jit
|
11
11
|
def create_flashinfer_kv_indices_triton(
|
12
|
-
req_to_token_ptr,
|
12
|
+
req_to_token_ptr, # [max_batch, max_context_len]
|
13
13
|
req_pool_indices_ptr,
|
14
14
|
page_kernel_lens_ptr,
|
15
15
|
kv_indptr,
|
16
16
|
kv_start_idx,
|
17
17
|
kv_indices_ptr,
|
18
18
|
req_to_token_ptr_stride: tl.constexpr,
|
19
|
-
PAGE_SIZE: tl.constexpr = 1,
|
20
19
|
):
|
21
|
-
"""
|
22
|
-
Create KV indices for FlashInfer attention backend.
|
23
|
-
|
24
|
-
This Triton kernel builds a lookup table that maps from logical request/token
|
25
|
-
coordinates to physical token locations in the global KV cache pool. It's used
|
26
|
-
by FlashInfer attention backends to efficiently access scattered KV cache data.
|
27
|
-
|
28
|
-
The kernel processes each request in parallel and converts the req_to_token
|
29
|
-
lookup table into a flat list of token indices that can be used by attention kernels.
|
30
|
-
|
31
|
-
general idea:
|
32
|
-
blocktables/kv_indices_ptr = [batch_size * max_pages(for graph mode with
|
33
|
-
fixed number of pages)]
|
34
|
-
max_pages = max_context_len / PAGED_SIZE
|
35
|
-
kv_indices_ptr will store the flat list of the pages used by each request
|
36
|
-
Args:
|
37
|
-
Inputs Arguments (non mutable):
|
38
|
-
|
39
|
-
req_to_token_ptr: Request to token location look up table
|
40
|
-
Shape: [max_batch, max_context_len]
|
41
|
-
req_pool_indices_ptr: Request to pool index look up table. Each request uses
|
42
|
-
one pool.
|
43
|
-
Shape: [batch_size]
|
44
|
-
page_kernel_lens_ptr: sequence lengths per request
|
45
|
-
Shape: [batch_size]
|
46
|
-
kv_indptr: Should be computed based on number of pages used by each request.
|
47
|
-
It is used by flashinfer attention kernels to index into the kv_indices_ptr.
|
48
|
-
per request.
|
49
|
-
Shape: [batch_size + 1]
|
50
|
-
kv_indptr[i] = start index in kv_indices for request i
|
51
|
-
kv_start_idx: Pointer to array containing start offsets for each request in SGL.
|
52
|
-
Can be None. If provided, adds offset to token positions.
|
53
|
-
|
54
|
-
req_to_token_ptr_stride: Stride for the second dimension of req_to_token.
|
55
|
-
Equal to max_context_len.
|
56
|
-
|
57
|
-
PAGED_SIZE: Number of tokens per page. Default is 1 for FlashInfer.
|
58
|
-
|
59
|
-
Outputs:
|
60
|
-
kv_indices_ptr: Pointer to output array where KV indices will be stored.
|
61
|
-
Shape:[total-num-pages],
|
62
|
-
where total_num_pages = sum(seq_lens // PAGED_SIZE)
|
63
|
-
|
64
|
-
Example:
|
65
|
-
If we have:
|
66
|
-
- req_pool_indices = [0, 1] (request 0 uses pool 0, request 1 uses pool 1)
|
67
|
-
- page_kernel_lens = [3, 2] (request 0 has 3 tokens, request 1 has 2 tokens)
|
68
|
-
- req_to_token = [[10, 11, 12, -1], [20, 21, -1, -1]] (tokens are the elements
|
69
|
-
in radix tree, use them as a pointer to the token location in the kv_indices_ptr)
|
70
|
-
|
71
|
-
The kernel will output:
|
72
|
-
If PAGE_SIZE = 1:
|
73
|
-
packed
|
74
|
-
- kv_indptr (passed in as input arg): [0,3,5]
|
75
|
-
- kv_indices = [10, 11, 12, 20, 21]
|
76
|
-
padded - max_pages is 10 tokens per req
|
77
|
-
- kv_indptr (passed in as input arg): [0,10, 20]
|
78
|
-
- kv_indices = [10, 11, 12, -1, -1, -1, -1, -1, -1, -1,
|
79
|
-
20, 21, -1, -1, -1, -1, -1, -1, -1, -1]
|
80
|
-
|
81
|
-
If PAGE_SIZE = 2
|
82
|
-
packed:
|
83
|
-
- kv_indptr (passed in as input arg): [0,3,4]
|
84
|
-
- kv_indices = [5,6,10]
|
85
|
-
padded: max_pages is 4
|
86
|
-
- kv_indptr (passed in as input arg): [0,4,8,..] (note that 4 is the max_pages)
|
87
|
-
- kv_indices = [5, 6, -1, -1,
|
88
|
-
10, -1, -1, -1]
|
89
|
-
This allows attention kernels to directly access the correct KV cache
|
90
|
-
entries for each request's tokens.
|
91
|
-
"""
|
92
20
|
BLOCK_SIZE: tl.constexpr = 512
|
93
|
-
NUM_PAGES_PER_BLOCK: tl.constexpr = BLOCK_SIZE // PAGE_SIZE
|
94
21
|
pid = tl.program_id(axis=0)
|
22
|
+
|
23
|
+
# find the req pool idx, this is for batch to token
|
95
24
|
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
96
25
|
kv_indices_offset = tl.load(kv_indptr + pid)
|
97
26
|
|
@@ -102,27 +31,19 @@ def create_flashinfer_kv_indices_triton(
|
|
102
31
|
kv_end = kv_start
|
103
32
|
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
104
33
|
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
valid_tokens = token_offsets_in_block < kv_range
|
117
|
-
valid_pages = page_offsets_in_block < num_pages
|
118
|
-
token_numbers = tl.load(
|
119
|
-
req_to_token_block_start + token_offsets_in_block, mask=valid_tokens
|
120
|
-
)
|
121
|
-
tl.store(
|
122
|
-
kv_indices_ptr + kv_indices_offset + page_offsets_in_block,
|
123
|
-
token_numbers // PAGE_SIZE, # write the page numbers to kv_indices_ptr
|
124
|
-
mask=valid_pages,
|
34
|
+
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
35
|
+
for i in range(num_loop):
|
36
|
+
# index into req_to_token_ptr needs to be int64
|
37
|
+
offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
|
38
|
+
mask = offset < kv_end - kv_start
|
39
|
+
data = tl.load(
|
40
|
+
req_to_token_ptr
|
41
|
+
+ req_pool_index * req_to_token_ptr_stride
|
42
|
+
+ kv_start
|
43
|
+
+ offset,
|
44
|
+
mask=mask,
|
125
45
|
)
|
46
|
+
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
126
47
|
|
127
48
|
|
128
49
|
@triton.jit
|
@@ -40,10 +40,9 @@ from sglang.srt.layers.moe import (
|
|
40
40
|
get_moe_a2a_backend,
|
41
41
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
42
42
|
)
|
43
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
44
43
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
45
44
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
46
|
-
from sglang.srt.utils import is_cuda, is_flashinfer_available
|
45
|
+
from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported
|
47
46
|
|
48
47
|
_is_flashinfer_available = is_flashinfer_available()
|
49
48
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
@@ -1,20 +1,12 @@
|
|
1
1
|
"""CUTLASS based Fused MoE kernels."""
|
2
2
|
|
3
|
-
import functools
|
4
|
-
import json
|
5
|
-
import logging
|
6
|
-
import os
|
7
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple
|
8
|
-
|
9
3
|
import torch
|
10
4
|
|
11
5
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
|
12
|
-
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
|
13
6
|
from sglang.srt.utils import is_cuda
|
14
7
|
|
15
8
|
_is_cuda = is_cuda()
|
16
9
|
if _is_cuda:
|
17
|
-
import sgl_kernel
|
18
10
|
from sgl_kernel import (
|
19
11
|
apply_shuffle_mul_sum,
|
20
12
|
cutlass_fp4_group_mm,
|
@@ -157,10 +149,6 @@ def cutlass_fused_experts_fp8(
|
|
157
149
|
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
|
158
150
|
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
|
159
151
|
|
160
|
-
if not is_sm100_supported():
|
161
|
-
rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets)
|
162
|
-
w1_scale = w1_scale.contiguous()
|
163
|
-
|
164
152
|
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
165
153
|
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
166
154
|
|
@@ -192,9 +180,6 @@ def cutlass_fused_experts_fp8(
|
|
192
180
|
silu_and_mul(c1, intermediate)
|
193
181
|
|
194
182
|
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
|
195
|
-
if not is_sm100_supported():
|
196
|
-
a2_scale = per_group_transpose(a2_scale, expert_offsets)
|
197
|
-
w2_scale = w2_scale.contiguous()
|
198
183
|
|
199
184
|
fp8_blockwise_scaled_grouped_mm(
|
200
185
|
c2,
|
@@ -248,7 +248,6 @@ class EPMoE(FusedMoE):
|
|
248
248
|
gateup_output,
|
249
249
|
masked_m,
|
250
250
|
expected_m,
|
251
|
-
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
252
251
|
)
|
253
252
|
del gateup_input
|
254
253
|
del gateup_input_fp8
|
@@ -304,7 +303,6 @@ class EPMoE(FusedMoE):
|
|
304
303
|
down_output,
|
305
304
|
masked_m,
|
306
305
|
expected_m,
|
307
|
-
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
308
306
|
)
|
309
307
|
del down_input
|
310
308
|
del down_input_fp8
|
@@ -667,7 +665,6 @@ class DeepEPMoE(EPMoE):
|
|
667
665
|
gateup_output,
|
668
666
|
masked_m,
|
669
667
|
expected_m,
|
670
|
-
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
671
668
|
)
|
672
669
|
dispose_tensor(hidden_states_fp8[0])
|
673
670
|
|
@@ -708,9 +705,7 @@ class DeepEPMoE(EPMoE):
|
|
708
705
|
(
|
709
706
|
down_input_scale
|
710
707
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
711
|
-
else deep_gemm_wrapper.
|
712
|
-
down_input_scale
|
713
|
-
)
|
708
|
+
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
|
714
709
|
),
|
715
710
|
)
|
716
711
|
down_output = torch.empty(
|
@@ -722,7 +717,6 @@ class DeepEPMoE(EPMoE):
|
|
722
717
|
down_output,
|
723
718
|
masked_m,
|
724
719
|
expected_m,
|
725
|
-
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
726
720
|
)
|
727
721
|
|
728
722
|
return down_output
|