sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,164 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 3,
9
+ "USE_TMA": false
10
+ },
11
+ "2": {
12
+ "BLOCK_SIZE_M": 16,
13
+ "BLOCK_SIZE_N": 64,
14
+ "BLOCK_SIZE_K": 64,
15
+ "GROUP_SIZE_M": 1,
16
+ "num_warps": 8,
17
+ "num_stages": 5,
18
+ "USE_TMA": false
19
+ },
20
+ "4": {
21
+ "BLOCK_SIZE_M": 16,
22
+ "BLOCK_SIZE_N": 64,
23
+ "BLOCK_SIZE_K": 64,
24
+ "GROUP_SIZE_M": 16,
25
+ "num_warps": 4,
26
+ "num_stages": 3,
27
+ "USE_TMA": false
28
+ },
29
+ "8": {
30
+ "BLOCK_SIZE_M": 16,
31
+ "BLOCK_SIZE_N": 128,
32
+ "BLOCK_SIZE_K": 128,
33
+ "GROUP_SIZE_M": 16,
34
+ "num_warps": 4,
35
+ "num_stages": 2,
36
+ "USE_TMA": false
37
+ },
38
+ "16": {
39
+ "BLOCK_SIZE_M": 16,
40
+ "BLOCK_SIZE_N": 128,
41
+ "BLOCK_SIZE_K": 64,
42
+ "GROUP_SIZE_M": 1,
43
+ "num_warps": 4,
44
+ "num_stages": 4,
45
+ "USE_TMA": false
46
+ },
47
+ "24": {
48
+ "BLOCK_SIZE_M": 16,
49
+ "BLOCK_SIZE_N": 128,
50
+ "BLOCK_SIZE_K": 64,
51
+ "GROUP_SIZE_M": 32,
52
+ "num_warps": 4,
53
+ "num_stages": 3,
54
+ "USE_TMA": false
55
+ },
56
+ "32": {
57
+ "BLOCK_SIZE_M": 16,
58
+ "BLOCK_SIZE_N": 128,
59
+ "BLOCK_SIZE_K": 64,
60
+ "GROUP_SIZE_M": 16,
61
+ "num_warps": 4,
62
+ "num_stages": 5,
63
+ "USE_TMA": false
64
+ },
65
+ "48": {
66
+ "BLOCK_SIZE_M": 16,
67
+ "BLOCK_SIZE_N": 128,
68
+ "BLOCK_SIZE_K": 64,
69
+ "GROUP_SIZE_M": 16,
70
+ "num_warps": 4,
71
+ "num_stages": 4,
72
+ "USE_TMA": false
73
+ },
74
+ "64": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 32,
79
+ "num_warps": 4,
80
+ "num_stages": 3,
81
+ "USE_TMA": false
82
+ },
83
+ "96": {
84
+ "BLOCK_SIZE_M": 16,
85
+ "BLOCK_SIZE_N": 128,
86
+ "BLOCK_SIZE_K": 64,
87
+ "GROUP_SIZE_M": 1,
88
+ "num_warps": 4,
89
+ "num_stages": 3,
90
+ "USE_TMA": false
91
+ },
92
+ "128": {
93
+ "BLOCK_SIZE_M": 16,
94
+ "BLOCK_SIZE_N": 128,
95
+ "BLOCK_SIZE_K": 64,
96
+ "GROUP_SIZE_M": 1,
97
+ "num_warps": 4,
98
+ "num_stages": 3,
99
+ "USE_TMA": false
100
+ },
101
+ "256": {
102
+ "BLOCK_SIZE_M": 16,
103
+ "BLOCK_SIZE_N": 128,
104
+ "BLOCK_SIZE_K": 128,
105
+ "GROUP_SIZE_M": 64,
106
+ "num_warps": 4,
107
+ "num_stages": 3,
108
+ "USE_TMA": true
109
+ },
110
+ "512": {
111
+ "BLOCK_SIZE_M": 64,
112
+ "BLOCK_SIZE_N": 128,
113
+ "BLOCK_SIZE_K": 128,
114
+ "GROUP_SIZE_M": 16,
115
+ "num_warps": 4,
116
+ "num_stages": 3,
117
+ "USE_TMA": true
118
+ },
119
+ "1024": {
120
+ "BLOCK_SIZE_M": 64,
121
+ "BLOCK_SIZE_N": 128,
122
+ "BLOCK_SIZE_K": 128,
123
+ "GROUP_SIZE_M": 16,
124
+ "num_warps": 4,
125
+ "num_stages": 3,
126
+ "USE_TMA": true
127
+ },
128
+ "1536": {
129
+ "BLOCK_SIZE_M": 64,
130
+ "BLOCK_SIZE_N": 128,
131
+ "BLOCK_SIZE_K": 128,
132
+ "GROUP_SIZE_M": 64,
133
+ "num_warps": 4,
134
+ "num_stages": 3,
135
+ "USE_TMA": true
136
+ },
137
+ "2048": {
138
+ "BLOCK_SIZE_M": 64,
139
+ "BLOCK_SIZE_N": 128,
140
+ "BLOCK_SIZE_K": 128,
141
+ "GROUP_SIZE_M": 32,
142
+ "num_warps": 4,
143
+ "num_stages": 3,
144
+ "USE_TMA": true
145
+ },
146
+ "3072": {
147
+ "BLOCK_SIZE_M": 64,
148
+ "BLOCK_SIZE_N": 128,
149
+ "BLOCK_SIZE_K": 128,
150
+ "GROUP_SIZE_M": 16,
151
+ "num_warps": 4,
152
+ "num_stages": 3,
153
+ "USE_TMA": true
154
+ },
155
+ "4096": {
156
+ "BLOCK_SIZE_M": 64,
157
+ "BLOCK_SIZE_N": 128,
158
+ "BLOCK_SIZE_K": 128,
159
+ "GROUP_SIZE_M": 16,
160
+ "num_warps": 4,
161
+ "num_stages": 3,
162
+ "USE_TMA": true
163
+ }
164
+ }
@@ -23,7 +23,11 @@ from sglang.srt.utils import (
23
23
  )
