sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. sglang/bench_one_batch.py +1 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/lang/chat_template.py +44 -0
  4. sglang/srt/configs/deepseekvl2.py +3 -0
  5. sglang/srt/configs/device_config.py +1 -1
  6. sglang/srt/configs/internvl.py +696 -0
  7. sglang/srt/configs/janus_pro.py +3 -0
  8. sglang/srt/configs/model_config.py +17 -0
  9. sglang/srt/constrained/xgrammar_backend.py +11 -19
  10. sglang/srt/conversation.py +30 -3
  11. sglang/srt/disaggregation/decode.py +4 -1
  12. sglang/srt/disaggregation/mini_lb.py +74 -23
  13. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  14. sglang/srt/disaggregation/nixl/conn.py +241 -71
  15. sglang/srt/disaggregation/utils.py +44 -1
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  17. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  19. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  20. sglang/srt/distributed/parallel_state.py +22 -1
  21. sglang/srt/entrypoints/engine.py +14 -2
  22. sglang/srt/entrypoints/http_server.py +28 -1
  23. sglang/srt/entrypoints/verl_engine.py +3 -2
  24. sglang/srt/hf_transformers_utils.py +20 -1
  25. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  26. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  27. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  28. sglang/srt/layers/attention/merge_state.py +46 -0
  29. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  30. sglang/srt/layers/attention/vision.py +290 -163
  31. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  32. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  33. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
  37. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  38. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  39. sglang/srt/layers/quantization/deep_gemm.py +5 -0
  40. sglang/srt/layers/quantization/fp8.py +108 -95
  41. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  42. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  43. sglang/srt/layers/quantization/kv_cache.py +3 -10
  44. sglang/srt/layers/quantization/utils.py +0 -5
  45. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  46. sglang/srt/lora/lora_manager.py +10 -13
  47. sglang/srt/managers/cache_controller.py +115 -119
  48. sglang/srt/managers/io_struct.py +10 -0
  49. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  50. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  51. sglang/srt/managers/schedule_batch.py +19 -1
  52. sglang/srt/managers/schedule_policy.py +11 -5
  53. sglang/srt/managers/scheduler.py +28 -13
  54. sglang/srt/managers/tokenizer_manager.py +24 -13
  55. sglang/srt/managers/tp_worker.py +9 -12
  56. sglang/srt/mem_cache/chunk_cache.py +2 -0
  57. sglang/srt/mem_cache/memory_pool.py +2 -2
  58. sglang/srt/model_executor/model_runner.py +44 -33
  59. sglang/srt/model_loader/loader.py +18 -11
  60. sglang/srt/models/clip.py +4 -4
  61. sglang/srt/models/deepseek_janus_pro.py +1 -1
  62. sglang/srt/models/deepseek_nextn.py +1 -20
  63. sglang/srt/models/deepseek_v2.py +55 -20
  64. sglang/srt/models/gemma3_mm.py +1 -1
  65. sglang/srt/models/internlm2.py +3 -0
  66. sglang/srt/models/internvl.py +670 -0
  67. sglang/srt/models/llama.py +1 -1
  68. sglang/srt/models/llama4.py +53 -7
  69. sglang/srt/models/minicpmv.py +1 -1
  70. sglang/srt/models/mllama.py +1 -1
  71. sglang/srt/models/phi3_small.py +16 -2
  72. sglang/srt/models/qwen2_5_vl.py +8 -4
  73. sglang/srt/models/qwen2_vl.py +4 -4
  74. sglang/srt/models/xiaomi_mimo.py +171 -0
  75. sglang/srt/openai_api/adapter.py +24 -40
  76. sglang/srt/openai_api/protocol.py +28 -16
  77. sglang/srt/reasoning_parser.py +2 -2
  78. sglang/srt/sampling/sampling_batch_info.py +54 -2
  79. sglang/srt/sampling/sampling_params.py +2 -0
  80. sglang/srt/server_args.py +30 -6
  81. sglang/srt/utils.py +35 -1
  82. sglang/test/test_block_fp8.py +2 -2
  83. sglang/test/test_deepep_utils.py +219 -0
  84. sglang/test/test_utils.py +3 -1
  85. sglang/version.py +1 -1
  86. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
  87. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
  88. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  89. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  90. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -5,16 +5,23 @@ import torch
