sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,19 @@ from typing import Callable, List, Optional, Tuple
4
4
  import torch
5
5
  from torch.nn import Module
6
6
 
7
+ from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
8
+
7
9
  try:
8
10
  from deep_gemm import (
9
11
  get_col_major_tma_aligned_tensor,
12
+ m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
10
13
  m_grouped_gemm_fp8_fp8_bf16_nt_masked,
11
14
  )
15
+ from sgl_kernel import silu_and_mul
16
+
17
+ from sglang.srt.layers.quantization.fp8_kernel import (
18
+ sglang_per_token_group_quant_fp8,
19
+ )
12
20
 
13
21
  use_deep_gemm = True
14
22
  except ImportError:
@@ -20,6 +28,8 @@ from sglang.srt.distributed import (
20
28
  get_tensor_model_parallel_world_size,
21
29
  )
22
30
  from sglang.srt.layers.moe.ep_moe.kernels import (
31
+ ep_gather,
32
+ ep_scatter,
23
33
  gelu_and_mul_triton_kernel,
24
34
  grouped_gemm_triton,
25
35
  post_reorder_triton_kernel,
@@ -27,6 +37,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
27
37
  run_moe_ep_preproess,
28
38
  silu_and_mul_masked_post_quant_fwd,
29
39
  silu_and_mul_triton_kernel,
40
+ tma_align_input_scale,
30
41
  )
31
42
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
32
43
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
@@ -600,7 +611,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
600
611
  self.quant_config.weight_block_size[1],
601
612
  )
602
613
  # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
603
- # Required by collum parallel or enabling merged weights
614
+ # Required by column parallel or enabling merged weights
604
615
  if intermediate_size % block_n != 0:
605
616
  raise ValueError(
606
617
  f"The output_size of gate's and up's weight = "
@@ -842,15 +853,23 @@ class DeepEPMoE(EPMoE):
842
853
  def forward(
843
854
  self,
844
855
  hidden_states: torch.Tensor,
856
+ topk_idx: torch.Tensor,
857
+ topk_weights: torch.Tensor,
845
858
  reorder_topk_ids: torch.Tensor,
846
859
  seg_indptr: torch.Tensor,
847
860
  masked_m: torch.Tensor,
848
861
  expected_m: int,
862
+ num_recv_tokens_per_expert: List[int],
849
863
  forward_mode: ForwardMode,
850
864
  ):
851
865
  resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
852
866
  if resolved_deepep_mode == DeepEPMode.normal:
853
- return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
867
+ if _ENABLE_JIT_DEEPGEMM:
868
+ return self.forward_deepgemm_contiguous(
869
+ hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
870
+ )
871
+ else:
872
+ return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
854
873
  elif resolved_deepep_mode == DeepEPMode.low_latency:
855
874
  return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
856
875
  else:
@@ -969,6 +988,106 @@ class DeepEPMoE(EPMoE):
969
988
  )
970
989
  return down_output
971
990
 
