sglang 0.5.1.post1__py3-none-any.whl → 0.5.1.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/bench_one_batch_server.py +79 -53
  2. sglang/bench_serving.py +186 -14
  3. sglang/profiler.py +0 -1
  4. sglang/srt/conversation.py +38 -5
  5. sglang/srt/disaggregation/decode.py +4 -0
  6. sglang/srt/disaggregation/prefill.py +4 -0
  7. sglang/srt/entrypoints/engine.py +2 -2
  8. sglang/srt/entrypoints/openai/protocol.py +27 -24
  9. sglang/srt/entrypoints/openai/serving_chat.py +50 -9
  10. sglang/srt/entrypoints/openai/serving_completions.py +15 -0
  11. sglang/srt/entrypoints/tool.py +7 -7
  12. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  13. sglang/srt/function_call/function_call_parser.py +2 -0
  14. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  15. sglang/srt/harmony_parser.py +588 -0
  16. sglang/srt/hf_transformers_utils.py +16 -7
  17. sglang/srt/layers/attention/ascend_backend.py +218 -111
  18. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  19. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  20. sglang/srt/layers/attention/flashinfer_mla_backend.py +76 -91
  21. sglang/srt/layers/attention/utils.py +15 -94
  22. sglang/srt/layers/communicator.py +1 -2
  23. sglang/srt/layers/moe/cutlass_moe.py +0 -15
  24. sglang/srt/layers/moe/ep_moe/layer.py +1 -7
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  27. sglang/srt/layers/moe/topk.py +1 -1
  28. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
  29. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -7
  30. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
  31. sglang/srt/layers/quantization/fp8.py +2 -1
  32. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  33. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  34. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  35. sglang/srt/layers/quantization/mxfp4.py +16 -23
  36. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  37. sglang/srt/layers/utils.py +0 -14
  38. sglang/srt/lora/lora_manager.py +29 -12
  39. sglang/srt/managers/cache_controller.py +223 -156
  40. sglang/srt/managers/detokenizer_manager.py +5 -0
  41. sglang/srt/managers/io_struct.py +30 -0
  42. sglang/srt/managers/scheduler.py +58 -7
  43. sglang/srt/managers/scheduler_metrics_mixin.py +15 -0
  44. sglang/srt/managers/tokenizer_manager.py +36 -3
  45. sglang/srt/mem_cache/hicache_storage.py +31 -20
  46. sglang/srt/mem_cache/hiradix_cache.py +12 -3
  47. sglang/srt/mem_cache/memory_pool.py +73 -14
  48. sglang/srt/mem_cache/memory_pool_host.py +3 -2
  49. sglang/srt/mem_cache/radix_cache.py +1 -0
  50. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +5 -13
  51. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +85 -81
  52. sglang/srt/metrics/collector.py +5 -5
  53. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  54. sglang/srt/model_executor/model_runner.py +1 -1
  55. sglang/srt/models/deepseek_v2.py +12 -3
  56. sglang/srt/models/gpt_oss.py +2 -1
  57. sglang/srt/models/qwen2_5_vl.py +1 -0
  58. sglang/srt/offloader.py +115 -0
  59. sglang/srt/reasoning_parser.py +56 -300
  60. sglang/srt/server_args.py +10 -5
  61. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  62. sglang/srt/utils.py +59 -12
  63. sglang/test/test_cutlass_moe.py +33 -28
  64. sglang/version.py +1 -1
  65. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/METADATA +6 -5
  66. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/RECORD +69 -65
  67. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/WHEEL +0 -0
  68. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/licenses/LICENSE +0 -0
  69. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/top_level.txt +0 -0
@@ -24,13 +24,18 @@ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
24
24
 
25
25
  from sglang.global_config import global_config
26
26
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
27
- from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
27
+ from sglang.srt.layers.attention.flashinfer_backend import (
28
+ create_flashinfer_kv_indices_triton,
29
+ )
28
30
  from sglang.srt.layers.dp_attention import get_attention_tp_size
29
- from sglang.srt.layers.utils import is_sm100_supported
30
31
  from sglang.srt.managers.schedule_batch import global_server_args_dict