5
5
  import triton
6
6
  import triton.language as tl
7
7
 
8
- from sglang.srt.distributed import get_tensor_model_parallel_rank
9
8
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
10
9
  from sglang.srt.utils import is_cuda
11
10
 
11
+ logger = logging.getLogger(__name__)
12
+
12
13
  _is_cuda = is_cuda()
13
14
  if _is_cuda:
14
15
  from sglang.srt.layers.quantization.fp8_kernel import (
15
- sglang_per_token_group_quant_fp8,
16
+ sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
16
17
  )
17
- logger = logging.getLogger(__name__)
18
+
19
+ try:
20
+ from deep_gemm import ceil_div
21
+ except ImportError:
22
+ logger.error(f"Failed to import ceil_div from deep_gemm.")
23
+
24
+ import triton.language as tl
18
25
 
19
26
 
20
27
  @triton.jit
@@ -654,10 +661,7 @@ def grouped_gemm_triton(
654
661
  if block_shape is not None:
655
662
  assert len(block_shape) == 2
656
663
  block_n, block_k = block_shape[0], block_shape[1]
657
- if _is_cuda:
658
- a, scale_a = sglang_per_token_group_quant_fp8(a, block_k)
659
- else:
660
- a, scale_a = per_token_group_quant_fp8(a, block_k)
664
+ a, scale_a = per_token_group_quant_fp8(a, block_k)
661
665
 
662
666
  assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
663
667
  assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
@@ -707,3 +711,334 @@ def grouped_gemm_triton(
707
711
  **config,
708
712
  )
709
713
  return c
714
+
715
+
716
+ @triton.jit
717
+ def _fwd_kernel_ep_scatter_1(
718
+ num_recv_tokens_per_expert,
719
+ expert_start_loc,
720
+ m_indices,
721
+ num_experts: tl.constexpr,
722
+ BLOCK_E: tl.constexpr,
723
+ BLOCK_EXPERT_NUM: tl.constexpr,
724
+ ):
725
+ cur_expert = tl.program_id(0)
726
+
727
+ offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
728
+ tokens_per_expert = tl.load(
729
+ num_recv_tokens_per_expert + offset_cumsum,
730
+ mask=offset_cumsum < num_experts,
731
+ other=0,
732
+ )
733
+ cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
734
+ tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
735
+
736
+ cur_expert_start = tl.load(expert_start_loc + cur_expert)
737
+ cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
738
+
739
+ m_indices_start_ptr = m_indices + cur_expert_start
740
+ off_expert = tl.arange(0, BLOCK_E)
741
+
742
+ for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
743
+ tl.store(
744
+ m_indices_start_ptr + start_m + off_expert,
745
+ cur_expert,
746
+ )
747
+
748
+
749
+ @triton.jit
750
+ def _fwd_kernel_ep_scatter_2(
751
+ total_token_num,
752
+ expert_start_loc,
753
+ recv_x,
754
+ recv_x_stride0,
755
+ recv_x_stride1,
756
+ recv_x_scale,
757
+ recv_x_scale_stride0,
758
+ recv_x_scale_stride1,
759
+ recv_topk,
760
+ recv_topk_stride0,
761
+ recv_topk_stride1,
762
+ output_tensor,
763
+ output_tensor_stride0,
764
+ output_tensor_stride1,
765
+ output_tensor_scale,
766
+ output_tensor_scale_stride0,
767
+ output_tensor_scale_stride1,
768
+ output_index,
769
+ output_index_stride0,
770
+ output_index_stride1,
771
+ topk_num: tl.constexpr,
772
+ HIDDEN_SIZE: tl.constexpr,
773
+ HIDDEN_SIZE_PAD: tl.constexpr,
774
+ SCALE_HIDDEN_SIZE: tl.constexpr,
775
+ SCALE_HIDDEN_SIZE_PAD: tl.constexpr,
776
+ ):
777
+ start_token_id = tl.program_id(0)
778
+ grid_num = tl.num_programs(0)
779
+
780
+ offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
781
+ mask = offset_in < HIDDEN_SIZE
782
+
783
+ offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
784
+ mask_s = offset_in_s < SCALE_HIDDEN_SIZE
785
+
786
+ for token_id in range(start_token_id, total_token_num, grid_num):
787
+ to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
788
+ to_copy_s = tl.load(
789
+ recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
790
+ )
791
+
792
+ for topk_index in tl.range(0, topk_num, 1, num_stages=4):
793
+ expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
794
+ if expert_id >= 0:
795
+ dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
796
+ tl.store(
797
+ output_index + token_id * output_index_stride0 + topk_index,
798
+ dest_token_index,
799
+ )
800
+ output_tensor_ptr = (
801
+ output_tensor + dest_token_index * output_tensor_stride0
802
+ )
803
+ output_tensor_scale_ptr = (
804
+ output_tensor_scale + dest_token_index * output_tensor_scale_stride0
805
+ )
806
+ tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
807
+ tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s)
808
+
809
+
810
+ # copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
811
+ @torch.no_grad()
812
+ def ep_scatter(
813
+ recv_x: torch.Tensor,
814
+ recv_x_scale: torch.Tensor,
815
+ recv_topk: torch.Tensor,
816
+ num_recv_tokens_per_expert: torch.Tensor,
817
+ expert_start_loc: torch.Tensor,
818
+ output_tensor: torch.Tensor,
819
+ output_tensor_scale: torch.Tensor,
820
+ m_indices: torch.Tensor,
821
+ output_index: torch.Tensor,
822
+ ):
823
+ BLOCK_E = 128 # token num of per expert is aligned to 128
824
+ BLOCK_D = 128 # block size of quantization
825
+ num_warps = 8
826
+ num_experts = num_recv_tokens_per_expert.shape[0]
827
+ hidden_size = recv_x.shape[1]
828
+ # grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
829
+ grid = num_experts
830
+
831
+ assert m_indices.shape[0] % BLOCK_E == 0
832
+
833
+ _fwd_kernel_ep_scatter_1[(grid,)](
834
+ num_recv_tokens_per_expert,
835
+ expert_start_loc,
836
+ m_indices,
837
+ num_experts=num_experts,
838
+ num_warps=num_warps,
839
+ BLOCK_E=BLOCK_E,
840
+ BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
841
+ )
842
+
843
+ grid = min(recv_topk.shape[0], 1024 * 8)
844
+
845
+ _fwd_kernel_ep_scatter_2[(grid,)](
846
+ recv_topk.shape[0],
847
+ expert_start_loc,
848
+ recv_x,
849
+ recv_x.stride(0),
850
+ recv_x.stride(1),
851
+ recv_x_scale,
852
+ recv_x_scale.stride(0),
853
+ recv_x_scale.stride(1),
854
+ recv_topk,
855
+ recv_topk.stride(0),
856
+ recv_topk.stride(1),
857
+ output_tensor,
858
+ output_tensor.stride(0),
859
+ output_tensor.stride(1),
860
+ output_tensor_scale,
861
+ output_tensor_scale.stride(0),
862
+ output_tensor_scale.stride(1),
863
+ output_index,
864
+ output_index.stride(0),
865
+ output_index.stride(1),
866
+ topk_num=recv_topk.shape[1],
867
+ num_warps=num_warps,
868
+ HIDDEN_SIZE=hidden_size,
869
+ HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
870
+ SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
871
+ SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
872
+ )
873
+ return
874
+
875
+
876
+ @triton.jit
877
+ def _fwd_kernel_ep_gather(
878
+ total_token_num,
879
+ input_tensor,
880
+ input_tensor_stride0,
881
+ input_tensor_stride1,
882
+ recv_topk_ids,
883
+ recv_topk_ids_stride0,
884
+ recv_topk_ids_stride1,
885
+ recv_topk_weight,
886
+ recv_topk_weight_stride0,
887
+ recv_topk_weight_stride1,
888
+ input_index,
889
+ input_index_stride0,
890
+ input_index_stride1,
891
+ output_tensor,
892
+ output_tensor_stride0,
893
+ output_tensor_stride1,
894
+ topk_num: tl.constexpr,
895
+ BLOCK_D: tl.constexpr,
896
+ ):
897
+ cur_block = tl.program_id(0)
898
+ start_cur_token = tl.program_id(1)
899
+ grid_num = tl.num_programs(1)
900
+
901
+ for cur_token in range(start_cur_token, total_token_num, grid_num):
902
+ off_d = tl.arange(0, BLOCK_D)
903
+ accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
904
+ for topk_index in range(0, topk_num):
905
+ expert_id = tl.load(
906
+ recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
907
+ )
908
+ if expert_id >= 0:
909
+ source_token_index = tl.load(
910
+ input_index + cur_token * input_index_stride0 + topk_index
911
+ )
912
+ acc_weight = tl.load(
913
+ recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
914
+ )
915
+ tmp = tl.load(
916
+ input_tensor
917
+ + source_token_index * input_tensor_stride0
918
+ + cur_block * BLOCK_D
919
+ + off_d
920
+ )
921
+ accumulator += tmp.to(tl.float32) * acc_weight
922
+
923
+ tl.store(
924
+ output_tensor
925
+ + cur_token * output_tensor_stride0
926
+ + cur_block * BLOCK_D
927
+ + off_d,
928
+ accumulator.to(output_tensor.dtype.element_ty),
929
+ )
930
+
931
+
932
+ @torch.no_grad()
933
+ def ep_gather(
934
+ input_tensor: torch.Tensor,
935
+ recv_topk_ids: torch.Tensor,
936
+ recv_topk_weight: torch.Tensor,
937
+ input_index: torch.Tensor,
938
+ output_tensor: torch.Tensor,
939
+ ):
940
+ BLOCK_D = 1024 # block size of quantization
941
+ num_warps = 2
942
+ num_tokens = output_tensor.shape[0]
943
+ hidden_size = input_tensor.shape[1]
944
+ assert hidden_size % BLOCK_D == 0
945
+ grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
946
+ _fwd_kernel_ep_gather[grid](
947
+ num_tokens,
948
+ input_tensor,
949
+ input_tensor.stride(0),
950
+ input_tensor.stride(1),
951
+ recv_topk_ids,
952
+ recv_topk_ids.stride(0),
953
+ recv_topk_ids.stride(1),
954
+ recv_topk_weight,
955
+ recv_topk_weight.stride(0),
956
+ recv_topk_weight.stride(1),
957
+ input_index,
958
+ input_index.stride(0),
959
+ input_index.stride(1),
960
+ output_tensor,
961
+ output_tensor.stride(0),
962
+ output_tensor.stride(1),
963
+ topk_num=recv_topk_ids.shape[1],
964
+ num_warps=num_warps,
965
+ BLOCK_D=BLOCK_D,
966
+ )
967
+ return
968
+
969
+
970
+ # copy from
971
+ # https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58
972
+ def get_tma_aligned_size(x: int, element_size: int) -> int:
973
+ """
974
+ Global memory address of TMA must be 16-byte aligned.
975
+ Since we use column-major layout for the LHS scaling tensor,
976
+ the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
977
+
978
+ Arguments:
979
+ x: original M-axis shape of the LHS scaling tensor.
980
+ element_size: element size of the LHS scaling tensor.
981
+
982
+ Returns:
983
+ M-axis shape of the LHS scaling tensor after padding.
984
+ """
985
+ tma_alignment_bytes = 16
986
+ assert tma_alignment_bytes % element_size == 0
987
+ alignment = tma_alignment_bytes // element_size
988
+ return ceil_div(x, alignment) * alignment
989
+
990
+
991
+ @triton.jit
992
+ def _tma_align_input_scale_kernel(
993
+ input_scale_ptr,
994
+ output_ptr,
995
+ m,
996
+ k_div_block_size,
997
+ input_scale_stride_m,
998
+ input_scale_stride_k,
999
+ output_stride_m,
1000
+ output_stride_k,
1001
+ BLOCK_SIZE_K: tl.constexpr,
1002
+ ):
1003
+ pid_m = tl.program_id(axis=0)
1004
+ grid_m = tl.num_programs(0)
1005
+ k_offsets = tl.arange(0, BLOCK_SIZE_K)
1006
+
1007
+ for m_base in range(pid_m, m, grid_m):
1008
+ input_offset = (
1009
+ input_scale_ptr
1010
+ + m_base * input_scale_stride_m
1011
+ + k_offsets * input_scale_stride_k
1012
+ )
1013
+ input_data = tl.load(input_offset, mask=k_offsets < k_div_block_size)
1014
+
1015
+ output_offset = (
1016
+ output_ptr + k_offsets * output_stride_k + m_base * output_stride_m
1017
+ )
1018
+ tl.store(output_offset, input_data, mask=k_offsets < k_div_block_size)
1019
+
1020
+
1021
+ # copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py
1022
+ def tma_align_input_scale(input_scale: torch.Tensor):
1023
+ assert input_scale.dim() == 2
1024
+ m, k_div_block_size = input_scale.shape
1025
+ padd_m = get_tma_aligned_size(m, input_scale.element_size())
1026
+ output = torch.empty(
1027
+ (k_div_block_size, padd_m), dtype=input_scale.dtype, device=input_scale.device
1028
+ )
1029
+
1030
+ grid_m = min(m, 8192)
1031
+ BLOCK_SIZE_K = triton.next_power_of_2(k_div_block_size)
1032
+
1033
+ _tma_align_input_scale_kernel[(grid_m,)](
1034
+ input_scale_ptr=input_scale,
1035
+ output_ptr=output,
1036
+ m=m,
1037
+ k_div_block_size=k_div_block_size,
1038
+ input_scale_stride_m=input_scale.stride(0),
1039
+ input_scale_stride_k=input_scale.stride(1),
1040
+ output_stride_m=output.stride(1), # Note: these are swapped
1041
+ output_stride_k=output.stride(0), # for column-major
1042
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
1043
+ )
1044
+ return output.t()[:m]
@@ -4,11 +4,19 @@ from typing import Callable, List, Optional, Tuple
4
4
  import torch