991
+ def forward_deepgemm_contiguous(
992
+ self,
993
+ hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
994
+ topk_idx,
995
+ topk_weights,
996
+ num_recv_tokens_per_expert: List[int],
997
+ ):
998
+ hidden_states_fp8, hidden_states_scale = hidden_states_fp8
999
+ assert self.quant_method is not None
1000
+ assert self.activation == "silu"
1001
+ if num_recv_tokens_per_expert is None:
1002
+ return hidden_states_fp8.bfloat16()
1003
+ all_tokens = sum(num_recv_tokens_per_expert)
1004
+ if all_tokens <= 0:
1005
+ return hidden_states_fp8.bfloat16()
1006
+ M, K = hidden_states_fp8.size()
1007
+ N = self.w13_weight.size(1)
1008
+ scale_block_size = 128
1009
+
1010
+ gather_out = torch.empty_like(
1011
+ hidden_states_fp8,
1012
+ device=hidden_states_fp8.device,
1013
+ dtype=torch.bfloat16,
1014
+ )
1015
+
1016
+ input_tensor = [
1017
+ torch.empty(
1018
+ (all_tokens, K),
1019
+ device=hidden_states_fp8.device,
1020
+ dtype=hidden_states_fp8.dtype,
1021
+ ),
1022
+ torch.empty(
1023
+ (all_tokens, K // 128),
1024
+ device=hidden_states_fp8.device,
1025
+ dtype=torch.float32,
1026
+ ),
1027
+ ]
1028
+ m_indices = torch.empty(
1029
+ all_tokens, device=hidden_states_fp8.device, dtype=torch.int32
1030
+ )
1031
+ output_index = torch.empty_like(topk_idx)
1032
+
1033
+ num_recv_tokens_per_expert_gpu = torch.tensor(
1034
+ num_recv_tokens_per_expert,
1035
+ dtype=torch.int32,
1036
+ pin_memory=True,
1037
+ device="cpu",
1038
+ ).cuda(non_blocking=True)
1039
+ expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
1040
+
1041
+ ep_scatter(
1042
+ hidden_states_fp8,
1043
+ hidden_states_scale,
1044
+ topk_idx,
1045
+ num_recv_tokens_per_expert_gpu,
1046
+ expert_start_loc,
1047
+ input_tensor[0],
1048
+ input_tensor[1],
1049
+ m_indices,
1050
+ output_index,
1051
+ )
1052
+
1053
+ gateup_output = torch.empty(
1054
+ (all_tokens, N),
1055
+ device=hidden_states_fp8.device,
1056
+ dtype=torch.bfloat16,
1057
+ )
1058
+ input_tensor[1] = tma_align_input_scale(input_tensor[1])
1059
+ m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
1060
+ input_tensor, self.w13_weight_fp8, gateup_output, m_indices
1061
+ )
1062
+ down_input = torch.empty(
1063
+ (
1064
+ all_tokens,
1065
+ N // 2,
1066
+ ),
1067
+ device=gateup_output.device,
1068
+ dtype=torch.bfloat16,
1069
+ )
1070
+ silu_and_mul(gateup_output.view(-1, N), down_input)
1071
+ down_output = torch.empty(
1072
+ (all_tokens, K),
1073
+ device=hidden_states_fp8.device,
1074
+ dtype=torch.bfloat16,
1075
+ )
1076
+ down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
1077
+ down_input, scale_block_size
1078
+ )
1079
+ down_input_scale = tma_align_input_scale(down_input_scale)
1080
+ m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
1081
+ (down_input_fp8, down_input_scale),
1082
+ self.w2_weight_fp8,
1083
+ down_output,
1084
+ m_indices,
1085
+ )
1086
+
1087
+ ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
1088
+
1089
+ return gather_out
1090
+
972
1091
  def forward_deepgemm_masked(
973
1092
  self,
974
1093
  hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
@@ -1,14 +1,19 @@
1
+ from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
1
2
  from sglang.srt.utils import DeepEPMode
2
3
 
3
4
  try:
4
5
  from deep_ep import Buffer
5
6
 
7
+ from sglang.srt.layers.quantization.fp8_kernel import (
8
+ sglang_per_token_group_quant_fp8,
9
+ )
10
+
6
11
  use_deepep = True
7
12
  except ImportError:
8
13
  use_deepep = False
9
14
 
10
15
  from enum import IntEnum, auto
11
- from typing import Optional, Tuple
16
+ from typing import Optional, Tuple, Union
12
17
 
13
18
  import torch
14
19
  import torch.distributed as dist
@@ -78,7 +83,6 @@ class DeepEPBuffer:
78
83
  ),
79
84
  num_rdma_bytes,
80
85
  )
81
-
82
86
  cls._buffer = Buffer(
83
87
  group,
84
88
  num_nvl_bytes,
@@ -181,44 +185,74 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
181
185
  topk_weights: torch.Tensor,
182
186
  ):
183
187
  topk_idx = topk_idx.to(torch.int64)
188
+ if _ENABLE_JIT_DEEPGEMM:
189
+ # TODO hard code 128 block quant,use fp8 communication
190
+ hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
184
191
  previous_event = Buffer.capture() if self.async_finish else None
185
192
  return hidden_states, topk_idx, topk_weights, previous_event
186
193
 
187
194
  def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
188
- (
189
- hidden_states,
190
- topk_idx,
191
- topk_weights,
192
- event,
193
- ) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
194
- event.current_stream_wait() if self.async_finish else ()
195
- if hidden_states.shape[0] > 0:
196
- reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
197
- hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
195
+ if _ENABLE_JIT_DEEPGEMM:
196
+ (
197
+ hidden_states,
198
+ topk_idx,
199
+ topk_weights,
200
+ num_recv_tokens_per_expert_list,
201
+ event,
202
+ ) = self._dispatch_core(
203
+ hidden_states, topk_idx, topk_weights, previous_event
198
204
  )