31
32
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
32
33
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
33
- from sglang.srt.utils import is_flashinfer_available, next_power_of_2
34
+ from sglang.srt.utils import (
35
+ is_flashinfer_available,
36
+ is_sm100_supported,
37
+ next_power_of_2,
38
+ )
34
39
 
35
40
  if TYPE_CHECKING:
36
41
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -179,6 +184,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
179
184
  q_indptr_decode_buf: Optional[torch.Tensor] = None,
180
185
  ):
181
186
  super().__init__()
187
+
182
188
  # Parse constants
183
189
  self.max_context_len = model_runner.model_config.context_len
184
190
  self.device = model_runner.device
@@ -210,25 +216,15 @@ class FlashInferMLAAttnBackend(AttentionBackend):
210
216
  else:
211
217
  self.kv_indptr = kv_indptr_buf
212
218
 
213
- self.kv_indices = torch.empty(
214
- (max_bs * (self.max_context_len + self.page_size - 1) // self.page_size,),
215
- dtype=torch.int32,
216
- device=model_runner.device,
217
- )
218
-
219
219
  if not self.skip_prefill:
220
220
  self.qo_indptr = torch.zeros(
221
221
  (max_bs + 1,), dtype=torch.int32, device=model_runner.device
222
222
  )
223
223
 
224
224
  if q_indptr_decode_buf is None:
225
- # A hack to pre-initialize large batch size for dp attention
226
- if model_runner.server_args.enable_dp_attention:
227
- max_bs = model_runner.server_args.dp_size * max_bs
228
225
  self.q_indptr_decode = torch.arange(
229
226
  0, max_bs + 1, dtype=torch.int32, device=model_runner.device
230
227
  )
231
-
232
228
  else:
233
229
  self.q_indptr_decode = q_indptr_decode_buf
234
230
 
@@ -273,7 +269,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
273
269
  self.prefill_cuda_graph_metadata = {} # For verify
274
270
 
275
271
  def init_forward_metadata(self, forward_batch: ForwardBatch):
276
-
277
272
  if forward_batch.forward_mode.is_decode_or_idle():
278
273
  self.indices_updater_decode.update(
279
274
  forward_batch.req_pool_indices,
@@ -331,9 +326,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
331
326
  max_num_tokens: int,
332
327
  kv_indices_buf: Optional[torch.Tensor] = None,
333
328
  ):
334
- self.cuda_graph_kv_indices = (
335
- self.kv_indices.clone() if kv_indices_buf is None else kv_indices_buf
336
- )
329
+ if kv_indices_buf is None:
330
+ cuda_graph_kv_indices = torch.zeros(
331
+ (max_bs * self.max_context_len,),
332
+ dtype=torch.int32,
333
+ device="cuda",
334
+ )
335
+ else:
336
+ cuda_graph_kv_indices = kv_indices_buf
337
+
338
+ self.cuda_graph_kv_indices = cuda_graph_kv_indices
337
339
  self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
338
340
  self.cuda_graph_kv_indptr = self.kv_indptr.clone()
339
341
  self.cuda_graph_kv_lens = torch.ones(
@@ -359,7 +361,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
359
361
  forward_mode: ForwardMode,
360
362
  spec_info: Optional[SpecInfo],
361
363
  ):
362
-
363
364
  if forward_mode.is_decode_or_idle():
364
365
  decode_wrapper = BatchMLAPagedAttentionWrapper(
365
366
  self.workspace_buffer,
@@ -370,6 +371,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
370
371
  kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
371
372
  backend="auto",
372
373
  )
374
+
373
375
  seq_lens_sum = seq_lens.sum().item()
374
376
  self.indices_updater_decode.update(
375
377
  req_pool_indices,
@@ -440,13 +442,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
440
442
  spec_info: Optional[SpecInfo],
441
443
  seq_lens_cpu: Optional[torch.Tensor],
442
444
  ):
443
-
444
445
  if forward_mode.is_decode_or_idle():
445
446
  assert seq_lens_cpu is not None
446
447
  kv_len_arr_cpu = seq_lens_cpu[:bs]
447
- num_pages_per_req = (seq_lens_cpu + self.page_size - 1) // self.page_size
448
448
  self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
449
- num_pages_per_req, dim=0
449
+ kv_len_arr_cpu, dim=0
450
450
  )
