sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import TYPE_CHECKING, Optional, Union
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
5
5
 
6
6
  import torch
7
7
 
8
- from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
9
8
  from sglang.srt.layers.moe import (
10
9
  get_deepep_mode,
11
10
  get_moe_a2a_backend,
@@ -15,13 +14,10 @@ from sglang.srt.layers.moe import (
15
14
  from sglang.srt.layers.moe.ep_moe.kernels import (
16
15
  ep_gather,
17
16
  ep_scatter,
18
- moe_ep_deepgemm_preprocess,
19
- post_reorder_triton_kernel,
20
17
  silu_and_mul_masked_post_quant_fwd,
21
18
  tma_align_input_scale,
22
19
  )
23
20
  from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
24
- from sglang.srt.layers.moe.topk import TopKOutput
25
21
  from sglang.srt.layers.quantization import deep_gemm_wrapper
26
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
27
23
  from sglang.srt.layers.quantization.fp8 import Fp8Config
@@ -29,13 +25,17 @@ from sglang.srt.layers.quantization.fp8_kernel import (
29
25
  is_fp8_fnuz,
30
26
  sglang_per_token_group_quant_fp8,
31
27
  )
32
- from sglang.srt.managers.schedule_batch import global_server_args_dict
28
+ from sglang.srt.layers.quantization.modelopt_quant import (
29
+ CUTEDSL_MOE_NVFP4_DISPATCH,
30
+ ModelOptNvFp4FusedMoEMethod,
31
+ )
33
32
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
33
+ from sglang.srt.offloader import get_offloader
34
+ from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
34
35
  from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
35
36
 
36
37
  if TYPE_CHECKING:
37
38
  from sglang.srt.layers.moe.token_dispatcher import (
38
- AscendDeepEPLLOutput,
39
39
  DeepEPLLOutput,
40
40
  DeepEPNormalOutput,
41
41
  DispatchOutput,
@@ -56,29 +56,13 @@ if _use_aiter:
56
56
  logger = logging.getLogger(__name__)
57
57
 
58
58
 
59
- # TODO(kaixih@nvidia): ideally we should merge this logic into
60
- # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
61
- @torch.compile
62
- def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
63
- temp = x.to(torch.float32).view(torch.int32)
64
- exp = torch.bitwise_right_shift(temp, 23)
65
- mant = torch.bitwise_and(temp, 0x7FFFFF)
66
- is_ru = torch.logical_and(
67
- torch.logical_and((mant > 0), (exp != 0xFE)),
68
- ~torch.logical_and((exp == 0), (mant <= 0x400000)),
69
- )
70
- exp = torch.where(is_ru, exp + 1, exp)
71
- new_x = exp.to(torch.uint8).view(torch.int)
72
- return new_x.transpose(1, 2).contiguous().transpose(1, 2)
73
-
74
-
75
- class EPMoE(FusedMoE):
59
+ class DeepEPMoE(FusedMoE):
76
60
  """
77
- MoE Expert Parallel Impl
78
-
79
-
61
+ MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
80
62
  """
81
63
 
64
+ _has_printed = False
65
+
82
66
  def __init__(
83
67
  self,
84
68
  num_experts: int,
@@ -92,272 +76,29 @@ class EPMoE(FusedMoE):
92
76
  prefix: str = "",
93
77
  activation: str = "silu",
94
78
  routed_scaling_factor: Optional[float] = None,
95
- gemm1_alpha: Optional[float] = None,
96
- gemm1_clamp_limit: Optional[float] = None,
97
- with_bias: bool = False,
98
79
  ):
99
80
  super().__init__(
100
81
  num_experts=num_experts,
82
+ top_k=top_k,
101
83
  hidden_size=hidden_size,
102
84
  intermediate_size=intermediate_size,
103
- num_fused_shared_experts=num_fused_shared_experts,
104
85
  layer_id=layer_id,
105
- top_k=top_k,
86
+ num_fused_shared_experts=num_fused_shared_experts,
106
87
  params_dtype=params_dtype,
107
88
  quant_config=quant_config,
108
89
  prefix=prefix,
109
90
  activation=activation,
110
- # apply_router_weight_on_input=apply_router_weight_on_input,
111
91
  routed_scaling_factor=routed_scaling_factor,
112
- gemm1_alpha=gemm1_alpha,
113
- gemm1_clamp_limit=gemm1_clamp_limit,
114
- with_bias=with_bias,
115
92
  )
116
93
 
117
- self.intermediate_size = intermediate_size
118
-
119
94
  if isinstance(quant_config, Fp8Config):
120
95
  self.use_block_quant = getattr(self.quant_method, "block_quant", False)
121
- self.block_shape = (
122
- self.quant_method.quant_config.weight_block_size
123
- if self.use_block_quant
124
- else None
125
- )
126
96
  self.use_fp8_w8a8 = True
127
97
  self.fp8_dtype = torch.float8_e4m3fn
128
- self.activation_scheme = quant_config.activation_scheme
129
98
  else:
130
99
  self.use_fp8_w8a8 = False
131
100
  self.use_block_quant = False
132
- self.block_shape = None
133
- self.activation_scheme = None
134
-
135
- def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
136
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
137
- return self.forward_deepgemm(hidden_states, topk_output)
138
- else:
139
- return super().forward(hidden_states, topk_output)
140
-
141
- def forward_deepgemm(
142
- self,
143
- hidden_states: torch.Tensor,
144
- topk_output: TopKOutput,
145
- ):
146
101
 
147
- self.w13_weight_fp8 = (
148
- self.w13_weight,
149
- (
150
- self.w13_weight_scale_inv
151
- if self.use_block_quant
152
- else self.w13_weight_scale
153
- ),
154
- )
155
- self.w2_weight_fp8 = (
156
- self.w2_weight,
157
- self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
158
- )
159
-
160
- assert self.quant_method is not None
161
- assert self.moe_runner_config.activation == "silu"
162
-
163
- hidden_states_shape = hidden_states.shape
164
- hidden_states_dtype = hidden_states.dtype
165
- hidden_states_device = hidden_states.device
166
-
167
- topk_weights, topk_ids, _ = topk_output
168
-
169
- if not self.use_block_quant:
170
- # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
171
- scale_block_size = 128
172
- w13_weight_scale_n = 2 * (
173
- (self.intermediate_size + scale_block_size - 1) // scale_block_size
174
- )
175
- w13_weight_scale_k = (
176
- hidden_states_shape[-1] + scale_block_size - 1
177
- ) // scale_block_size
178
- w13_weight_scale = (
179
- self.w13_weight_scale.unsqueeze(1)
180
- .repeat_interleave(w13_weight_scale_n, dim=1)
181
- .unsqueeze(2)
182
- .repeat_interleave(w13_weight_scale_k, dim=2)
183
- )
184
- self.w13_weight_fp8 = (
185
- self.w13_weight,
186
- w13_weight_scale,
187
- )
188
- w2_weight_scale_n = (
189
- hidden_states_shape[-1] + scale_block_size - 1
190
- ) // scale_block_size
191
- w2_weight_scale_k = (
192
- self.intermediate_size + scale_block_size - 1
193
- ) // scale_block_size
194
- w2_weight_scale = (
195
- self.w2_weight_scale.unsqueeze(1)
196
- .repeat_interleave(w2_weight_scale_n, dim=1)
197
- .unsqueeze(2)
198
- .repeat_interleave(w2_weight_scale_k, dim=2)
199
- )
200
- self.w2_weight_fp8 = (
201
- self.w2_weight,
202
- w2_weight_scale,
203
- )
204
-
205
- # PreReorder
206
- m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
207
- moe_ep_deepgemm_preprocess(
208
- topk_ids,
209
- self.num_experts,
210
- hidden_states,
211
- self.top_k,
212
- self.start_expert_id,
213
- self.end_expert_id,
214
- self.block_shape,
215
- )
216
- )
217
-
218
- dispose_tensor(hidden_states)
219
-
220
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
221
- b, s_mn, s_k = gateup_input_scale.shape
222
- assert (
223
- s_mn % 4 == 0 and s_k % 4 == 0
224
- ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
225
-
226
- # GroupGemm-0
227
- gateup_input_fp8 = (
228
- gateup_input,
229
- (
230
- _cast_to_e8m0_with_rounding_up(gateup_input_scale)
231
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
232
- else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
233
- gateup_input_scale
234
- )
235
- ),
236
- )
237
- num_groups, m, k = gateup_input_fp8[0].size()
238
- n = self.w13_weight.size(1)
239
- gateup_output = torch.empty(
240
- (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
241
- )
242
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
243
- gateup_input_fp8,
244
- self.w13_weight_fp8,
245
- gateup_output,
246
- masked_m,
247
- expected_m,
248
- )
249
- del gateup_input
250
- del gateup_input_fp8
251
-
252
- # Act
253
- down_input = torch.empty(
254
- (
255
- gateup_output.shape[0],
256
- gateup_output.shape[1],
257
- gateup_output.shape[2] // 2,
258
- ),
259
- device=hidden_states_device,
260
- dtype=self.fp8_dtype,
261
- )
262
- scale_block_size = 128
263
- down_input_scale = torch.empty(
264
- (
265
- gateup_output.shape[0],
266
- gateup_output.shape[1],
267
- gateup_output.shape[2] // 2 // scale_block_size,
268
- ),
269
- device=hidden_states_device,
270
- dtype=torch.float32,
271
- )
272
- silu_and_mul_masked_post_quant_fwd(
273
- gateup_output,
274
- down_input,
275
- down_input_scale,
276
- scale_block_size,
277
- masked_m,
278
- scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
279
- )
280
- del gateup_output
281
-
282
- # GroupGemm-1
283
- n = self.w2_weight.size(1)
284
- down_input_fp8 = (
285
- down_input,
286
- (
287
- down_input_scale
288
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
289
- else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
290
- ),
291
- )
292
- down_output = torch.empty(
293
- (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
294
- )
295
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
296
- down_input_fp8,
297
- self.w2_weight_fp8,
298
- down_output,
299
- masked_m,
300
- expected_m,
301
- )
302
- del down_input
303
- del down_input_fp8
304
-
305
- # PostReorder
306
- output = torch.empty(
307
- hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
308
- )
309
- post_reorder_triton_kernel[(hidden_states_shape[0],)](
310
- down_output,
311
- output,
312
- src2dst,
313
- topk_ids,
314
- topk_weights,
315
- self.start_expert_id,
316
- self.end_expert_id,
317
- self.top_k,
318
- hidden_states_shape[1],
319
- m_max * self.start_expert_id,
320
- BLOCK_SIZE=512,
321
- )
322
- if self.moe_runner_config.routed_scaling_factor is not None:
323
- output *= self.moe_runner_config.routed_scaling_factor
324
- return output
325
-
326
-
327
- class DeepEPMoE(EPMoE):
328
- """
329
- MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
330
- """
331
-
332
- _has_printed = False
333
-
334
- def __init__(
335
- self,
336
- num_experts: int,
337
- top_k: int,
338
- hidden_size: int,
339
- intermediate_size: int,
340
- layer_id: int,
341
- num_fused_shared_experts: int = 0,
342
- params_dtype: Optional[torch.dtype] = None,
343
- quant_config: Optional[QuantizationConfig] = None,
344
- prefix: str = "",
345
- activation: str = "silu",
346
- routed_scaling_factor: Optional[float] = None,
347
- ):
348
- super().__init__(
349
- num_experts=num_experts,
350
- top_k=top_k,
351
- hidden_size=hidden_size,
352
- intermediate_size=intermediate_size,
353
- layer_id=layer_id,
354
- num_fused_shared_experts=num_fused_shared_experts,
355
- params_dtype=params_dtype,
356
- quant_config=quant_config,
357
- prefix=prefix,
358
- activation=activation,
359
- routed_scaling_factor=routed_scaling_factor,
360
- )
361
102
  self.deepep_mode = get_deepep_mode()
362
103
 
363
104
  # TODO: move to the beginning of the file
@@ -444,9 +185,20 @@ class DeepEPMoE(EPMoE):
444
185
  topk_idx=topk_idx,
445
186
  topk_weights=topk_weights,
446
187
  forward_batch=forward_batch,
188
+ input_global_scale=(
189
+ self.w13_input_scale_quant
190
+ if isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
191
+ and self.quant_method.enable_flashinfer_cutedsl_moe
192
+ and CUTEDSL_MOE_NVFP4_DISPATCH
193
+ else None
194
+ ),
447
195
  )
448
196
 
449
- def moe_impl(self, dispatch_output: DispatchOutput):
197
+ def moe_impl(
198
+ self,
199
+ dispatch_output: DispatchOutput,
200
+ down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
201
+ ):
450
202
  from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
451
203
 
452
204
  if _use_aiter:
@@ -454,12 +206,16 @@ class DeepEPMoE(EPMoE):
454
206
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
455
207
  return self.forward_aiter(dispatch_output)
456
208
  if _is_npu:
457
- assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
209
+ assert DispatchOutputChecker.format_is_deepep(dispatch_output)
458
210
  return self.forward_npu(dispatch_output)
459
211
  if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
460
212
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
461
213
  return self.forward_deepgemm_contiguous(dispatch_output)
462
214
  elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
215
+ if get_moe_runner_backend().is_flashinfer_cutedsl():
216
+ return self.forward_flashinfer_cutedsl(
217
+ dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
218
+ )
463
219
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
464
220
  return self.forward_deepgemm_masked(dispatch_output)
465
221
  else:
@@ -473,12 +229,14 @@ class DeepEPMoE(EPMoE):
473
229
  topk_idx: torch.Tensor,
474
230
  topk_weights: torch.Tensor,
475
231
  forward_batch: ForwardBatch,
232
+ overlap_args: Optional[Dict[str, Any]] = None,
476
233
  ):
