sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,41 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
1
4
  from dataclasses import dataclass
2
- from typing import Optional
5
+ from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard
6
+
7
+ import torch
8
+
9
+ from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend
10
+
11
+ if TYPE_CHECKING:
12
+ from sglang.srt.layers.moe.moe_runner.triton import (
13
+ TritonRunnerCore,
14
+ TritonRunnerInput,
15
+ TritonRunnerOutput,
16
+ )
17
+ from sglang.srt.layers.moe.token_dispatcher import (
18
+ CombineInput,
19
+ CombineInputFormat,
20
+ DispatchOutput,
21
+ DispatchOutputFormat,
22
+ )
3
23
 
4
24
 
5
25
  @dataclass
6
26
  class MoeRunnerConfig:
27
+
28
+ # MoE parameters
29
+ num_experts: Optional[int] = None
30
+ num_local_experts: Optional[int] = None
31
+ hidden_size: Optional[int] = None
32
+ intermediate_size_per_partition: Optional[int] = None
33
+ layer_id: Optional[int] = None
34
+ top_k: Optional[int] = None
35
+ num_fused_shared_experts: Optional[int] = None
36
+ params_dtype: Optional[torch.dtype] = None
37
+
38
+ # Runner configuration
7
39
  activation: str = "silu"
8
40
  apply_router_weight_on_input: bool = False
9
41
  inplace: bool = True
@@ -11,3 +43,244 @@ class MoeRunnerConfig:
11
43
  routed_scaling_factor: Optional[float] = None
12
44
  gemm1_alpha: Optional[float] = None
13
45
  gemm1_clamp_limit: Optional[float] = None