451
451
  self.fast_decode_kwargs.update(
452
452
  {
@@ -455,6 +455,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
455
455
  "kv_len_arr_cpu": kv_len_arr_cpu,
456
456
  }
457
457
  )
458
+
458
459
  self.indices_updater_decode.update(
459
460
  req_pool_indices[:bs],
460
461
  seq_lens[:bs],
@@ -534,6 +535,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
534
535
  q_rope = q_rope.view(
535
536
  -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
536
537
  )
538
+
537
539
  if self.forward_metadata.use_ragged:
538
540
  # ragged prefill
539
541
  if q_rope is not None:
@@ -554,8 +556,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
554
556
  k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
555
557
  q.dtype
556
558
  )
557
- k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
558
-
559
559
  if q_rope is None:
560
560
  qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
561
561
  q, q_rope = (
@@ -617,17 +617,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
617
617
  q_nope = reshaped_q[:, :, : layer.v_head_dim]
618
618
  q_rope = reshaped_q[:, :, layer.v_head_dim :]
619
619
 
620
- k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
620
+ k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
621
621
  q.dtype
622
622
  )
623
- k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
624
623
 
625
624
  o = q_nope.new_empty(q_nope.shape)
625
+ # Direct call to run without the wrapper
626
626
  o = decode_wrapper.run(
627
627
  q_nope,
628
628
  q_rope,
629
- k_buf[:, :, : layer.v_head_dim],
630
- k_buf[:, :, layer.v_head_dim :],
629
+ k_buffer[:, :, : layer.v_head_dim],
630
+ k_buffer[:, :, layer.v_head_dim :],
631
631
  out=o,
632
632
  )
633
633
 
@@ -646,10 +646,9 @@ class FlashInferMLAIndicesUpdaterDecode:
646
646
  self.scaling = model_runner.model_config.scaling
647
647
  self.data_type = model_runner.dtype
648
648
  self.attn_backend = attn_backend
649
- self.page_size = model_runner.page_size
649
+
650
650
  # Buffers and wrappers
651
651
  self.kv_indptr = attn_backend.kv_indptr
652
- self.kv_indices = attn_backend.kv_indices
653
652
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
654
653
  self.q_indptr = attn_backend.q_indptr_decode
655
654
 
@@ -693,17 +692,13 @@ class FlashInferMLAIndicesUpdaterDecode:
693
692
  kv_lens = paged_kernel_lens.to(torch.int32)
694
693
  sm_scale = self.scaling
695
694
  if spec_info is None:
696
- num_pages_per_req = (
697
- paged_kernel_lens + self.page_size - 1
698
- ) // self.page_size
699
- kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
695
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
700
696
  kv_indptr = kv_indptr[: bs + 1]
701
697
  kv_indices = (
702
- self.kv_indices[: kv_indptr[-1]]
698
+ torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
703
699
  if not init_metadata_replay
704
700
  else fast_decode_kwargs["kv_indices"]
705
701
  )
706
-
707
702
  create_flashinfer_kv_indices_triton[(bs,)](
708
703
  self.req_to_token,
709
704
  req_pool_indices,
@@ -712,40 +707,39 @@ class FlashInferMLAIndicesUpdaterDecode:
712
707
  None,
713
708
  kv_indices,
714
709
  self.req_to_token.shape[1],
715
- self.page_size,
716
710
  )
717
711
  else:
718
712
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
719
713
 
720
714
  if not init_metadata_replay:
721
715
  wrapper.plan(
722
- qo_indptr=q_indptr,
723
- kv_indptr=kv_indptr,
724
- kv_indices=kv_indices,
725
- kv_len_arr=kv_lens,
726
- num_heads=self.num_local_heads,
727
- head_dim_ckv=self.kv_lora_rank,
728
- head_dim_kpe=self.qk_rope_head_dim,
729
- page_size=self.page_size,
730
- causal=False,
731
- sm_scale=sm_scale,
732
- q_data_type=self.data_type,
733
- kv_data_type=self.data_type,
716
+ q_indptr,
717
+ kv_indptr,
718
+ kv_indices,
719
+ kv_lens,
720
+ self.num_local_heads,
721
+ self.kv_lora_rank,
722
+ self.qk_rope_head_dim,
723
+ 1,
724
+ False,
725
+ sm_scale,
726
+ self.data_type,
727
+ self.data_type,
734
728
  )