477
234
  return self.deepep_dispatcher.combine(
478
235
  hidden_states=hidden_states,
479
236
  topk_idx=topk_idx,
480
237
  topk_weights=topk_weights,
481
238
  forward_batch=forward_batch,
239
+ overlap_args=overlap_args,
482
240
  )
483
241
 
484
242
  def forward_aiter(
@@ -534,6 +292,23 @@ class DeepEPMoE(EPMoE):
534
292
  N = self.w13_weight.size(1)
535
293
  scale_block_size = 128
536
294
 
295
+ w13_weight_fp8 = (
296
+ self.w13_weight,
297
+ (
298
+ self.w13_weight_scale_inv
299
+ if self.use_block_quant
300
+ else self.w13_weight_scale
301
+ ),
302
+ )
303
+ w2_weight_fp8 = (
304
+ self.w2_weight,
305
+ (
306
+ self.w2_weight_scale_inv
307
+ if self.use_block_quant
308
+ else self.w2_weight_scale
309
+ ),
310
+ )
311
+
537
312
  hidden_states_fp8_shape = hidden_states_fp8.shape
538
313
  hidden_states_fp8_device = hidden_states_fp8.device
539
314
  hidden_states_fp8_dtype = hidden_states_fp8.dtype
@@ -564,12 +339,17 @@ class DeepEPMoE(EPMoE):
564
339
  )