46
+
47
+
48
+ @dataclass
49
+ class RunnerInput(ABC):
50
+
51
+ @property
52
+ @abstractmethod
53
+ def runner_backend(self) -> MoeRunnerBackend: ...
54
+
55
+ def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerInput]:
56
+ return self.runner_backend == MoeRunnerBackend.TRITON
57
+
58
+
59
+ class RunnerOutput(ABC):
60
+
61
+ @property
62
+ @abstractmethod
63
+ def runner_backend(self) -> MoeRunnerBackend: ...
64
+
65
+ def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerOutput]:
66
+ return self.runner_backend == MoeRunnerBackend.TRITON
67
+
68
+
69
+ @dataclass
70
+ class MoeQuantInfo(ABC):
71
+ """Moe quantization data."""
72
+
73
+ pass
74
+
75
+
76
+ class MoeRunnerCore(ABC):
77
+
78
+ def __init__(self, config: MoeRunnerConfig):
79
+ self.config = config
80
+
81
+ @abstractmethod
82
+ def run(
83
+ self, runner_input: RunnerInput, quant_info: MoeQuantInfo, running_state: dict
84
+ ) -> RunnerOutput:
85
+ pass
86
+
87
+ @property
88
+ @abstractmethod
89
+ def runner_backend(self) -> MoeRunnerBackend: ...
90
+
91
+ def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerCore]:
92
+ return self.runner_backend == MoeRunnerBackend.TRITON
93
+
94
+
95
+ class FusedOpPool:
96
+
97
+ _fused_funcs: dict[str, Callable] = {}
98
+
99
+ @classmethod
100
+ def register_fused_func(
101
+ cls, a2a_backend_name: str, runner_backend_name: str, fused_func: Callable
102
+ ):
103
+ key = (a2a_backend_name, runner_backend_name)
104
+ if key in cls._fused_funcs:
105
+ raise ValueError(
106
+ f"Fused function for {a2a_backend_name} to {runner_backend_name} is already registered."
107
+ )
108
+ assert MoeA2ABackend(
109
+ a2a_backend_name
110
+ ), f"Invalid dispatch name: {a2a_backend_name}"
111
+ assert MoeRunnerBackend(
112
+ runner_backend_name
113
+ ), f"Invalid runner name: {runner_backend_name}"
114
+ cls._fused_funcs[key] = fused_func
115
+
116
+ @classmethod
117
+ def get_fused_func(cls, dispatch_name: str, runner_name: str) -> Optional[Callable]:
118
+ key = (dispatch_name, runner_name)
119
+ fused_func = cls._fused_funcs.get(key)
120
+ return fused_func
121
+
122
+
123
+ class PermuteMethodPool:
124
+
125
+ _pre_permute_methods: dict[
126
+ Tuple[DispatchOutputFormat, MoeRunnerBackend], Callable
127
+ ] = {}
128
+ _post_permute_methods: dict[
129
+ Tuple[MoeRunnerBackend, CombineInputFormat], Callable
130
+ ] = {}
131
+
132
+ @classmethod
133
+ def register_pre_permute(
134
+ cls,
135
+ dispatch_output_name: str,
136
+ runner_backend_name: str,
137
+ permute_func: Callable,
138
+ ):
139
+ """
140
+ Register a customized pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
141
+
142
+ :param dispatch_output_name: The DispatchOutputFormat name.
143
+ :param runner_backend_name: The MoeRunnerBackend name.
144
+ :param permute_func: The permute function to register.
145
+ """
146
+ # TODO: check if registration is valid
147
+ key = (dispatch_output_name, runner_backend_name)
148
+ if key in cls._pre_permute_methods:
149
+ raise ValueError(
150
+ f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered."
151
+ )
152
+ cls._pre_permute_methods[key] = permute_func
153
+
154
+ @classmethod
155
+ def register_post_permute(
156
+ cls,
157
+ runner_backend_name: str,
158
+ combine_input_name: str,
159
+ permute_func: Callable,
160
+ ):
161
+ """
162
+ Register a customized post-permute function for the given MoeRunnerBackend and CombineInputFormat.
163
+
164
+ :param runner_backend_name: The MoeRunnerBackend name.
165
+ :param combine_input_name: The CombineInputFormat name.
166
+ :param permute_func: The permute function to register.
167
+ """
168
+ # TODO: check if registration is valid
169
+ key = (runner_backend_name, combine_input_name)
170
+ if key in cls._post_permute_methods:
171
+ raise ValueError(
172
+ f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered."
173
+ )
174
+ cls._post_permute_methods[key] = permute_func
175
+
176
+ @classmethod
177
+ def get_pre_permute(
178
+ cls,
179
+ dispatch_output_format: DispatchOutputFormat,
180
+ runner_input_format: MoeRunnerBackend,
181
+ ) -> Callable:
182
+ """
183
+ Retrieve the pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
184
+
185
+ :param dispatch_output_format: The DispatchOutputFormat type.
186
+ :param runner_input_format: The MoeRunnerBackend type.
187
+ :return: The registered permute function or None if not found.
188
+ """
189
+ key = (dispatch_output_format, runner_input_format)
190
+ pre_permute_func = cls._pre_permute_methods.get(key)
191
+ assert (
192
+ pre_permute_func is not None
193
+ ), f"Pre-permute function for {dispatch_output_format} to {runner_input_format} is not registered"
194
+ return pre_permute_func
195
+
196
+ @classmethod
197
+ def get_post_permute(
198
+ cls,
199
+ runner_output_format: MoeRunnerBackend,
200
+ combine_input_format: CombineInputFormat,
201
+ ) -> Callable:
202
+ """
203
+ Retrieve the post-permute function for the given MoeRunnerBackend and CombineInputFormat.
204
+
205
+ :param runner_output_format: The MoeRunnerBackend type.
206
+ :param combine_input_format: The CombineInputFormat type.
207
+ :return: The registered permute function or None if not found.
208
+ """
209
+ key = (runner_output_format, combine_input_format)
210
+ post_permute_func = cls._post_permute_methods.get(key)
211
+ assert (
212
+ post_permute_func is not None
213
+ ), f"Post-permute function for {runner_output_format} to {combine_input_format} is not registered"
214
+ return post_permute_func
215
+
216
+
217
+ def register_fused_func(
218
+ a2a_backend_name: str,
219
+ runner_backend_name: str,
220
+ ) -> Callable:
221
+ """
222
+ Decorator to register a fused function for the given DispatchOutputFormat and MoeRunnerBackend.
223
+
224
+ :param a2a_backend_name: The A2A backend name.
225
+ :param runner_backend_name: The MoeRunnerBackend name.
226
+ :return: The decorator function.
227
+ """
228
+
229
+ def decorator(fused_func: Callable):
230
+ FusedOpPool.register_fused_func(
231
+ a2a_backend_name, runner_backend_name, fused_func
232
+ )
233
+ return fused_func
234
+
235
+ return decorator
236
+
237
+
238
+ def register_pre_permute(
239
+ dispatch_output_name: str,
240
+ runner_backend_name: str,
241
+ ) -> Callable:
242
+ """
243
+ Decorator to register a pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
244
+
245
+ :param dispatch_output_name: The DispatchOutputFormat name.
246
+ :param runner_backend_name: The MoeRunnerBackend name.
247
+ :return: The decorator function.
248
+ """
249
+
250
+ def decorator(
251
+ permute_func: Callable[
252
+ [DispatchOutput, MoeQuantInfo, MoeRunnerConfig, dict], RunnerInput
253
+ ]
254
+ ) -> Callable:
255
+
256
+ PermuteMethodPool.register_pre_permute(
257
+ dispatch_output_name, runner_backend_name, permute_func
258
+ )
259
+ return permute_func
260
+
261
+ return decorator
262
+
263
+
264
+ def register_post_permute(
265
+ runner_backend_name: str,
266
+ combine_input_name: str,
267
+ ) -> Callable:
268
+ """
269
+ Decorator to register a post-permute function for the given MoeRunnerBackend and CombineInputFormat.
270
+
271
+ :param runner_backend_name: The MoeRunnerBackend name.
272
+ :param combine_input_name: The CombineInputFormat name.
273
+ :return: The decorator function.
274
+ """
275
+
276
+ def decorator(
277
+ permute_func: Callable[
278
+ [RunnerOutput, MoeQuantInfo, MoeRunnerConfig, dict], CombineInput
279
+ ]
280
+ ) -> Callable:
281
+ PermuteMethodPool.register_post_permute(
282
+ runner_backend_name, combine_input_name, permute_func
283
+ )
284
+ return permute_func
285
+
286
+ return decorator
@@ -0,0 +1,304 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, List, Optional
5
+
6
+ import torch
7
+
8
+ from sglang.srt.layers.moe.moe_runner.base import (
9
+ MoeQuantInfo,
10
+ MoeRunnerConfig,
11
+ MoeRunnerCore,
12
+ RunnerInput,
13
+ RunnerOutput,
14
+ register_post_permute,
15
+ register_pre_permute,
16
+ )
17
+ from sglang.srt.layers.moe.utils import MoeRunnerBackend
18
+ from sglang.srt.utils import dispose_tensor
19
+
20
+ if TYPE_CHECKING:
21
+ from sglang.srt.layers.moe.token_dispatcher.standard import (
22
+ StandardCombineInput,
23
+ StandardDispatchOutput,
24
+ )
25
+
26
+
27
+ # TODO(kaixih@nvidia): ideally we should merge this logic into
28
+ # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
29
+ @torch.compile
30
+ def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
31
+ temp = x.to(torch.float32).view(torch.int32)
32
+ exp = torch.bitwise_right_shift(temp, 23)
33
+ mant = torch.bitwise_and(temp, 0x7FFFFF)
34
+ is_ru = torch.logical_and(
35
+ torch.logical_and((mant > 0), (exp != 0xFE)),
36
+ ~torch.logical_and((exp == 0), (mant <= 0x400000)),
37
+ )
38
+ exp = torch.where(is_ru, exp + 1, exp)
39
+ new_x = exp.to(torch.uint8).view(torch.int)
40
+ return new_x.transpose(1, 2).contiguous().transpose(1, 2)
41
+
42
+
43
+ @dataclass
44
+ class DeepGemmRunnerInput(RunnerInput):
45
+ hidden_states: torch.Tensor
46
+ hidden_states_scale: torch.Tensor
47
+ masked_m: torch.Tensor
48
+ expected_m: int
49
+ use_masked_gemm: bool
50
+
51
+ @property
52
+ def runner_backend(self) -> MoeRunnerBackend:
53
+ return MoeRunnerBackend.DEEP_GEMM
54
+
55
+
56
+ @dataclass
57
+ class DeepGemmRunnerOutput(RunnerOutput):
58
+ hidden_states: torch.Tensor
59
+
60
+ @property
61
+ def runner_backend(self) -> MoeRunnerBackend:
62
+ return MoeRunnerBackend.DEEP_GEMM
63
+
64
+
65
+ @dataclass
66
+ class DeepGemmMoeQuantInfo(MoeQuantInfo):
67
+ w13_weight: torch.Tensor
68
+ w2_weight: torch.Tensor
69
+ use_fp8: bool
70
+ w13_scale: Optional[torch.Tensor] = None
71
+ w2_scale: Optional[torch.Tensor] = None
72
+ block_shape: Optional[List[int]] = None
73
+
74
+
75
+ class DeepGemmRunnerCore(MoeRunnerCore):
76
+ def __init__(self, config: MoeRunnerConfig):
77
+ super().__init__(config)
78
+ assert self.config.activation == "silu"
79
+
80
+ def run(
81
+ self,
82
+ runner_input: DeepGemmRunnerInput,
83
+ quant_info: DeepGemmMoeQuantInfo,
84
+ running_state: dict,
85
+ ) -> DeepGemmRunnerOutput:
86
+
87
+ if runner_input.use_masked_gemm:
88
+ hidden_states = self._run_masked_gemm(
89
+ runner_input,
90
+ quant_info,
91
+ running_state,
92
+ )
93
+ else:
94
+ hidden_states = self._run_contiguous_gemm(
95
+ runner_input,
96
+ quant_info,
97
+ running_state,
98
+ )
99
+ return DeepGemmRunnerOutput(hidden_states=hidden_states)
100
+
101
+ def _run_masked_gemm(
102
+ self,
103
+ runner_input: DeepGemmRunnerInput,
104
+ quant_info: DeepGemmMoeQuantInfo,
105
+ running_state: dict,
106
+ ) -> torch.Tensor:
107
+
108
+ from sglang.srt.layers.moe.ep_moe.kernels import (
109
+ silu_and_mul_masked_post_quant_fwd,
110
+ )
111
+ from sglang.srt.layers.quantization import deep_gemm_wrapper
112
+
113
+ hidden_states = runner_input.hidden_states
114
+ hidden_states_scale = runner_input.hidden_states_scale
115
+ masked_m = runner_input.masked_m
116
+ expected_m = runner_input.expected_m
117
+
118
+ w13_weight = quant_info.w13_weight
119
+ w2_weight = quant_info.w2_weight
120
+ w13_scale = quant_info.w13_scale
121
+ w2_scale = quant_info.w2_scale
122
+
123
+ hidden_states_device = running_state["hidden_states_device"]
124
+
125
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
126
+ b, s_mn, s_k = hidden_states_scale.shape
127
+ assert (
128
+ s_mn % 4 == 0 and s_k % 4 == 0
129
+ ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
130
+
131
+ # GroupGemm-0
132
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
133
+ hidden_states_scale = _cast_to_e8m0_with_rounding_up(hidden_states_scale)
134
+ else:
135
+ hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
136
+ hidden_states_scale
137
+ )
138
+
139
+ num_groups, m, k = hidden_states.shape
140
+ n = w13_weight.size(1)
141
+ gateup_output = torch.empty(
142
+ (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
143
+ )
144
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
145
+ (hidden_states, hidden_states_scale),
146
+ (w13_weight, w13_scale),
147
+ gateup_output,
148
+ masked_m,
149
+ expected_m,
150
+ )
151
+ dispose_tensor(hidden_states)
152
+
153
+ # Act
154
+ down_input = torch.empty(
155
+ (
156
+ gateup_output.shape[0],
157
+ gateup_output.shape[1],
158
+ gateup_output.shape[2] // 2,
159
+ ),
160
+ device=hidden_states_device,
161
+ dtype=torch.float8_e4m3fn,
162
+ )
163
+ scale_block_size = 128
164
+ down_input_scale = torch.empty(
165
+ (
166
+ gateup_output.shape[0],
167
+ gateup_output.shape[1],
168
+ gateup_output.shape[2] // 2 // scale_block_size,
169
+ ),
170
+ device=hidden_states_device,
171
+ dtype=torch.float32,
172
+ )
173
+ silu_and_mul_masked_post_quant_fwd(
174
+ gateup_output,
175
+ down_input,
176
+ down_input_scale,
177
+ scale_block_size,
178
+ masked_m,
179
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
180
+ )
181
+ del gateup_output
182
+
183
+ # GroupGemm-1
184
+ n = w2_weight.shape[1]
185
+
186
+ if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
187
+ down_input_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
188
+ down_input_scale
189
+ )
190
+
191
+ down_output = torch.empty(
192
+ (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
193
+ )
194
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
195
+ (down_input, down_input_scale),
196
+ (w2_weight, w2_scale),
197
+ down_output,
198
+ masked_m,
199
+ expected_m,
200
+ )
201
+ del down_input
202
+
203
+ return down_output
204
+
205
+ def _run_contiguous_gemm(
206
+ self,
207
+ runner_input: DeepGemmRunnerInput,
208
+ quant_info: DeepGemmMoeQuantInfo,
209
+ running_state: dict,
210
+ ) -> torch.Tensor:
211
+ pass
212
+
213
+ @property
214
+ def runner_backend(self) -> MoeRunnerBackend:
215
+ return MoeRunnerBackend.DEEP_GEMM
216
+
217
+
218
+ @register_pre_permute("standard", "deep_gemm")
219
+ def pre_permute_standard_to_deep_gemm(
220
+ dispatch_output: StandardDispatchOutput,
221
+ quant_info: DeepGemmMoeQuantInfo,
222
+ runner_config: MoeRunnerConfig,
223
+ running_state: dict,
224
+ ) -> DeepGemmRunnerInput:
225
+ from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess
226
+
227
+ hidden_states, topk_output = dispatch_output
228
+ topk_weights, topk_ids, _ = topk_output
229
+
230
+ hidden_states_shape = hidden_states.shape
231
+ hidden_states_dtype = hidden_states.dtype
232
+ hidden_states_device = hidden_states.device
233
+ hidden_states_ref = hidden_states
234
+
235
+ topk_weights, topk_ids = topk_weights, topk_ids
236
+
237
+ # PreReorder
238
+ masked_m, expected_m, src2dst, hidden_states, hidden_states_scale = (
239
+ moe_ep_deepgemm_preprocess(
240
+ topk_ids,
241
+ runner_config.num_local_experts,
242
+ hidden_states,
243
+ runner_config.top_k,
244
+ quant_info.block_shape,
245
+ )
246
+ )
247
+
248
+ dispose_tensor(hidden_states_ref)
249
+
250
+ running_state["topk_ids"] = topk_ids
251
+ running_state["topk_weights"] = topk_weights
252
+ running_state["hidden_states_shape"] = hidden_states_shape
253
+ running_state["hidden_states_dtype"] = hidden_states_dtype
254
+ running_state["hidden_states_device"] = hidden_states_device
255
+ running_state["src2dst"] = src2dst
256
+
257
+ return DeepGemmRunnerInput(
258
+ hidden_states=hidden_states,
259
+ hidden_states_scale=hidden_states_scale,
260
+ masked_m=masked_m,
261
+ expected_m=expected_m,
262
+ use_masked_gemm=True,
263
+ )
264
+
265
+
266
+ @register_post_permute("deep_gemm", "standard")
267
+ def post_permute_deep_gemm_to_standard(
268
+ runner_output: DeepGemmRunnerOutput,
269
+ quant_info: DeepGemmMoeQuantInfo,
270
+ runner_config: MoeRunnerConfig,
271
+ running_state: dict,
272
+ ) -> StandardCombineInput:
273
+ from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
274
+ from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
275
+
276
+ hidden_states_shape = running_state["hidden_states_shape"]
277
+ hidden_states_dtype = running_state["hidden_states_dtype"]
278
+ hidden_states_device = running_state["hidden_states_device"]
279
+ src2dst = running_state["src2dst"]
280
+ topk_ids = running_state["topk_ids"]
281
+ topk_weights = running_state["topk_weights"]
282
+
283
+ output = torch.empty(
284
+ hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
285
+ )
286
+ post_reorder_triton_kernel[(hidden_states_shape[0],)](
287
+ runner_output.hidden_states,
288
+ output,
289
+ src2dst,
290
+ topk_ids,
291
+ topk_weights,
292
+ runner_config.top_k,
293
+ hidden_states_shape[1],
294
+ BLOCK_SIZE=512,
295
+ )
296
+
297
+ dispose_tensor(runner_output.hidden_states)
298
+
299
+ if runner_config.routed_scaling_factor is not None:
300
+ output *= runner_config.routed_scaling_factor
301
+
302
+ return StandardCombineInput(
303
+ hidden_states=output,
304
+ )
@@ -0,0 +1,83 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ from typing import TYPE_CHECKING
6
+
7
+ from sglang.srt.layers.moe.moe_runner.base import (
8
+ FusedOpPool,
9
+ MoeRunnerConfig,
10
+ PermuteMethodPool,
11
+ )
12
+ from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmRunnerCore
13
+ from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
14
+ from sglang.srt.layers.moe.utils import get_moe_a2a_backend
15
+
16
+ if TYPE_CHECKING:
17
+ from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo
18
+ from sglang.srt.layers.moe.token_dispatcher.base import CombineInput, DispatchOutput
19
+ from sglang.srt.layers.moe.utils import MoeRunnerBackend
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class MoeRunner:
25
+
26
+ def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig):
27
+ self.runner_backend = runner_backend
28
+ self.config = config
29
+
30
+ self.fused_func = None
31
+
32
+ if runner_backend.is_triton():
33
+ self.runner_core = TritonRunnerCore(config)
34
+ elif runner_backend.is_deep_gemm():
35
+ self.runner_core = DeepGemmRunnerCore(config)
36
+ else:
37
+ raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
38
+
39
+ a2a_backend_name = get_moe_a2a_backend().value
40
+ runner_backend_name = runner_backend.value
41
+
42
+ self.fused_func = FusedOpPool.get_fused_func(
43
+ a2a_backend_name, runner_backend_name
44
+ )
45
+
46
+ SGLANG_CI_DISABLE_MOE_FUSED_FUNC = os.environ.get(
47
+ "SGLANG_CI_DISABLE_MOE_FUSED_FUNC", "0"
48
+ )
49
+ if SGLANG_CI_DISABLE_MOE_FUSED_FUNC == "1":
50
+ logger.info(
51
+ "SGLANG_CI_DISABLE_MOE_FUSED_FUNC is set to 1, disabling fused func"
52
+ )
53
+ self.fused_func = None
54
+
55
+ def run(
56
+ self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo
57
+ ) -> CombineInput:
58
+
59
+ if self.fused_func is not None:
60
+ return self.fused_func(dispatch_output, quant_info, self.config)
61
+
62
+ dispatch_format = dispatch_output.format.value
63
+ runner_format = self.runner_core.runner_backend.value
64
+ self.pre_permute_func = PermuteMethodPool.get_pre_permute(
65
+ dispatch_format, runner_format
66
+ )
67
+
68
+ running_state = {}
69
+ runner_input = self.pre_permute_func(
70
+ dispatch_output, quant_info, self.config, running_state
71
+ )
72
+ runner_output = self.runner_core.run(runner_input, quant_info, running_state)
73
+
74
+ runner_format = self.runner_core.runner_backend.value
75
+ combine_format = dispatch_output.format.value
76
+ self.post_permute_func = PermuteMethodPool.get_post_permute(
77
+ runner_format, combine_format
78
+ )
79
+ combine_input = self.post_permute_func(
80
+ runner_output, quant_info, self.config, running_state
81
+ )
82
+
83
+ return combine_input