735
729
  else:
736
730
  wrapper.plan(
737
- qo_indptr_cpu=fast_decode_kwargs["qo_indptr_cpu"],
738
- kv_indptr_cpu=fast_decode_kwargs["kv_indptr_cpu"],
739
- kv_indices=kv_indices,
740
- kv_len_arr_cpu=fast_decode_kwargs["kv_len_arr_cpu"],
741
- num_heads=self.num_local_heads,
742
- head_dim_ckv=self.kv_lora_rank,
743
- head_dim_kpe=self.qk_rope_head_dim,
744
- page_size=self.page_size,
745
- causal=False,
746
- sm_scale=sm_scale,
747
- q_data_type=self.data_type,
748
- kv_data_type=self.data_type,
731
+ fast_decode_kwargs["qo_indptr_cpu"],
732
+ fast_decode_kwargs["kv_indptr_cpu"],
733
+ kv_indices,
734
+ fast_decode_kwargs["kv_len_arr_cpu"],
735
+ self.num_local_heads,
736
+ self.kv_lora_rank,
737
+ self.qk_rope_head_dim,
738
+ 1,
739
+ False,
740
+ sm_scale,
741
+ self.data_type,
742
+ self.data_type,
749
743
  )
750
744
 
751
745
 
@@ -767,14 +761,12 @@ class FlashInferMLAIndicesUpdaterPrefill:
767
761
  # Buffers and wrappers
768
762
  self.kv_indptr = attn_backend.kv_indptr
769
763
  self.qo_indptr = attn_backend.qo_indptr
770
- self.kv_indices = attn_backend.kv_indices
771
764
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
772
765
  self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
773
- self.page_size = model_runner.page_size
774
766
 
775
767
  def update(
776
768
  self,
777
- req_pool_indices: torch.Tensor,
769
+ req_pool_indices: torch.Tnesor,
778
770
  seq_lens: torch.Tensor,
779
771
  seq_lens_sum: int,
780
772
  prefix_lens: torch.Tensor,
@@ -788,6 +780,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
788
780
  else:
789
781
  paged_kernel_lens = seq_lens
790
782
  paged_kernel_lens_sum = seq_lens_sum
783
+
791
784
  self.call_begin_forward(
792
785
  self.prefill_wrapper_ragged,
793
786
  prefill_wrapper_paged,
@@ -821,12 +814,13 @@ class FlashInferMLAIndicesUpdaterPrefill:
821
814
 
822
815
  if spec_info is None:
823
816
  assert len(seq_lens) == len(req_pool_indices)
824
- num_pages_per_req = (
825
- paged_kernel_lens + self.page_size - 1
826
- ) // self.page_size
827
- kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
817
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
828
818
  kv_indptr = kv_indptr[: bs + 1]
829
- kv_indices = self.kv_indices[: kv_indptr[-1]]
819
+ kv_indices = torch.empty(
820
+ paged_kernel_lens_sum,
821
+ dtype=torch.int32,
822
+ device=req_pool_indices.device,
823
+ )
830
824
  create_flashinfer_kv_indices_triton[(bs,)](
831
825
  self.req_to_token,
832
826
  req_pool_indices,
@@ -835,7 +829,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
835
829
  None,
836
830
  kv_indices,
837
831
  self.req_to_token.shape[1],
838
- self.page_size,
839
832
  )
840
833
  qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
841
834
  qo_indptr = qo_indptr[: bs + 1]
@@ -853,6 +846,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
853
846
  self.req_to_token,
854
847
  )
855
848
  )
849
+
856
850
  if use_ragged:
857
851
  # ragged prefill