565
340
  output_index = torch.empty_like(topk_idx)
566
341
 
567
- num_recv_tokens_per_expert_gpu = torch.tensor(
568
- num_recv_tokens_per_expert,
569
- dtype=torch.int32,
570
- pin_memory=True,
571
- device="cpu",
572
- ).cuda(non_blocking=True)
342
+ if get_offloader().forbid_copy_engine_usage:
343
+ num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
344
+ num_recv_tokens_per_expert
345
+ )
346
+ else:
347
+ num_recv_tokens_per_expert_gpu = torch.tensor(
348
+ num_recv_tokens_per_expert,
349
+ dtype=torch.int32,
350
+ pin_memory=True,
351
+ device="cpu",
352
+ ).cuda(non_blocking=True)
573
353
  expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
574
354
 
575
355
  ep_scatter(
@@ -594,7 +374,7 @@ class DeepEPMoE(EPMoE):
594
374
  if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
595
375
  input_tensor[1] = tma_align_input_scale(input_tensor[1])
596
376
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
597
- input_tensor, self.w13_weight_fp8, gateup_output, m_indices
377
+ input_tensor, w13_weight_fp8, gateup_output, m_indices
598
378
  )
599
379
  del input_tensor
600
380
  down_input = torch.empty(
@@ -624,7 +404,7 @@ class DeepEPMoE(EPMoE):
624
404
  down_input_scale = tma_align_input_scale(down_input_scale)
625
405
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
626
406
  (down_input_fp8, down_input_scale),
627
- self.w2_weight_fp8,
407
+ w2_weight_fp8,
628
408
  down_output,
629
409
  m_indices,
630
410
  )
