sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -35,6 +35,9 @@ from sglang.srt.layers.dp_attention import (
35
35
  get_attention_dp_rank,
36
36
  get_attention_dp_size,
37
37
  get_attention_tp_size,
38
+ get_dp_device,
39
+ get_dp_dtype,
40
+ get_dp_hidden_size,
38
41
  get_global_dp_buffer,
39
42
  get_local_attention_dp_size,
40
43
  set_dp_buffer_len,
@@ -46,10 +49,12 @@ from sglang.srt.model_executor.forward_batch_info import (
46
49
  ForwardBatch,
47
50
  ForwardMode,
48
51
  )
49
- from sglang.srt.utils import dump_to_file, use_intel_amx_backend
52
+ from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
50
53
 
51
54
  logger = logging.getLogger(__name__)
52
55
 
56
+ _is_npu = is_npu()
57
+
53
58
 
54
59
  @dataclasses.dataclass
55
60
  class LogitsProcessorOutput:
@@ -67,7 +72,10 @@ class LogitsProcessorOutput:
67
72
  next_token_top_logprobs_val: Optional[List] = None
68
73
  next_token_top_logprobs_idx: Optional[List] = None
69
74
  # The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
70
- next_token_token_ids_logprobs_val: Optional[List] = None
75
+ # Can contain either lists or GPU tensors (for delayed copy optimization in prefill-only requests)
76
+ next_token_token_ids_logprobs_val: Optional[
77
+ List[Union[List[float], torch.Tensor]]
78
+ ] = None
71
79
  next_token_token_ids_logprobs_idx: Optional[List] = None
72
80
 
73
81
  ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
@@ -180,10 +188,13 @@ class LogitsMetadata:
180
188
  )
181
189
  else:
182
190
  dp_local_start_pos = cumtokens[dp_rank - 1]
183
- dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
184
191
 
185
192
  self.dp_local_start_pos = dp_local_start_pos
186
- self.dp_local_num_tokens = dp_local_num_tokens
193
+ self.dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
194
+
195
+ hidden_size = get_dp_hidden_size()
196
+ dtype = get_dp_dtype()
197
+ device = get_dp_device()
187
198
 
188
199
  if self.global_num_tokens_for_logprob_cpu is not None:
189
200
  # create a smaller buffer to reduce peak memory usage
@@ -191,10 +202,13 @@ class LogitsMetadata:
191
202
  else:
192
203
  self.global_dp_buffer_len = self.global_dp_buffer_len
193
204
 