858
852
  wrapper_ragged.begin_forward(
@@ -867,26 +861,20 @@ class FlashInferMLAIndicesUpdaterPrefill:
867
861
  )
868
862
  else:
869
863
  # mla paged prefill
870
- if spec_info is not None:
871
- assert (
872
- self.page_size == 1
873
- ), "Only page_size=1 is supported for flashinfer backend with speculative decoding"
874
- kv_lens = kv_indptr[1:] - kv_indptr[:-1]
875
- else:
876
- kv_lens = paged_kernel_lens.to(torch.int32)
864
+ kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
877
865
  wrapper_paged.plan(
878
- qo_indptr=qo_indptr,
879
- kv_indptr=kv_indptr,
880
- kv_indices=kv_indices,
881
- kv_len_arr=kv_lens,
882
- num_heads=self.num_local_heads,
883
- head_dim_ckv=self.kv_lora_rank,
884
- head_dim_kpe=self.qk_rope_head_dim,
885
- page_size=self.page_size,
886
- causal=True,
887
- sm_scale=sm_scale,
888
- q_data_type=self.q_data_type,
889
- kv_data_type=self.data_type,
866
+ qo_indptr,
867
+ kv_indptr,
868
+ kv_indices,
869
+ kv_len_arr,
870
+ self.num_local_heads,
871
+ self.kv_lora_rank,
872
+ self.qk_rope_head_dim,
873
+ 1,
874
+ True,
875
+ sm_scale,
876
+ self.q_data_type,
877
+ self.data_type,
890
878
  )
891
879
 
892
880
 
@@ -981,7 +969,6 @@ class FlashInferMLAMultiStepDraftBackend:
981
969
  call_fn(i, forward_batch)
982
970
 
983
971
  def init_forward_metadata(self, forward_batch: ForwardBatch):