@@ -639,6 +419,24 @@ class DeepEPMoE(EPMoE):
639
419
 
640
420
  return gather_out
641
421
 
422
+ def forward_flashinfer_cutedsl(
423
+ self,
424
+ dispatch_output: DeepEPLLOutput,
425
+ down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
426
+ ):
427
+ hidden_states, _, _, masked_m, _ = dispatch_output
428
+ assert self.quant_method is not None
429
+ assert self.moe_runner_config.activation == "silu"
430
+
431
+ output = self.quant_method.apply_without_routing_weights(
432
+ layer=self,
433
+ x=hidden_states,
434
+ masked_m=masked_m,
435
+ moe_runner_config=self.moe_runner_config,
436
+ down_gemm_overlap_args=down_gemm_overlap_args,
437
+ )
438
+ return output
439
+
642
440
  def forward_deepgemm_masked(
643
441
  self,
644
442
  dispatch_output: DeepEPLLOutput,
@@ -718,66 +516,176 @@ class DeepEPMoE(EPMoE):
718
516
 
719
517
  def forward_npu(
720
518
  self,
721
- dispatch_output: DeepEPLLOutput,
519
+ dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
722
520
  ):
723
- if TYPE_CHECKING:
724
- assert isinstance(dispatch_output, AscendDeepEPLLOutput)
725
- hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
726
521
  assert self.quant_method is not None
727
522
  assert self.moe_runner_config.activation == "silu"
728
523
 
524
+ import torch_npu
525
+
526
+ from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
527
+
729
528
  # NOTE: Ascend's Dispatch & Combine does not support FP16
730
529
  output_dtype = torch.bfloat16
530
+ group_list_type = 1
731
531
 
732
- pertoken_scale = hidden_states[1]
733
- hidden_states = hidden_states[0]
532
+ def _forward_normal(dispatch_output: DeepEPNormalOutput):
533
+ if TYPE_CHECKING:
534
+ assert isinstance(dispatch_output, DeepEPNormalOutput)
535
+ hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
734
536
 
735
- group_list_type = 1
736
- seg_indptr = seg_indptr.to(torch.int64)
537
+ if isinstance(hidden_states, tuple):
538
+ per_token_scale = hidden_states[1]
539
+ hidden_states = hidden_states[0]
737
540
 
738
- import torch_npu
541
+ group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
542
+ hidden_states.device
543
+ )
544
+ if self.w13_weight.dtype != torch.int8:
545
+ # gmm1: gate_up_proj
546
+ hidden_states = torch_npu.npu_grouped_matmul(
547
+ x=[hidden_states],
548
+ weight=[self.w13_weight.permute(0, 2, 1)],
549
+ # per_token_scale=[per_token_scale],
550
+ split_item=2,
551
+ group_list_type=group_list_type,
552
+ group_type=0,
553
+ group_list=group_list,
554
+ output_dtype=output_dtype,
555
+ )[0]
556
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
557
+ # gmm2: down_proj
558
+ hidden_states = torch_npu.npu_grouped_matmul(
559
+ x=[hidden_states],
560
+ weight=[self.w2_weight.permute(0, 2, 1)],
561
+ split_item=2,
562
+ group_list_type=group_list_type,
563
+ group_type=0,
564
+ group_list=group_list,
565
+ output_dtype=output_dtype,
566
+ )[0]
567
+ else:
568
+ if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
569
+ hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
570
+ hidden_states
571
+ )
572
+ # gmm1: gate_up_proj
573
+ hidden_states = torch_npu.npu_grouped_matmul(
574
+ x=[hidden_states],
575
+ weight=[self.w13_weight],
576
+ scale=[self.w13_weight_scale.to(output_dtype)],
577
+ per_token_scale=[per_token_scale],
578
+ split_item=2,
579
+ group_list_type=group_list_type,
580
+ group_type=0,
581
+ group_list=group_list,
582
+ output_dtype=output_dtype,
583
+ )[0]
584
+
585
+ # act_fn: swiglu
586
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
587
+ hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
588
+ hidden_states
589
+ )
739
590
 