194
- set_dp_buffer_len(
195
- self.global_dp_buffer_len,
196
- self.dp_local_num_tokens,
197
- self.global_num_tokens_for_logprob_cpu,
205
+ self.gathered_buffer = torch.empty(
206
+ (
207
+ self.global_dp_buffer_len,
208
+ hidden_size,
209
+ ),
210
+ dtype=dtype,
211
+ device=device,
198
212
  )
199
213
 
200
214
 
@@ -441,7 +455,7 @@ class LogitsProcessor(nn.Module):
441
455
  if self.do_tensor_parallel_all_gather_dp_attn:
442
456
  logits_metadata.compute_dp_attention_metadata()
443
457
  hidden_states, local_hidden_states = (
444
- get_global_dp_buffer(),
458
+ logits_metadata.gathered_buffer,
445
459
  hidden_states,
446
460
  )
447
461
  dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
@@ -517,7 +531,12 @@ class LogitsProcessor(nn.Module):
517
531
  logits = logits[:, : self.config.vocab_size].float()
518
532
 
519
533
  if self.final_logit_softcapping:
520
- fused_softcap(logits, self.final_logit_softcapping)
534
+ if not _is_npu:
535
+ fused_softcap(logits, self.final_logit_softcapping)
536
+ else:
537
+ logits = self.final_logit_softcapping * torch.tanh(
538
+ logits / self.final_logit_softcapping
539
+ )
521
540
 
522
541
  return logits
523
542
 
@@ -1,4 +1,4 @@
1
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
1
+ from sglang.srt.layers.moe.moe_runner import MoeRunner, MoeRunnerConfig
2
2
  from sglang.srt.layers.moe.utils import (
3
3
  DeepEPMode,
4
4
  MoeA2ABackend,
@@ -17,6 +17,7 @@ from sglang.srt.layers.moe.utils import (
17
17
  __all__ = [
18
18
  "DeepEPMode",
19
19
  "MoeA2ABackend",
20
+ "MoeRunner",
20
21
  "MoeRunnerConfig",
21
22
  "MoeRunnerBackend",
22
23
  "initialize_moe_config",
@@ -147,8 +147,8 @@ def cutlass_w4a8_moe(
147
147
  k,
148
148
  )
149
149
 
150
- c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
151
- c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
150
+ c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
151
+ c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
152
152
 
153
153
  cutlass_w4a8_moe_mm(
154
154
  c1,
@@ -166,7 +166,7 @@ def cutlass_w4a8_moe(
166
166
  topk,
167
167
  )
168
168
 
169
- intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
169
+ intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
170
170
  silu_and_mul(c1, intermediate)
171
171
 
172
172
  intermediate_q = torch.empty(
@@ -1416,7 +1416,7 @@ def zero_experts_compute_triton(
1416
1416
  zero_expert_scales[zero_expert_mask] = 0.0
1417
1417
 
1418
1418
  normal_expert_mask = expert_indices >= num_experts
1419
- expert_indices[normal_expert_mask] = 0
1419
+ expert_indices[normal_expert_mask] = -1
1420
1420
  expert_scales[normal_expert_mask] = 0.0
1421
1421
 
1422
1422
  output = torch.zeros_like(hidden_states).to(hidden_states.device)
@@ -1,9 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import TYPE_CHECKING, Optional, Union
4
+ from typing import TYPE_CHECKING, List, Optional, Union
5
5
 
6
6
  import torch
7
+ import triton
8
+ import triton.language as tl
7
9
 
8
10
  from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
9
11
  from sglang.srt.layers.moe import (
@@ -31,11 +33,18 @@ from sglang.srt.layers.quantization.fp8_kernel import (
31
33
  )
32
34
  from sglang.srt.managers.schedule_batch import global_server_args_dict
33
35
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
34
- from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
36
+ from sglang.srt.offloader import get_offloader
37
+ from sglang.srt.utils import (
38
+ ceil_div,
39
+ dispose_tensor,
40
+ get_bool_env_var,
41
+ is_cuda,
42
+ is_hip,
43
+ is_npu,
44
+ )
35
45
 
36
46
  if TYPE_CHECKING:
37
47
  from sglang.srt.layers.moe.token_dispatcher import (
38
- AscendDeepEPLLOutput,
39
48
  DeepEPLLOutput,
40
49
  DeepEPNormalOutput,
41
50
  DispatchOutput,
@@ -454,12 +463,14 @@ class DeepEPMoE(EPMoE):
454
463
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
455
464
  return self.forward_aiter(dispatch_output)
456
465
  if _is_npu:
457
- assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
466
+ assert DispatchOutputChecker.format_is_deepep(dispatch_output)
458
467
  return self.forward_npu(dispatch_output)
459
468
  if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
460
469
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
461
470
  return self.forward_deepgemm_contiguous(dispatch_output)
462
471
  elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
472
+ if get_moe_runner_backend().is_flashinfer_cutedsl():
473
+ return self.forward_flashinfer_cutedsl(dispatch_output)
463
474
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
464
475
  return self.forward_deepgemm_masked(dispatch_output)
465
476
  else:
@@ -534,6 +545,24 @@ class DeepEPMoE(EPMoE):
534
545
  N = self.w13_weight.size(1)
535
546
  scale_block_size = 128
536
547
 
548
+ # TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
549
+ w13_weight_fp8 = (
550
+ self.w13_weight,
551
+ (
552
+ self.w13_weight_scale_inv
553
+ if self.use_block_quant
554
+ else self.w13_weight_scale
555
+ ),
556
+ )
557
+ w2_weight_fp8 = (
558
+ self.w2_weight,
559
+ (
560
+ self.w2_weight_scale_inv
561
+ if self.use_block_quant
562
+ else self.w2_weight_scale
563
+ ),
564
+ )
565
+
537
566
  hidden_states_fp8_shape = hidden_states_fp8.shape
538
567
  hidden_states_fp8_device = hidden_states_fp8.device
539
568
  hidden_states_fp8_dtype = hidden_states_fp8.dtype
@@ -564,12 +593,17 @@ class DeepEPMoE(EPMoE):
564
593
  )
565
594
  output_index = torch.empty_like(topk_idx)
566
595
 
567
- num_recv_tokens_per_expert_gpu = torch.tensor(
568
- num_recv_tokens_per_expert,
569
- dtype=torch.int32,
570
- pin_memory=True,
571
- device="cpu",
572
- ).cuda(non_blocking=True)
596
+ if get_offloader().forbid_copy_engine_usage:
597
+ num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
598
+ num_recv_tokens_per_expert
599
+ )
600
+ else:
601
+ num_recv_tokens_per_expert_gpu = torch.tensor(
602
+ num_recv_tokens_per_expert,
603
+ dtype=torch.int32,
604
+ pin_memory=True,
605
+ device="cpu",
606
+ ).cuda(non_blocking=True)
573
607
  expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
574
608
 
575
609
  ep_scatter(
@@ -594,7 +628,7 @@ class DeepEPMoE(EPMoE):
594
628
  if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
595
629
  input_tensor[1] = tma_align_input_scale(input_tensor[1])
596
630
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
597
- input_tensor, self.w13_weight_fp8, gateup_output, m_indices
631
+ input_tensor, w13_weight_fp8, gateup_output, m_indices
598
632
  )
599
633
  del input_tensor
600
634
  down_input = torch.empty(
@@ -624,7 +658,7 @@ class DeepEPMoE(EPMoE):
624
658
  down_input_scale = tma_align_input_scale(down_input_scale)
625
659
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
626
660
  (down_input_fp8, down_input_scale),
627
- self.w2_weight_fp8,
661
+ w2_weight_fp8,
628
662
  down_output,
629
663
  m_indices,
630
664
  )
@@ -639,6 +673,22 @@ class DeepEPMoE(EPMoE):
639
673
 
640
674
  return gather_out
641
675
 
676
+ def forward_flashinfer_cutedsl(
677
+ self,
678
+ dispatch_output: DeepEPLLOutput,
679
+ ):
680
+ hidden_states, _, _, masked_m, _ = dispatch_output
681
+ assert self.quant_method is not None
682
+ assert self.moe_runner_config.activation == "silu"
683
+
684
+ output = self.quant_method.apply_without_routing_weights(
685
+ layer=self,
686
+ x=hidden_states,
687
+ masked_m=masked_m,
688
+ moe_runner_config=self.moe_runner_config,
689
+ )
690
+ return output
691
+
642
692
  def forward_deepgemm_masked(
643
693
  self,
644
694
  dispatch_output: DeepEPLLOutput,
@@ -718,66 +768,127 @@ class DeepEPMoE(EPMoE):
718
768
 
719
769
  def forward_npu(
720
770
  self,
721
- dispatch_output: DeepEPLLOutput,
771
+ dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
722
772
  ):
723
- if TYPE_CHECKING:
724
- assert isinstance(dispatch_output, AscendDeepEPLLOutput)
725
- hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
726
773
  assert self.quant_method is not None
727
774
  assert self.moe_runner_config.activation == "silu"
728
775
 
776
+ import torch_npu
777
+
778
+ from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
779
+
729
780
  # NOTE: Ascend's Dispatch & Combine does not support FP16
730
781
  output_dtype = torch.bfloat16
782
+ group_list_type = 1
731
783
 
732
- pertoken_scale = hidden_states[1]
733
- hidden_states = hidden_states[0]
784
+ def _forward_normal(dispatch_output: DeepEPNormalOutput):
785
+ if TYPE_CHECKING:
786
+ assert isinstance(dispatch_output, DeepEPNormalOutput)
787
+ hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
788
+
789
+ if isinstance(hidden_states, tuple):
790
+ per_token_scale = hidden_states[1]
791
+ hidden_states = hidden_states[0]
792
+ else:
793
+ # dynamic quant
794
+ hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
795
+ hidden_states
796
+ )
734
797
 
735
- group_list_type = 1
736
- seg_indptr = seg_indptr.to(torch.int64)
798
+ group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
799
+ hidden_states.device
800
+ )
737
801
 
738
- import torch_npu
802
+ # gmm1: gate_up_proj
803
+ hidden_states = torch_npu.npu_grouped_matmul(
804
+ x=[hidden_states],
805
+ weight=[self.w13_weight],
806
+ scale=[self.w13_weight_scale.to(output_dtype)],
807
+ per_token_scale=[per_token_scale],
808
+ split_item=2,
809
+ group_list_type=group_list_type,
810
+ group_type=0,
811
+ group_list=group_list,
812
+ output_dtype=output_dtype,
813
+ )[0]
814
+
815
+ # act_fn: swiglu
816
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
817
+ hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
818
+
819
+ # gmm2: down_proj
820
+ hidden_states = torch_npu.npu_grouped_matmul(
821
+ x=[hidden_states],
822
+ weight=[self.w2_weight],
823
+ scale=[self.w2_weight_scale.to(output_dtype)],
824
+ per_token_scale=[swiglu_out_scale],
825
+ split_item=2,
826
+ group_list_type=group_list_type,
827
+ group_type=0,
828
+ group_list=group_list,
829
+ output_dtype=output_dtype,
830
+ )[0]
739
831
 
740
- # gmm1: gate_up_proj
741
- hidden_states = torch_npu.npu_grouped_matmul(
742
- x=[hidden_states],
743
- weight=[self.w13_weight],
744
- split_item=2,
745
- group_list_type=group_list_type,
746
- group_type=0,
747
- group_list=seg_indptr,
748
- output_dtype=torch.int32,
749
- )[0]
750
-
751
- # act_fn: swiglu
752
- hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
753
- x=hidden_states,
754
- weight_scale=self.w13_weight_scale.to(torch.float32),
755
- activation_scale=pertoken_scale,
756
- bias=None,
757
- quant_scale=None,
758
- quant_offset=None,
759
- group_index=seg_indptr,
760
- activate_left=True,
761
- quant_mode=1,
762
- )
763
-
764
- # gmm2: down_proj
765
- hidden_states = torch_npu.npu_grouped_matmul(
766
- x=[hidden_states],
767
- weight=[self.w2_weight],
768
- scale=[self.w2_weight_scale.to(output_dtype)],
769
- per_token_scale=[swiglu_out_scale],
770
- split_item=2,
771
- group_list_type=group_list_type,
772
- group_type=0,
773
- group_list=seg_indptr,
774
- output_dtype=output_dtype,
775
- )[0]
832
+ return hidden_states
776
833
 
777
- return hidden_states
834
+ def _forward_ll(dispatch_output: DeepEPLLOutput):
835
+ if TYPE_CHECKING:
836
+ assert isinstance(dispatch_output, DeepEPLLOutput)
837
+ hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
838
+
839
+ per_token_scale = hidden_states[1]
840
+ hidden_states = hidden_states[0]
841
+
842
+ group_list = group_list.to(torch.int64)
843
+
844
+ # gmm1: gate_up_proj
845
+ hidden_states = torch_npu.npu_grouped_matmul(
846
+ x=[hidden_states],
847
+ weight=[self.w13_weight],
848
+ split_item=2,
849
+ group_list_type=group_list_type,
850
+ group_type=0,
851
+ group_list=group_list,
852
+ output_dtype=torch.int32,
853
+ )[0]
854
+
855
+ # act_fn: swiglu
856
+ hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
857
+ x=hidden_states,
858
+ weight_scale=self.w13_weight_scale.to(torch.float32),
859
+ activation_scale=per_token_scale,
860
+ bias=None,
861
+ quant_scale=None,
862
+ quant_offset=None,
863
+ group_index=group_list,
864
+ activate_left=True,
865
+ quant_mode=1,
866
+ )
778
867
 
868
+ # gmm2: down_proj
869
+ hidden_states = torch_npu.npu_grouped_matmul(
870
+ x=[hidden_states],
871
+ weight=[self.w2_weight],
872
+ scale=[self.w2_weight_scale.to(output_dtype)],
873
+ per_token_scale=[swiglu_out_scale],
874
+ split_item=2,
875
+ group_list_type=group_list_type,
876
+ group_type=0,
877
+ group_list=group_list,
878
+ output_dtype=output_dtype,
879
+ )[0]
779
880
 
780
- def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
881
+ return hidden_states
882
+
883
+ if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
884
+ return _forward_normal(dispatch_output)
885
+ elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
886
+ return _forward_ll(dispatch_output)
887
+ else:
888
+ raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
889
+
890
+
891
+ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
781
892
  if get_moe_a2a_backend().is_deepep():
782
893
  return DeepEPMoE
783
894
 
@@ -790,8 +901,7 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
790
901
  return FusedMoE
791
902
  try:
792
903
  # Check the quantization argument directly
793
- quantization = global_server_args_dict.get("quantization")
794
- if quantization == "modelopt_fp4":
904
+ if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
795
905
  from sglang.srt.layers.moe.fused_moe_triton.layer import (
796
906
  FlashInferFP4MoE,
797
907
  )
@@ -800,10 +910,20 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
800
910
  except:
801
911
  pass
802
912
 
803
- if should_use_flashinfer_trtllm_moe():
913
+ if should_use_flashinfer_trtllm_moe() and quant_config is not None:
914
+ # FIXME: FlashInferFusedMoE only supports fp8 quant now
804
915
  return FlashInferFusedMoE
805
916
  if get_moe_runner_backend().is_flashinfer_cutlass():
806
917
  return FusedMoE
807
918
  if get_moe_expert_parallel_world_size() > 1:
808
919
  return EPMoE
809
920
  return FusedMoE
921
+
922
+
923
+ def copy_list_to_gpu_no_ce(arr: List[int]):
924
+ from sgl_kernel.elementwise import copy_to_gpu_no_ce
925
+
926
+ tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
927
+ tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
928
+ copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
929
+ return tensor_gpu
@@ -0,0 +1,156 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
5
+ from sgl_kernel.gemm import (
6
+ scaled_fp4_grouped_quant,
7
+ silu_and_mul_scaled_fp4_grouped_quant,
8
+ )
9
+
10
+
11
+ def get_cute_dtype(input: torch.Tensor) -> str:
12
+ if input.dtype == torch.bfloat16:
13
+ return "bfloat16"
14
+ elif input.dtype == torch.float16:
15
+ return "float16"
16
+ elif input.dtype == torch.float32:
17
+ return "float32"
18
+ else:
19
+ raise ValueError(f"Unsupported cute dtype {input.dtype}")
20
+
21
+
22
+ def flashinfer_cutedsl_moe_masked(
23
+ hidden_states: torch.Tensor,
24
+ input_global_scale: torch.Tensor,
25
+ w1: torch.Tensor,
26
+ w1_blockscale: torch.Tensor,
27
+ w1_alpha,
28
+ w2: torch.Tensor,
29
+ a2_global_scale: torch.Tensor,
30
+ w2_blockscale: torch.Tensor,
31
+ w2_alpha,
32
+ masked_m: torch.Tensor,
33
+ ):
34
+ """
35
+ Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
36
+ kernels.
37
+
38
+ Args:
39
+ hidden_states (torch.Tensor): [num_experts, m, k], bf16
40
+ input_global_scale (torch.Tensor): (l,)
41
+ w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
42
+ w1_blockscale (torch.Tensor): blockscale factors, e4m3,
43
+ w1_alpha (torch.Tensor): (l,)
44
+ w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
45
+ a2_global_scale (torch.Tensor): (l,)
46
+ w2_blockscale (torch.Tensor): blockscale factors, e4m3,
47
+ w2_alpha (torch.Tensor): (l,)
48
+ masked_m (torch.Tensor): Masked dimension indices
49
+
50
+ Notes:
51
+ - Assumes max(masked_m) <= m.
52
+ """
53
+
54
+ # === Assertions on dtypes ===
55
+ assert (
56
+ input_global_scale.dtype == torch.float32
57
+ ), f"input_global_scale must be float32, got {input_global_scale.dtype}"
58
+ assert w1.dtype == torch.uint8, f"w1 must be uint8 (fp4 packed), got {w1.dtype}"
59
+ assert (
60
+ w1_blockscale.dtype == torch.float8_e4m3fn
61
+ ), f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
62
+ assert (
63
+ w1_alpha.dtype == torch.float32
64
+ ), f"w1_alpha must be float32, got {w1_alpha.dtype}"
65
+ assert w2.dtype == torch.uint8, f"w2 must be uint8 (fp4 packed), got {w2.dtype}"
66
+ assert (
67
+ a2_global_scale.dtype == torch.float32
68
+ ), f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
69
+ assert (
70
+ w2_blockscale.dtype == torch.float8_e4m3fn
71
+ ), f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
72
+ assert (
73
+ w2_alpha.dtype == torch.float32
74
+ ), f"w2_alpha must be float32, got {w2_alpha.dtype}"
75
+
76
+ # === Assertions on shapes ===
77
+ n = w2.shape[-1] * 2 # intermediate dimension
78
+ num_experts, m, k = hidden_states.shape
79
+
80
+ assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
81
+ assert (
82
+ w1.shape[-1] * 2 == k
83
+ ), f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}"
84
+ assert w2.shape[-2:] == (
85
+ k,
86
+ n // 2,
87
+ ), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n//2)}"
88
+
89
+ assert input_global_scale.shape == (
90
+ num_experts,
91
+ ), f"input_global_scale must be (l,), got {input_global_scale.shape}"
92
+ assert w1_alpha.shape == (
93
+ num_experts,
94
+ ), f"w1_alpha must be (l,), got {w1_alpha.shape}"
95
+ assert a2_global_scale.shape == (
96
+ num_experts,
97
+ ), f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
98
+ assert w2_alpha.shape == (
99
+ num_experts,
100
+ ), f"w2_alpha must be (l,), got {w2_alpha.shape}"
101
+
102
+ aq, aq_sf = scaled_fp4_grouped_quant(
103
+ hidden_states,
104
+ input_global_scale,
105
+ masked_m,
106
+ )
107
+ gateup_output = torch.empty(
108
+ (num_experts, m, n * 2), dtype=hidden_states.dtype, device=aq.device
109
+ )
110
+ gateup_output = gateup_output.permute(1, 2, 0) # requirement of kernel
111
+ sf_vec_size = 16
112
+ assert aq_sf.dtype == torch.float8_e4m3fn
113
+ assert aq.dtype == torch.uint8
114
+ ab_dtype = "float4_e2m1fn"
115
+ sf_dtype = "float8_e4m3fn"
116
+
117
+ c_dtype = get_cute_dtype(hidden_states)
118
+
119
+ # Gemm1
120
+
121
+ grouped_gemm_nt_masked(
122
+ (aq, aq_sf),
123
+ (w1.permute(1, 2, 0), w1_blockscale),
124
+ gateup_output,
125
+ masked_m,
126
+ ab_dtype=ab_dtype,
127
+ sf_dtype=sf_dtype,
128
+ c_dtype=c_dtype,
129
+ sf_vec_size=sf_vec_size,
130
+ alpha=w1_alpha.view(1, 1, num_experts),
131
+ alpha_dtype=get_cute_dtype(w1_alpha),
132
+ ) # in logical [m, n, l]
133
+
134
+ # SILU and quantization
135
+ diq, diq_sf = silu_and_mul_scaled_fp4_grouped_quant(
136
+ gateup_output.permute(2, 0, 1),
137
+ a2_global_scale,
138
+ masked_m,
139
+ )
140
+
141
+ # Gemm2
142
+ out = torch.empty_like(hidden_states)
143
+ out = out.permute(1, 2, 0) # requirement of kernel
144
+ grouped_gemm_nt_masked(
145
+ (diq, diq_sf),
146
+ (w2.permute(1, 2, 0), w2_blockscale),
147
+ out,
148
+ masked_m,
149
+ ab_dtype=ab_dtype,
150
+ sf_dtype=sf_dtype,
151
+ c_dtype=c_dtype,
152
+ sf_vec_size=sf_vec_size,
153
+ alpha=w2_alpha.view(1, 1, num_experts),
154
+ alpha_dtype=get_cute_dtype(w2_alpha),
155
+ ) # in logical [m, k, l]
156
+ return out.permute(2, 0, 1)
@@ -8,16 +8,18 @@ from torch.nn import functional as F
8
8
 
9
9
  from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
10
10
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
11
+ from sglang.srt.layers.moe.token_dispatcher import StandardDispatchOutput
11
12
  from sglang.srt.layers.moe.topk import StandardTopKOutput
12
13
 
13
14
 
14
15
  def fused_moe_forward_native(
15
16
  layer: torch.nn.Module,
16
- x: torch.Tensor,
17
- topk_output: StandardTopKOutput,
18
- moe_runner_config: MoeRunnerConfig,
17
+ dispatch_output: StandardDispatchOutput,
19
18
  ) -> torch.Tensor:
20
19
 
20
+ x, topk_output = dispatch_output
21
+ moe_runner_config = layer.moe_runner_config
22
+
21
23
  if moe_runner_config.apply_router_weight_on_input:
22
24
  raise NotImplementedError()
23
25