5
5
  from torch.nn import Module
6
6
 
7
+ from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
8
+
7
9
  try:
8
10
  from deep_gemm import (
9
11
  get_col_major_tma_aligned_tensor,
12
+ m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
10
13
  m_grouped_gemm_fp8_fp8_bf16_nt_masked,
11
14
  )
15
+ from sgl_kernel import silu_and_mul
16
+
17
+ from sglang.srt.layers.quantization.fp8_kernel import (
18
+ sglang_per_token_group_quant_fp8,
19
+ )
12
20
 
13
21
  use_deep_gemm = True
14
22
  except ImportError:
@@ -20,6 +28,8 @@ from sglang.srt.distributed import (
20
28
  get_tensor_model_parallel_world_size,
21
29
  )
22
30
  from sglang.srt.layers.moe.ep_moe.kernels import (
31
+ ep_gather,
32
+ ep_scatter,
23
33
  gelu_and_mul_triton_kernel,
24
34
  grouped_gemm_triton,
25
35
  post_reorder_triton_kernel,
@@ -27,6 +37,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
27
37
  run_moe_ep_preproess,
28
38
  silu_and_mul_masked_post_quant_fwd,
29
39
  silu_and_mul_triton_kernel,
40
+ tma_align_input_scale,
30
41
  )