740
- # gmm1: gate_up_proj
741
- hidden_states = torch_npu.npu_grouped_matmul(
742
- x=[hidden_states],
743
- weight=[self.w13_weight],
744
- split_item=2,
745
- group_list_type=group_list_type,
746
- group_type=0,
747
- group_list=seg_indptr,
748
- output_dtype=torch.int32,
749
- )[0]
750
-
751
- # act_fn: swiglu
752
- hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
753
- x=hidden_states,
754
- weight_scale=self.w13_weight_scale.to(torch.float32),
755
- activation_scale=pertoken_scale,
756
- bias=None,
757
- quant_scale=None,
758
- quant_offset=None,
759
- group_index=seg_indptr,
760
- activate_left=True,
761
- quant_mode=1,
762
- )
591
+ # gmm2: down_proj
592
+ hidden_states = torch_npu.npu_grouped_matmul(
593
+ x=[hidden_states],
594
+ weight=[self.w2_weight],
595
+ scale=[self.w2_weight_scale.to(output_dtype)],
596
+ per_token_scale=[swiglu_out_scale],
597
+ split_item=2,
598
+ group_list_type=group_list_type,
599
+ group_type=0,
600
+ group_list=group_list,
601
+ output_dtype=output_dtype,
602
+ )[0]
763
603
 