199
- else:
200
- reorder_topk_ids = torch.empty(
201
- (0,), device=hidden_states.device, dtype=torch.int64
205
+ event.current_stream_wait() if self.async_finish else ()
206
+ return (
207
+ hidden_states,
208
+ topk_idx,
209
+ topk_weights,
210
+ None,
211
+ num_recv_tokens_per_expert_list,
212
+ None,
213
+ None,
214
+ None,
202
215
  )
203
- seg_indptr = torch.zeros(
204
- (self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
216
+ else:
217
+ (
218
+ hidden_states,
219
+ topk_idx,
220
+ topk_weights,
221
+ num_recv_tokens_per_expert_list,
222
+ event,
223
+ ) = self._dispatch_core(
224
+ hidden_states, topk_idx, topk_weights, previous_event
205
225
  )
226
+ event.current_stream_wait() if self.async_finish else ()
227
+ if hidden_states.shape[0] > 0:
228
+ reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
229
+ hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
230
+ )
231
+ else:
232
+ reorder_topk_ids = torch.empty(
233
+ (0,), device=hidden_states.device, dtype=torch.int64
234
+ )
235
+ seg_indptr = torch.zeros(
236
+ (self.num_experts + 1,),
237
+ device=hidden_states.device,
238
+ dtype=torch.int64,
239
+ )
206
240
 
207
- masked_m = expected_m = None
208
-
209
- return (
210
- hidden_states,
211
- topk_idx,
212
- topk_weights,
213
- reorder_topk_ids,
214
- seg_indptr,
215
- masked_m,
216
- expected_m,
217
- )
241
+ masked_m = expected_m = None
242
+ return (
243
+ hidden_states,
244
+ topk_idx,
245
+ topk_weights,
246
+ reorder_topk_ids,
247
+ None,
248
+ seg_indptr,
249
+ masked_m,
250
+ expected_m,
251
+ )
218
252
 
219
253
  def _dispatch_core(
220
254
  self,
221
- x: torch.Tensor,
255
+ x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
222
256
  topk_idx: torch.Tensor,
223
257
  topk_weights: torch.Tensor,
224
258
  previous_event,
@@ -246,7 +280,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
246
280
  recv_x,
247
281
  recv_topk_idx,
248
282
  recv_topk_weights,
249
- _, # num_recv_tokens_per_expert_list
283
+ num_recv_tokens_per_expert_list,
250
284
  self.handle,
251
285
  event,
252
286
  ) = buffer.dispatch(
@@ -260,12 +294,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
260
294
  previous_event=previous_event,
261
295
  async_finish=self.async_finish,
262
296
  allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
297
+ expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
263
298
  )
264
299
 
265
300
  return (
266
301
  recv_x,
267
302
  recv_topk_idx,
268
303
  recv_topk_weights,
304
+ num_recv_tokens_per_expert_list,
269
305
  event,
270
306
  )
271
307
 
@@ -314,29 +350,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
314
350
  topk_idx: torch.Tensor,
315
351
  topk_weights: torch.Tensor,
316
352
  ):
317
- if hidden_states.shape[0] > 0:
318
- num_tokens = self.src2dst.shape[0] // self.router_topk
319
- output = torch.empty(
320
- (num_tokens, hidden_states.shape[1]),
321
- device=hidden_states.device,
322
- dtype=hidden_states.dtype,
323
- )
324
- deepep_post_reorder_triton_kernel[(num_tokens,)](
325
- hidden_states,
326
- output,
327
- self.src2dst,
328
- topk_idx,
329
- topk_weights,
330
- self.router_topk,
331
- hidden_states.shape[1],
332
- BLOCK_SIZE=512,
333
- )
353
+ if _ENABLE_JIT_DEEPGEMM:
354
+ output = hidden_states
334
355
  else:
335
- output = torch.zeros(
336
- (0, hidden_states.shape[1]),
337
- device=hidden_states.device,
338
- dtype=hidden_states.dtype,
339
- )
356
+ if hidden_states.shape[0] > 0:
357
+ num_tokens = self.src2dst.shape[0] // self.router_topk
358
+ output = torch.empty(
359
+ (num_tokens, hidden_states.shape[1]),
360
+ device=hidden_states.device,
361
+ dtype=hidden_states.dtype,
362
+ )
363
+ deepep_post_reorder_triton_kernel[(num_tokens,)](
364
+ hidden_states,
365
+ output,
366
+ self.src2dst,
367
+ topk_idx,
368
+ topk_weights,
369
+ self.router_topk,
370
+ hidden_states.shape[1],
371
+ BLOCK_SIZE=512,
372
+ )
373
+ else:
374
+ output = torch.zeros(
375
+ (0, hidden_states.shape[1]),
376
+ device=hidden_states.device,
377
+ dtype=hidden_states.dtype,
378
+ )
340
379
  previous_event = Buffer.capture() if self.async_finish else None
341
380
  return output, previous_event
342
381
 
@@ -360,6 +399,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
360
399
 
361
400
  def _get_buffer(self):
362
401
  DeepEPBuffer.set_dispatch_mode_as_normal()
402
+
363
403
  return DeepEPBuffer.get_deepep_buffer(
364
404
  self.group,
365
405
  self.hidden_size,
@@ -426,6 +466,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
426
466
  topk_idx,
427
467
  topk_weights,
428
468
  reorder_topk_ids,
469
+ None,
429
470
  seg_indptr,
430
471
  masked_m,
431
472
  expected_m,
@@ -570,7 +611,8 @@ class DeepEPDispatcher:
570
611
 
571
612
  def dispatch(self, *args, **kwargs) -> Tuple:
572
613
  self.dispatch_a(*args, **kwargs)
573
- return self.dispatch_b()
614
+ ret = self.dispatch_b()
615
+ return ret
574
616
 
575
617
  def dispatch_a(
576
618
  self,
@@ -593,7 +635,8 @@ class DeepEPDispatcher:
593
635
 
594
636
  def combine(self, *args, **kwargs) -> Tuple:
595
637
  self.combine_a(*args, **kwargs)
596
- return self.combine_b()
638
+ ret = self.combine_b()
639
+ return ret
597
640
 
598
641
  def combine_a(
599
642
  self,
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 32,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 64,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 64,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 256,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 8,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 8,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 8,
144
+ "num_stages": 4
145
+ }
146
+ }