31
42
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
32
43
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
@@ -842,15 +853,23 @@ class DeepEPMoE(EPMoE):
842
853
  def forward(
843
854
  self,
844
855
  hidden_states: torch.Tensor,
856
+ topk_idx: torch.Tensor,
857
+ topk_weights: torch.Tensor,
845
858
  reorder_topk_ids: torch.Tensor,
846
859
  seg_indptr: torch.Tensor,
847
860
  masked_m: torch.Tensor,
848
861
  expected_m: int,
862
+ num_recv_tokens_per_expert: List[int],
849
863
  forward_mode: ForwardMode,
850
864
  ):
851
865
  resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
852
866
  if resolved_deepep_mode == DeepEPMode.normal:
853
- return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
867
+ if _ENABLE_JIT_DEEPGEMM:
868
+ return self.forward_deepgemm_contiguous(
869
+ hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
870
+ )
871
+ else:
872
+ return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
854
873
  elif resolved_deepep_mode == DeepEPMode.low_latency:
855
874
  return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
856
875
  else:
@@ -969,6 +988,106 @@ class DeepEPMoE(EPMoE):
969
988
  )
970
989
  return down_output
971
990
 
991
+ def forward_deepgemm_contiguous(
992
+ self,
993
+ hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
994
+ topk_idx,
995
+ topk_weights,
996
+ num_recv_tokens_per_expert: List[int],
997
+ ):
998
+ hidden_states_fp8, hidden_states_scale = hidden_states_fp8
999
+ assert self.quant_method is not None
1000
+ assert self.activation == "silu"
1001
+ if num_recv_tokens_per_expert is None:
1002
+ return hidden_states_fp8.bfloat16()
1003
+ all_tokens = sum(num_recv_tokens_per_expert)
1004
+ if all_tokens <= 0:
1005
+ return hidden_states_fp8.bfloat16()
1006
+ M, K = hidden_states_fp8.size()
1007
+ N = self.w13_weight.size(1)
1008
+ scale_block_size = 128
1009
+
1010
+ gather_out = torch.empty_like(
1011
+ hidden_states_fp8,
1012
+ device=hidden_states_fp8.device,
1013
+ dtype=torch.bfloat16,
1014
+ )
1015
+
1016
+ input_tensor = [
1017
+ torch.empty(
1018
+ (all_tokens, K),
1019
+ device=hidden_states_fp8.device,
1020
+ dtype=hidden_states_fp8.dtype,
1021
+ ),
1022
+ torch.empty(
1023
+ (all_tokens, K // 128),
1024
+ device=hidden_states_fp8.device,
1025
+ dtype=torch.float32,
1026
+ ),
1027
+ ]
1028
+ m_indices = torch.empty(
1029
+ all_tokens, device=hidden_states_fp8.device, dtype=torch.int32
1030
+ )
1031
+ output_index = torch.empty_like(topk_idx)
1032
+
1033
+ num_recv_tokens_per_expert_gpu = torch.tensor(
1034
+ num_recv_tokens_per_expert,
1035
+ dtype=torch.int32,
1036
+ pin_memory=True,
1037
+ device="cpu",
1038
+ ).cuda(non_blocking=True)
1039
+ expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
1040
+
1041
+ ep_scatter(
1042
+ hidden_states_fp8,
1043
+ hidden_states_scale,
1044
+ topk_idx,
1045
+ num_recv_tokens_per_expert_gpu,
1046
+ expert_start_loc,
1047
+ input_tensor[0],
1048
+ input_tensor[1],
1049
+ m_indices,
1050
+ output_index,
1051
+ )
1052
+
1053
+ gateup_output = torch.empty(
1054
+ (all_tokens, N),
1055
+ device=hidden_states_fp8.device,
1056
+ dtype=torch.bfloat16,
1057
+ )
1058
+ input_tensor[1] = tma_align_input_scale(input_tensor[1])
1059
+ m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
1060
+ input_tensor, self.w13_weight_fp8, gateup_output, m_indices
1061
+ )
1062
+ down_input = torch.empty(
1063
+ (
1064
+ all_tokens,
1065
+ N // 2,
1066
+ ),
1067
+ device=gateup_output.device,
1068
+ dtype=torch.bfloat16,
1069
+ )
1070
+ silu_and_mul(gateup_output.view(-1, N), down_input)
1071
+ down_output = torch.empty(
1072
+ (all_tokens, K),
1073
+ device=hidden_states_fp8.device,
1074
+ dtype=torch.bfloat16,
1075
+ )
1076
+ down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
1077
+ down_input, scale_block_size
1078
+ )
1079
+ down_input_scale = tma_align_input_scale(down_input_scale)
1080
+ m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
1081
+ (down_input_fp8, down_input_scale),
1082
+ self.w2_weight_fp8,
1083
+ down_output,
1084
+ m_indices,
1085
+ )
1086
+
1087
+ ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
1088
+
1089
+ return gather_out
1090
+
972
1091
  def forward_deepgemm_masked(
973
1092
  self,
974
1093
  hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],