764
- # gmm2: down_proj
765
- hidden_states = torch_npu.npu_grouped_matmul(
766
- x=[hidden_states],
767
- weight=[self.w2_weight],
768
- scale=[self.w2_weight_scale.to(output_dtype)],
769
- per_token_scale=[swiglu_out_scale],
770
- split_item=2,
771
- group_list_type=group_list_type,
772
- group_type=0,
773
- group_list=seg_indptr,
774
- output_dtype=output_dtype,
775
- )[0]
604
+ return hidden_states
776
605
 
777
- return hidden_states
606
+ def _forward_ll(dispatch_output: DeepEPLLOutput):
607
+ if TYPE_CHECKING:
608
+ assert isinstance(dispatch_output, DeepEPLLOutput)
609
+ hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
610
+
611
+ if isinstance(hidden_states, tuple):
612
+ per_token_scale = hidden_states[1]
613
+ hidden_states = hidden_states[0]
614
+
615
+ group_list = group_list.to(torch.int64)
616
+
617
+ if self.w13_weight.dtype != torch.int8:
618
+ # gmm1: gate_up_proj
619
+ hidden_states = torch_npu.npu_grouped_matmul(
620
+ x=[hidden_states],
621
+ weight=[self.w13_weight.permute(0, 2, 1)],
622
+ # per_token_scale=[per_token_scale],
623
+ split_item=2,
624
+ group_list_type=group_list_type,
625
+ group_type=0,
626
+ group_list=group_list,
627
+ output_dtype=output_dtype,
628
+ )[0]
629
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
630
+ # gmm2: down_proj
631
+ hidden_states = torch_npu.npu_grouped_matmul(
632
+ x=[hidden_states],
633
+ weight=[self.w2_weight.permute(0, 2, 1)],
634
+ split_item=2,
635
+ group_list_type=group_list_type,
636
+ group_type=0,
637
+ group_list=group_list,
638
+ output_dtype=output_dtype,
639
+ )[0]
640
+ else:
641
+ # gmm1: gate_up_proj
642
+ hidden_states = torch_npu.npu_grouped_matmul(
643
+ x=[hidden_states],
644
+ weight=[self.w13_weight],
645
+ split_item=2,
646
+ group_list_type=group_list_type,
647
+ group_type=0,
648
+ group_list=group_list,
649
+ output_dtype=torch.int32,
650
+ )[0]
651
+
652
+ # act_fn: swiglu
653
+ hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
654
+ x=hidden_states,
655
+ weight_scale=self.w13_weight_scale.to(torch.float32),
656
+ activation_scale=per_token_scale,
657
+ bias=None,
658
+ quant_scale=None,
659
+ quant_offset=None,
660
+ group_index=group_list,
661
+ activate_left=True,
662
+ quant_mode=1,
663
+ )
778
664
 