24
24
 
25
25
  from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_config
26
- from .fused_moe_triton_kernels import invoke_fused_moe_kernel, moe_sum_reduce_triton
26
+ from .fused_moe_triton_kernels import (
27
+ invoke_fused_moe_kernel,
28
+ moe_sum_reduce_triton,
29
+ support_tensor_descriptor,
30
+ )
27
31
  from .moe_align_block_size import moe_align_block_size
28
32
 
29
33
  if TYPE_CHECKING:
@@ -36,7 +40,7 @@ _is_cpu = is_cpu()
36
40
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
37
41
 
38
42
  if _is_cuda:
39
- from sgl_kernel import gelu_and_mul, silu_and_mul
43
+ from sgl_kernel import gelu_and_mul, moe_sum_reduce, silu_and_mul
40
44
  elif _is_cpu and _is_cpu_amx_available:
41
45
  pass
42
46
  elif _is_hip:
@@ -78,6 +82,7 @@ def inplace_fused_experts(
78
82
  routed_scaling_factor: Optional[float] = None,
79
83
  gemm1_alpha: Optional[float] = None,
80
84
  gemm1_limit: Optional[float] = None,
85
+ filter_expert: bool = True,
81
86
  ) -> None:
82
87
  fused_experts_impl(
83
88
  hidden_states,
@@ -106,6 +111,7 @@ def inplace_fused_experts(
106
111
  routed_scaling_factor,
107
112
  gemm1_alpha,
108
113
  gemm1_limit,
114
+ filter_expert,
109
115
  )
110
116
 
111
117
 
@@ -134,6 +140,7 @@ def inplace_fused_experts_fake(
134
140
  routed_scaling_factor: Optional[float] = None,
135
141
  gemm1_alpha: Optional[float] = None,
136
142
  gemm1_limit: Optional[float] = None,
143
+ filter_expert: bool = True,
137
144
  ) -> None:
138
145
  pass
139
146
 
@@ -172,6 +179,7 @@ def outplace_fused_experts(
172
179
  routed_scaling_factor: Optional[float] = None,
173
180
  gemm1_alpha: Optional[float] = None,
174
181
  gemm1_limit: Optional[float] = None,
182
+ filter_expert: bool = True,
175
183
  ) -> torch.Tensor:
176
184
  return fused_experts_impl(
177
185
  hidden_states,
@@ -200,6 +208,7 @@ def outplace_fused_experts(
200
208
  routed_scaling_factor=routed_scaling_factor,
201
209
  gemm1_alpha=gemm1_alpha,
202
210
  gemm1_limit=gemm1_limit,
211
+ filter_expert=filter_expert,
203
212
  )
204
213
 
205
214
 
@@ -229,6 +238,7 @@ def outplace_fused_experts_fake(
229
238
  routed_scaling_factor: Optional[float] = None,
230
239
  gemm1_alpha: Optional[float] = None,
231
240
  gemm1_limit: Optional[float] = None,
241
+ filter_expert: bool = True,
232
242
  ) -> torch.Tensor:
233
243
  return torch.empty_like(hidden_states)
234
244
 
@@ -263,6 +273,10 @@ def fused_experts(
263
273
  block_shape: Optional[List[int]] = None,
264
274
  ):
265
275
  topk_weights, topk_ids, _ = topk_output
276
+ filter_expert = (
277
+ moe_runner_config.num_experts is None
278
+ or moe_runner_config.num_experts != moe_runner_config.num_local_experts
279
+ )
266
280
  if moe_runner_config.inplace:
267
281
  assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
268
282
  torch.ops.sglang.inplace_fused_experts(
@@ -290,6 +304,7 @@ def fused_experts(
290
304
  moe_runner_config.routed_scaling_factor,
291
305
  moe_runner_config.gemm1_alpha,
292
306
  moe_runner_config.gemm1_clamp_limit,
307
+ filter_expert,
293
308
  )
294
309
  return hidden_states
295
310
  else:
@@ -319,6 +334,7 @@ def fused_experts(
319
334
  routed_scaling_factor=moe_runner_config.routed_scaling_factor,
320
335
  gemm1_alpha=moe_runner_config.gemm1_alpha,
321
336
  gemm1_limit=moe_runner_config.gemm1_clamp_limit,
337
+ filter_expert=filter_expert,
322
338
  )
323
339
 
324
340
 
@@ -336,6 +352,11 @@ def swiglu_with_alpha_and_limit(x, gemm1_alpha, gemm1_limit):
336
352
  return gate * torch.sigmoid(gate * gemm1_alpha) * (up + 1)
337
353
 
338
354
 
355
+ @functools.lru_cache()
356
+ def _down_moe_use_tma():
357
+ return support_tensor_descriptor()
358
+
359
+
339
360
  def fused_experts_impl(
340
361
  hidden_states: torch.Tensor,
341
362
  w1: torch.Tensor,
@@ -363,6 +384,7 @@ def fused_experts_impl(
363
384
  routed_scaling_factor: Optional[float] = None,
364
385
  gemm1_alpha: Optional[float] = None,
365
386
  gemm1_limit: Optional[float] = None,
387
+ filter_expert: bool = True,
366
388
  ):
367
389
  padded_size = padding_size
368
390
  if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
@@ -402,25 +424,27 @@ def fused_experts_impl(
402
424
  topk_ids.shape[1],
403
425
  config_dtype,
404
426
  block_shape=block_shape,
427
+ return_down_config=True,
405
428
  )
406
429
 
407
- config = get_config_func(M)
408
-
409
- cache = torch.empty(
410
- M * topk_ids.shape[1] * max(N, w2.shape[1]),
411
- device=hidden_states.device,
412
- dtype=hidden_states.dtype,
430
+ config, (down_config, max_block_m) = get_config_func(M)
431
+ down_moe_use_tma = (
432
+ _down_moe_use_tma()
433
+ and down_config is not None
434
+ and down_config.pop("USE_TMA", False)
413
435
  )
414
- intermediate_cache1 = cache[: M * topk_ids.shape[1] * N].view(
415
- (M, topk_ids.shape[1], N),
436
+ topk = topk_ids.shape[1]
437
+ max_padded_tokens = (
438
+ min(M * topk, E + 1) * (max_block_m - 1) if down_moe_use_tma else 0
416
439
  )
417
- intermediate_cache2 = torch.empty(
418
- (M * topk_ids.shape[1], N // 2),
440
+ total_tokens = M * topk + max_padded_tokens
441
+ cache = torch.empty(
442
+ total_tokens * max(N, w2.shape[1]),
419
443
  device=hidden_states.device,
420
444
  dtype=hidden_states.dtype,
421
445
  )
422
- intermediate_cache3 = cache[: M * topk_ids.shape[1] * w2.shape[1]].view(
423
- (M, topk_ids.shape[1], w2.shape[1]),
446
+ intermediate_cache3 = cache[: M * topk * w2.shape[1]].view(
447
+ (M, topk, w2.shape[1]),
424
448
  )
425
449
 
426
450
  compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
@@ -428,7 +452,7 @@ def fused_experts_impl(
428
452
  if no_combine:
429
453
  assert not inplace
430
454
  out_hidden_states = torch.empty(
431
- (num_tokens, topk_ids.shape[1], w2.shape[1]),
455
+ (num_tokens, topk, w2.shape[1]),
432
456
  device=hidden_states.device,
433
457
  dtype=hidden_states.dtype,
434
458
  )
@@ -453,12 +477,28 @@ def fused_experts_impl(
453
477
  # chunk. Note that in most cases we only have one chunk
454
478
  # so the cache size and config are already set correctly and
455
479
  # do not need to be adjusted.
456
- intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
457
- intermediate_cache2 = intermediate_cache2[
458
- : tokens_in_chunk * topk_ids.shape[1]
459
- ]
480
+ config, (down_config, _) = get_config_func(tokens_in_chunk)
481
+ down_moe_use_tma = (
482
+ _down_moe_use_tma()
483
+ and down_config is not None
484
+ and down_config.pop("USE_TMA", False)
485
+ )
460
486
  intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
461
- config = get_config_func(tokens_in_chunk)
487
+
488
+ padded_tokens = (
489
+ min(tokens_in_chunk * topk, E + 1) * (config["BLOCK_SIZE_M"] - 1)
490
+ if down_moe_use_tma
491
+ else 0
492
+ )
493
+ total_tokens = tokens_in_chunk * topk + padded_tokens
494
+ intermediate_cache1 = cache[: total_tokens * N].view(
495
+ (total_tokens, N),
496
+ )
497
+ intermediate_cache2 = torch.empty(
498
+ (total_tokens, N // 2),
499
+ device=hidden_states.device,
500
+ dtype=hidden_states.dtype,
501
+ )
462
502
 
463
503
  curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
464
504
  curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
@@ -490,6 +530,8 @@ def fused_experts_impl(
490
530
  use_int4_w4a16=use_int4_w4a16,
491
531
  per_channel_quant=per_channel_quant,
492
532
  block_shape=block_shape,
533
+ c_sorted=down_moe_use_tma,
534
+ filter_expert=filter_expert,
493
535
  )
494
536
  if activation == "silu":
495
537
  if gemm1_alpha is not None:
@@ -536,7 +578,7 @@ def fused_experts_impl(
536
578
  num_tokens_post_padded,
537
579
  not apply_router_weight_on_input,
538
580
  1,
539
- config,
581
+ down_config or config,
540
582
  compute_type=compute_type,
541
583
  use_fp8_w8a8=use_fp8_w8a8,
542
584
  use_int8_w8a8=use_int8_w8a8,
@@ -544,6 +586,9 @@ def fused_experts_impl(
544
586
  use_int4_w4a16=use_int4_w4a16,
545
587
  per_channel_quant=per_channel_quant,
546
588
  block_shape=block_shape,
589
+ a_use_tma=down_moe_use_tma,
590
+ b_use_tma=down_moe_use_tma,
591
+ filter_expert=filter_expert,
547
592
  )
548
593
 
549
594
  if routed_scaling_factor is None:
@@ -569,11 +614,12 @@ def fused_experts_impl(
569
614
  routed_scaling_factor,
570
615
  )
571
616
  else:
572
- moe_sum_reduce_triton(
617
+ moe_sum_reduce(
573
618
  intermediate_cache3.view(*intermediate_cache3.shape),
574
619
  out_hidden_states[begin_chunk_idx:end_chunk_idx],
575
620
  routed_scaling_factor,
576
621
  )
622
+
577
623
  elif _is_hip:
578
624
  if _use_aiter:
579
625
  moe_sum(
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Tuple
9
9
  import torch
10
10
  import triton
11
11
 
12
+ from sglang.srt.server_args import get_global_server_args
12
13
  from sglang.srt.utils import get_device_name, is_hip
13
14
 
14
15
  logger = logging.getLogger(__name__)
@@ -21,6 +22,7 @@ def get_config_file_name(
21
22
  dtype: Optional[str],
22
23
  block_shape: Optional[int] = None,
23
24
  per_channel_quant: bool = False,
25
+ down_moe: bool = False,
24
26
  ) -> str:
25
27
  device_name = get_device_name().replace(" ", "_")
26
28
  dtype_selector = "" if not dtype else f",dtype={dtype}"
@@ -28,7 +30,8 @@ def get_config_file_name(
28
30
  "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
29
31
  )
30
32
  per_channel_quant_selector = ",per_channel_quant=True" if per_channel_quant else ""
31
- return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}{per_channel_quant_selector}.json"
33
+ down_moe_selector = "_down" if down_moe else ""
34
+ return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}{per_channel_quant_selector}{down_moe_selector}.json"
32
35
 
33
36
 
34
37
  @functools.lru_cache
@@ -39,6 +42,7 @@ def get_moe_configs(
39
42
  block_n: Optional[int] = 0,
40
43
  block_k: Optional[int] = 0,
41
44
  per_channel_quant: bool = False,
45
+ down_moe: bool = False,
42
46
  ) -> Optional[Dict[int, Any]]:
43
47
  """
44
48
  Return optimized configurations for the fused MoE kernel.
@@ -48,13 +52,23 @@ def get_moe_configs(
48
52
  kernel on a given batch size bs, the closest batch size in the grid should
49
53
  be picked and the associated configuration chosen to invoke the kernel.
50
54
  """
55
+ if get_global_server_args().enable_deterministic_inference:
56
+ logger.warning(
57
+ "Deterministic inference is enabled, using default MoE kernel config."
58
+ )
59
+ return None
51
60
  # Supported Triton versions, should be sorted from the newest to the oldest
52
61
  supported_triton_versions = ["3.4.0", "3.3.1", "3.2.0", "3.1.0"]
53
62
 
54
63
  # First look up if an optimized configuration is available in the configs
55
64
  # directory
56
65
  json_file_name = get_config_file_name(
57
- E, N, dtype, [block_n, block_k], per_channel_quant
66
+ E,
67
+ N,
68
+ dtype,
69
+ [block_n, block_k],
70
+ per_channel_quant,
71
+ down_moe=down_moe,
58
72
  )
59
73
 
60
74
  # We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
@@ -122,6 +136,14 @@ def get_default_config(
122
136
  is_marlin: bool,
123
137
  block_shape: Optional[List[int]] = None,
124
138
  ) -> Dict[str, int]:
139
+ if get_global_server_args().enable_deterministic_inference:
140
+ config = {
141
+ "BLOCK_SIZE_M": 64,
142
+ "BLOCK_SIZE_N": 64,
143
+ "BLOCK_SIZE_K": 32,
144
+ "GROUP_SIZE_M": 8,
145
+ }
146
+ return config
125
147
  if dtype == "fp8_w8a8":
126
148
  if block_shape is None:
127
149
  config = {
@@ -177,9 +199,12 @@ def try_get_optimal_moe_config(
177
199
  M: int,
178
200
  is_marlin: bool = False,
179
201
  block_shape: Optional[List[int]] = None,
202
+ return_down_config: bool = False,
180
203
  ):
181
204
  from sglang.srt.layers.moe.fused_moe_triton import get_config
182
205
 
206
+ down_config = None
207
+ max_block_m = None
183
208
  override_config = get_config()
184
209
  if override_config:
185
210
  config = override_config
@@ -188,7 +213,7 @@ def try_get_optimal_moe_config(
188
213
  E, _, N = w2_shape
189
214
  block_n = block_shape[0] if block_shape else 0
190
215
  block_k = block_shape[1] if block_shape else 0
191
- configs = get_moe_configs(E, N, dtype, block_n, block_k)
216
+ configs = get_moe_configs(E, N, dtype, block_n, block_k, down_moe=False)
192
217
 
193
218
  if configs:
194
219
  # If an optimal configuration map has been found, look up the
@@ -199,6 +224,21 @@ def try_get_optimal_moe_config(
199
224
  config = get_default_config(
200
225
  M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape
201
226
  )
227
+ if return_down_config:
228
+ down_configs = get_moe_configs(E, N, dtype, block_n, block_k, down_moe=True)
229
+ if down_configs:
230
+ down_config = down_configs[
231
+ min(down_configs.keys(), key=lambda x: abs(x - M))
232
+ ]
233
+ down_config = dict(**down_config)
234
+ max_block_m = max(
235
+ [cfg["BLOCK_SIZE_M"] for cfg in down_configs.values()]
236
+ )
237
+ if return_down_config:
238
+ assert (
239
+ down_config is None or config["BLOCK_SIZE_M"] == down_config["BLOCK_SIZE_M"]
240
+ )
241
+ return config, (down_config, max_block_m)
202
242
  return config
203
243
 
204
244