984
-
985
972
  kv_indices = torch.zeros(
986
973
  (
987
974
  self.speculative_num_steps,
@@ -1017,7 +1004,6 @@ class FlashInferMLAMultiStepDraftBackend:
1017
1004
  )
1018
1005
 
1019
1006
  def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
1020
-
1021
1007
  def call_fn(i, forward_batch):
1022
1008
  self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
1023
1009
  forward_batch.batch_size,
@@ -1034,7 +1020,6 @@ class FlashInferMLAMultiStepDraftBackend:
1034
1020
  def init_forward_metadata_replay_cuda_graph(
1035
1021
  self, forward_batch: ForwardBatch, bs: int
1036
1022
  ):
1037
-
1038
1023
  def call_fn(i, forward_batch):
1039
1024
  self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
1040
1025
  bs,
@@ -9,89 +9,18 @@ TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
9
9
 
10
10
  @triton.jit
11
11
  def create_flashinfer_kv_indices_triton(
12
- req_to_token_ptr,
12
+ req_to_token_ptr, # [max_batch, max_context_len]
13
13
  req_pool_indices_ptr,
14
14
  page_kernel_lens_ptr,
15
15
  kv_indptr,
16
16
  kv_start_idx,
17
17
  kv_indices_ptr,
18
18
  req_to_token_ptr_stride: tl.constexpr,
19
- PAGE_SIZE: tl.constexpr = 1,
20
19
  ):
21
- """
22
- Create KV indices for FlashInfer attention backend.
23
-
24
- This Triton kernel builds a lookup table that maps from logical request/token
25
- coordinates to physical token locations in the global KV cache pool. It's used
26
- by FlashInfer attention backends to efficiently access scattered KV cache data.
27
-
28
- The kernel processes each request in parallel and converts the req_to_token
29
- lookup table into a flat list of token indices that can be used by attention kernels.
30
-
31
- general idea:
32
- blocktables/kv_indices_ptr = [batch_size * max_pages(for graph mode with
33
- fixed number of pages)]
34
- max_pages = max_context_len / PAGED_SIZE
35
- kv_indices_ptr will store the flat list of the pages used by each request
36
- Args:
37
- Inputs Arguments (non mutable):
38
-
39
- req_to_token_ptr: Request to token location look up table
40
- Shape: [max_batch, max_context_len]
41
- req_pool_indices_ptr: Request to pool index look up table. Each request uses
42
- one pool.
43
- Shape: [batch_size]
44
- page_kernel_lens_ptr: sequence lengths per request
45
- Shape: [batch_size]
46
- kv_indptr: Should be computed based on number of pages used by each request.
47
- It is used by flashinfer attention kernels to index into the kv_indices_ptr.
48
- per request.
49
- Shape: [batch_size + 1]
50
- kv_indptr[i] = start index in kv_indices for request i
51
- kv_start_idx: Pointer to array containing start offsets for each request in SGL.
52
- Can be None. If provided, adds offset to token positions.
53
-
54
- req_to_token_ptr_stride: Stride for the second dimension of req_to_token.
55
- Equal to max_context_len.
56
-
57
- PAGED_SIZE: Number of tokens per page. Default is 1 for FlashInfer.
58
-
59
- Outputs:
60
- kv_indices_ptr: Pointer to output array where KV indices will be stored.
61
- Shape:[total-num-pages],
62
- where total_num_pages = sum(seq_lens // PAGED_SIZE)
63
-
64
- Example:
65
- If we have:
66
- - req_pool_indices = [0, 1] (request 0 uses pool 0, request 1 uses pool 1)
67
- - page_kernel_lens = [3, 2] (request 0 has 3 tokens, request 1 has 2 tokens)
68
- - req_to_token = [[10, 11, 12, -1], [20, 21, -1, -1]] (tokens are the elements
69
- in radix tree, use them as a pointer to the token location in the kv_indices_ptr)
70
-
71
- The kernel will output:
72
- If PAGE_SIZE = 1:
73
- packed
74
- - kv_indptr (passed in as input arg): [0,3,5]
75
- - kv_indices = [10, 11, 12, 20, 21]
76
- padded - max_pages is 10 tokens per req
77
- - kv_indptr (passed in as input arg): [0,10, 20]
78
- - kv_indices = [10, 11, 12, -1, -1, -1, -1, -1, -1, -1,
79
- 20, 21, -1, -1, -1, -1, -1, -1, -1, -1]
80
-
81
- If PAGE_SIZE = 2
82
- packed:
83
- - kv_indptr (passed in as input arg): [0,3,4]
84
- - kv_indices = [5,6,10]
85
- padded: max_pages is 4
86
- - kv_indptr (passed in as input arg): [0,4,8,..] (note that 4 is the max_pages)
87
- - kv_indices = [5, 6, -1, -1,
88
- 10, -1, -1, -1]
89
- This allows attention kernels to directly access the correct KV cache
90
- entries for each request's tokens.
91
- """
92
20
  BLOCK_SIZE: tl.constexpr = 512
93
- NUM_PAGES_PER_BLOCK: tl.constexpr = BLOCK_SIZE // PAGE_SIZE
94
21
  pid = tl.program_id(axis=0)
22
+
23
+ # find the req pool idx, this is for batch to token
95
24
  req_pool_index = tl.load(req_pool_indices_ptr + pid)
96
25
  kv_indices_offset = tl.load(kv_indptr + pid)
97
26
 
@@ -102,27 +31,19 @@ def create_flashinfer_kv_indices_triton(
102
31
  kv_end = kv_start
103
32
  kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
104
33
 
105
- kv_range = kv_end - kv_start
106
- num_pages = tl.cdiv(kv_range, PAGE_SIZE)
107
- num_loops = tl.cdiv(kv_range, BLOCK_SIZE)
108
- req_to_token_block_start = (
109
- req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + kv_start
110
- )
111
- for i in range(num_loops):
112
- token_offsets_in_block = (
113
- tl.arange(0, NUM_PAGES_PER_BLOCK).to(tl.int64) + i * NUM_PAGES_PER_BLOCK
114
- ) * PAGE_SIZE
115
- page_offsets_in_block = token_offsets_in_block // PAGE_SIZE
116
- valid_tokens = token_offsets_in_block < kv_range
117
- valid_pages = page_offsets_in_block < num_pages
118
- token_numbers = tl.load(
119
- req_to_token_block_start + token_offsets_in_block, mask=valid_tokens
120
- )
121
- tl.store(
122
- kv_indices_ptr + kv_indices_offset + page_offsets_in_block,
123
- token_numbers // PAGE_SIZE, # write the page numbers to kv_indices_ptr
124
- mask=valid_pages,
34
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
35
+ for i in range(num_loop):
36
+ # index into req_to_token_ptr needs to be int64
37
+ offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
38
+ mask = offset < kv_end - kv_start
39
+ data = tl.load(
40
+ req_to_token_ptr
41
+ + req_pool_index * req_to_token_ptr_stride
42
+ + kv_start
43
+ + offset,
44
+ mask=mask,
125
45
  )
46
+ tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
126
47
 
127
48
 
128
49
  @triton.jit
@@ -40,10 +40,9 @@ from sglang.srt.layers.moe import (
40
40
  get_moe_a2a_backend,
41
41
  should_use_flashinfer_cutlass_moe_fp4_allgather,
42
42
  )
43
- from sglang.srt.layers.utils import is_sm100_supported
44
43
  from sglang.srt.managers.schedule_batch import global_server_args_dict
45
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
46
- from sglang.srt.utils import is_cuda, is_flashinfer_available
45
+ from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported
47
46
 
48
47
  _is_flashinfer_available = is_flashinfer_available()
49
48
  _is_sm100_supported = is_cuda() and is_sm100_supported()
@@ -1,20 +1,12 @@
1
1
  """CUTLASS based Fused MoE kernels."""
2
2
 
3
- import functools
4
- import json
5
- import logging
6
- import os
7
- from typing import Any, Callable, Dict, List, Optional, Tuple
8
-
9
3
  import torch
10
4
 
11
5
  from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
12
- from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
13
6
  from sglang.srt.utils import is_cuda
14
7
 
15
8
  _is_cuda = is_cuda()
16
9
  if _is_cuda:
17
- import sgl_kernel
18
10
  from sgl_kernel import (
19
11
  apply_shuffle_mul_sum,
20
12
  cutlass_fp4_group_mm,
@@ -157,10 +149,6 @@ def cutlass_fused_experts_fp8(
157
149
  rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
158
150
  rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
159
151
 
160
- if not is_sm100_supported():
161
- rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets)
162
- w1_scale = w1_scale.contiguous()
163
-
164
152
  c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
165
153
  c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
166
154
 
@@ -192,9 +180,6 @@ def cutlass_fused_experts_fp8(
192
180
  silu_and_mul(c1, intermediate)
193
181
 
194
182
  intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
195
- if not is_sm100_supported():
196
- a2_scale = per_group_transpose(a2_scale, expert_offsets)
197
- w2_scale = w2_scale.contiguous()
198
183
 
199
184
  fp8_blockwise_scaled_grouped_mm(
200
185
  c2,
@@ -248,7 +248,6 @@ class EPMoE(FusedMoE):
248
248
  gateup_output,
249
249
  masked_m,
250
250
  expected_m,
251
- recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
252
251
  )
253
252
  del gateup_input
254
253
  del gateup_input_fp8
@@ -304,7 +303,6 @@ class EPMoE(FusedMoE):
304
303
  down_output,
305
304
  masked_m,
306
305
  expected_m,
307
- recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
308
306
  )
309
307
  del down_input
310
308
  del down_input_fp8
@@ -667,7 +665,6 @@ class DeepEPMoE(EPMoE):
667
665
  gateup_output,
668
666
  masked_m,
669
667
  expected_m,
670
- recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
671
668
  )
672
669
  dispose_tensor(hidden_states_fp8[0])
673
670
 
@@ -708,9 +705,7 @@ class DeepEPMoE(EPMoE):
708
705
  (
709
706
  down_input_scale
710
707
  if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
711
- else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
712
- down_input_scale
713
- )
708
+ else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
714
709
  ),
715
710
  )
716
711
  down_output = torch.empty(
@@ -722,7 +717,6 @@ class DeepEPMoE(EPMoE):
722
717
  down_output,
723
718
  masked_m,
724
719
  expected_m,
725
- recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
726
720
  )
727
721
 
728
722
  return down_output