665
+ # gmm2: down_proj
666
+ hidden_states = torch_npu.npu_grouped_matmul(
667
+ x=[hidden_states],
668
+ weight=[self.w2_weight],
669
+ scale=[self.w2_weight_scale.to(output_dtype)],
670
+ per_token_scale=[swiglu_out_scale],
671
+ split_item=2,
672
+ group_list_type=group_list_type,
673
+ group_type=0,
674
+ group_list=group_list,
675
+ output_dtype=output_dtype,
676
+ )[0]
779
677
 
780
- def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
678
+ return hidden_states
679
+
680
+ if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
681
+ return _forward_normal(dispatch_output)
682
+ elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
683
+ return _forward_ll(dispatch_output)
684
+ else:
685
+ raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
686
+
687
+
688
+ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
781
689
  if get_moe_a2a_backend().is_deepep():
782
690
  return DeepEPMoE
783
691
 
@@ -790,8 +698,7 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
790
698
  return FusedMoE
791
699
  try:
792
700
  # Check the quantization argument directly
793
- quantization = global_server_args_dict.get("quantization")
794
- if quantization == "modelopt_fp4":
701
+ if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
795
702
  from sglang.srt.layers.moe.fused_moe_triton.layer import (
796
703
  FlashInferFP4MoE,
797
704
  )
@@ -800,10 +707,18 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
800
707
  except:
801
708
  pass
802
709
 
803
- if should_use_flashinfer_trtllm_moe():
710
+ if should_use_flashinfer_trtllm_moe() and quant_config is not None:
711
+ # FIXME: FlashInferFusedMoE only supports fp8 quant now
804
712
  return FlashInferFusedMoE
805
713
  if get_moe_runner_backend().is_flashinfer_cutlass():
806
714
  return FusedMoE
807
- if get_moe_expert_parallel_world_size() > 1:
808
- return EPMoE
809
715
  return FusedMoE
716
+
717
+
718
+ def copy_list_to_gpu_no_ce(arr: List[int]):
719
+ from sgl_kernel.elementwise import copy_to_gpu_no_ce
720
+
721
+ tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
722
+ tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
723
+ copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
724
+ return tensor_gpu