sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,8 @@ from torch.nn.parameter import Parameter
9
9
 
10
10
  from sglang.srt.custom_op import CustomOp
11
11
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
12
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
13
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
12
14
  from sglang.srt.layers.quantization.base_config import (
13
15
  FusedMoEMethodBase,
14
16
  LinearMethodBase,
@@ -24,8 +26,10 @@ from sglang.srt.utils import (
24
26
  )
25
27
 
26
28
  if TYPE_CHECKING:
27
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
28
- from sglang.srt.layers.moe.topk import TopKOutput
29
+ from sglang.srt.layers.moe.token_dispatcher import (
30
+ CombineInput,
31
+ StandardDispatchOutput,
32
+ )
29
33
 
30
34
  has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
31
35
 
@@ -155,7 +159,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
155
159
  layer: torch.nn.Module,
156
160
  num_experts: int,
157
161
  hidden_size: int,
158
- intermediate_size: int,
162
+ intermediate_size_per_partition: int,
159
163
  params_dtype: torch.dtype,
160
164
  with_bias: bool = False,
161
165
  **extra_weight_attrs,
@@ -163,7 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
163
167
  self.with_bias = with_bias
164
168
 
165
169
  # Fused gate_up_proj (column parallel)
166
- w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
170
+ w13_weight_n, w13_weight_k = 2 * intermediate_size_per_partition, hidden_size
167
171
  if self.use_triton_kernels:
168
172
  w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
169
173
  w13_weight = torch.nn.Parameter(
@@ -175,7 +179,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
175
179
 
176
180
  if self.with_bias:
177
181
  w13_weight_bias = torch.nn.Parameter(
178
- torch.empty(num_experts, 2 * intermediate_size, dtype=torch.float32),
182
+ torch.empty(
183
+ num_experts,
184
+ 2 * intermediate_size_per_partition,
185
+ dtype=torch.float32,
186
+ ),
179
187
  requires_grad=False,
180
188
  )
181
189
  layer.register_parameter("w13_weight_bias", w13_weight_bias)
@@ -184,7 +192,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
184
192
  # down_proj (row parallel)
185
193
  w2_weight_n, w2_weight_k = (
186
194
  hidden_size,
187
- intermediate_size,
195
+ intermediate_size_per_partition,
188
196
  )
189
197
  if self.use_triton_kernels:
190
198
  w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
@@ -222,33 +230,40 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
222
230
 
223
231
  return
224
232
 
233
+ def create_moe_runner(
234
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
235
+ ):
236
+ self.moe_runner_config = moe_runner_config
237
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
238
+
225
239
  def apply(
226
240
  self,
227
241
  layer: torch.nn.Module,
228
- x: torch.Tensor,
229
- topk_output: TopKOutput,
230
- moe_runner_config: MoeRunnerConfig,
231
- ) -> torch.Tensor:
242
+ dispatch_output: StandardDispatchOutput,
243
+ ) -> CombineInput:
232
244
 
233
245
  return self.forward(
234
- x=x,
235
246
  layer=layer,
236
- topk_output=topk_output,
237
- moe_runner_config=moe_runner_config,
247
+ dispatch_output=dispatch_output,
238
248
  )
239
249
 
240
250
  def forward_cuda(
241
251
  self,
242
252
  layer: torch.nn.Module,
243
- x: torch.Tensor,
244
- topk_output: TopKOutput,
245
- moe_runner_config: MoeRunnerConfig,
246
- ) -> torch.Tensor:
253
+ dispatch_output: StandardDispatchOutput,
254
+ ) -> CombineInput:
255
+
256
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
257
+
258
+ x = dispatch_output.hidden_states
259
+ topk_output = dispatch_output.topk_output
260
+
261
+ moe_runner_config = self.moe_runner_config
247
262
 
248
263
  if self.use_triton_kernels:
249
264
  if self.with_bias:
250
265
  assert self.triton_kernel_moe_with_bias_forward is not None
251
- return self.triton_kernel_moe_with_bias_forward(
266
+ output = self.triton_kernel_moe_with_bias_forward(
252
267
  hidden_states=x,
253
268
  w1=layer.w13_weight,
254
269
  w2=layer.w2_weight,
@@ -261,13 +276,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
261
276
  )
262
277
  else:
263
278
  assert self.triton_kernel_moe_forward is not None
264
- return self.triton_kernel_moe_forward(
279
+ output = self.triton_kernel_moe_forward(
265
280
  hidden_states=x,
266
281
  w1=layer.w13_weight,
267
282
  w2=layer.w2_weight,
268
283
  topk_output=topk_output,
269
284
  moe_runner_config=moe_runner_config,
270
285
  )
286
+ return StandardCombineInput(hidden_states=output)
271
287
  else:
272
288
  if _use_aiter:
273
289
  assert not moe_runner_config.no_combine, "unsupported"
@@ -284,7 +300,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
284
300
  topk_weights = torch.ones_like(
285
301
  topk_weights, dtype=torch.float32
286
302
  ) # topk_weights must be FP32 (float32)
287
- return fused_moe(
303
+ output = fused_moe(
288
304
  x,
289
305
  layer.w13_weight,
290
306
  layer.w2_weight,
@@ -296,28 +312,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
296
312
  else ActivationType.Gelu
297
313
  ),
298
314
  )
315
+ return StandardCombineInput(hidden_states=output)
299
316
  else:
300
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
301
- fused_experts,
302
- )
303
317
 
304
- return fused_experts(
305
- hidden_states=x,
306
- w1=layer.w13_weight,
307
- w2=layer.w2_weight,
308
- b1=getattr(layer, "w13_weight_bias", None),
318
+ quant_info = TritonMoeQuantInfo(
319
+ w13_weight=layer.w13_weight,
320
+ w2_weight=layer.w2_weight,
321
+ b13=getattr(layer, "w13_weight_bias", None),
309
322
  b2=getattr(layer, "w2_weight_bias", None),
310
- topk_output=topk_output,
311
- moe_runner_config=moe_runner_config,
312
323
  )
324
+ return self.runner.run(dispatch_output, quant_info)
313
325
 
314
326
  def forward_cpu(
315
327
  self,
316
328
  layer: torch.nn.Module,
317
- x: torch.Tensor,
318
- topk_output: TopKOutput,
319
- moe_runner_config: MoeRunnerConfig,
320
- ) -> torch.Tensor:
329
+ dispatch_output: StandardDispatchOutput,
330
+ ) -> CombineInput:
331
+
332
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
333
+
334
+ x = dispatch_output.hidden_states
335
+ topk_output = dispatch_output.topk_output
336
+
337
+ moe_runner_config = self.moe_runner_config
338
+
321
339
  assert (
322
340
  moe_runner_config.activation == "silu"
323
341
  ), f"activation = {moe_runner_config.activation} is not supported."
@@ -332,7 +350,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
332
350
  x, topk_weights = apply_topk_weights_cpu(
333
351
  moe_runner_config.apply_router_weight_on_input, topk_weights, x
334
352
  )
335
- return torch.ops.sgl_kernel.fused_experts_cpu(
353
+ output = torch.ops.sgl_kernel.fused_experts_cpu(
336
354
  x,
337
355
  layer.w13_weight,
338
356
  layer.w2_weight,
@@ -348,33 +366,103 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
348
366
  None, # a2_scale
349
367
  True, # is_vnni
350
368
  )
369
+ return StandardCombineInput(hidden_states=output)
351
370
  else:
352
371
  from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
353
372
 
354
- return moe_forward_native(
373
+ output = moe_forward_native(
355
374
  layer,
356
375
  x,
357
376
  topk_output,
358
377
  moe_runner_config,
359
378
  )
379
+ return StandardCombineInput(hidden_states=output)
360
380
 
361
381
  def forward_npu(
362
382
  self,
363
383
  layer: torch.nn.Module,
364
- x: torch.Tensor,
365
- topk_output: TopKOutput,
366
- moe_runner_config: MoeRunnerConfig,
367
- ) -> torch.Tensor:
368
- from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
384
+ dispatch_output: StandardDispatchOutput,
385
+ ) -> CombineInput:
386
+
387
+ import torch_npu
388
+
389
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
390
+
391
+ x = dispatch_output.hidden_states
392
+ topk_weights, topk_ids, _ = dispatch_output.topk_output
393
+
394
+ original_dtype = x.dtype
395
+ num_tokens = x.shape[0]
396
+ topk_weights = topk_weights.to(x.dtype)
397
+ topk_ids = topk_ids.to(torch.int32)
398
+ num_experts = layer.num_experts
399
+ top_k = layer.top_k
400
+ row_idx_len = num_tokens * top_k
401
+ row_idx = (
402
+ torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device)
403
+ .view(top_k, -1)
404
+ .permute(1, 0)
405
+ .contiguous()
406
+ )
369
407
 
370
- return moe_forward_native(
371
- layer,
372
- x,
373
- topk_output,
374
- moe_runner_config,
408
+ hidden_states, expanded_row_idx, expanded_expert_idx = (
409
+ torch_npu.npu_moe_init_routing(
410
+ x, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens
411
+ )
412
+ )
413
+
414
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
415
+ expanded_expert_idx, num_experts
375
416
  )
376
417
 
377
- def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
418
+ expert_tokens = expert_tokens.to(torch.int64)
419
+ if layer.w13_weight.shape[-1] == layer.hidden_size:
420
+ w13 = layer.w13_weight.transpose(1, 2)
421
+ w2 = layer.w2_weight.transpose(1, 2)
422
+
423
+ # gmm1: gate_up_proj
424
+ hidden_states = torch_npu.npu_grouped_matmul(
425
+ x=[hidden_states],
426
+ weight=[w13],
427
+ split_item=2,
428
+ group_list_type=0,
429
+ group_type=0,
430
+ group_list=expert_tokens,
431
+ output_dtype=original_dtype,
432
+ )[0]
433
+
434
+ # act_fn:
435
+ if self.moe_runner_config.activation == "silu":
436
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
437
+ else:
438
+ from sglang.srt.layers.activation import GeluAndMul
439
+
440
+ hidden_states = GeluAndMul()(hidden_states)
441
+
442
+ # gmm2: down_proj
443
+ hidden_states = torch_npu.npu_grouped_matmul(
444
+ x=[hidden_states],
445
+ weight=[w2],
446
+ split_item=2,
447
+ group_list_type=0,
448
+ group_type=0,
449
+ group_list=expert_tokens,
450
+ output_dtype=original_dtype,
451
+ )[0]
452
+
453
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
454
+ hidden_states,
455
+ skip1=None,
456
+ skip2=None,
457
+ bias=None,
458
+ scales=topk_weights,
459
+ expanded_src_to_dst_row=expanded_row_idx,
460
+ export_for_source_row=topk_ids,
461
+ )
462
+
463
+ return StandardCombineInput(hidden_states=final_hidden_states)
464
+
465
+ def forward_tpu(self, *args, **kwargs) -> CombineInput:
378
466
  raise NotImplementedError("The TPU backend currently does not support MoE.")
379
467
 
380
468
  forward_native = forward_cpu
@@ -17,12 +17,14 @@ from sglang.srt.layers.quantization.base_config import (
17
17
  from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
18
18
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
19
19
  from sglang.srt.layers.quantization.utils import is_layer_skipped
20
- from sglang.srt.utils import set_weight_attrs
20
+ from sglang.srt.utils import is_npu, set_weight_attrs
21
21
 
22
22
  if TYPE_CHECKING:
23
23
  from sglang.srt.layers.moe import MoeRunnerConfig
24
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
25
- from sglang.srt.layers.moe.topk import StandardTopKOutput
24
+ from sglang.srt.layers.moe.token_dispatcher import (
25
+ CombineInput,
26
+ StandardDispatchOutput,
27
+ )
26
28
 
27
29
  ACTIVATION_SCHEMES = ["static", "dynamic"]
28
30
 
@@ -91,9 +93,7 @@ class W4AFp8Config(QuantizationConfig):
91
93
  self, layer: torch.nn.Module, prefix: str
92
94
  ) -> Optional[QuantizeMethodBase]:
93
95
  from sglang.srt.layers.linear import LinearBase
94
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
95
96
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
96
- from sglang.srt.managers.schedule_batch import global_server_args_dict
97
97
 
98
98
  if isinstance(layer, LinearBase):
99
99
  if is_layer_skipped(prefix, self.ignored_layers):
@@ -130,10 +130,10 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
130
130
 
131
131
  def create_weights(
132
132
  self,
133
- layer: EPMoE,
133
+ layer: Module,
134
134
  num_experts: int,
135
135
  hidden_size: int,
136
- intermediate_size: int,
136
+ intermediate_size_per_partition: int,
137
137
  params_dtype: torch.dtype,
138
138
  **extra_weight_attrs,
139
139
  ):
@@ -145,7 +145,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
145
145
  w13_weight = torch.nn.Parameter(
146
146
  torch.empty(
147
147
  num_experts,
148
- intermediate_size * 2,
148
+ intermediate_size_per_partition * 2,
149
149
  hidden_size // 2,
150
150
  dtype=torch.int8,
151
151
  ),
@@ -159,7 +159,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
159
159
  torch.empty(
160
160
  num_experts,
161
161
  hidden_size,
162
- intermediate_size // 2,
162
+ intermediate_size_per_partition // 2,
163
163
  dtype=torch.int8,
164
164
  ),
165
165
  requires_grad=False,
@@ -173,7 +173,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
173
173
  w13_weight_scale = torch.nn.Parameter(
174
174
  torch.zeros(
175
175
  num_experts,
176
- 2 * intermediate_size,
176
+ 2 * intermediate_size_per_partition,
177
177
  hidden_size // self.quant_config.group_size,
178
178
  dtype=torch.float32,
179
179
  ),
@@ -186,7 +186,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
186
186
  torch.zeros(
187
187
  num_experts,
188
188
  hidden_size,
189
- intermediate_size // self.quant_config.group_size,
189
+ intermediate_size_per_partition // self.quant_config.group_size,
190
190
  dtype=torch.float32,
191
191
  ),
192
192
  requires_grad=False,
@@ -220,13 +220,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
220
220
  )
221
221
  self.c_strides1 = torch.full(
222
222
  (num_experts, 3),
223
- 2 * intermediate_size,
223
+ 2 * intermediate_size_per_partition,
224
224
  device=device,
225
225
  dtype=torch.int64,
226
226
  )
227
227
  self.a_strides2 = torch.full(
228
228
  (num_experts, 3),
229
- intermediate_size,
229
+ intermediate_size_per_partition,
230
230
  device=device,
231
231
  dtype=torch.int64,
232
232
  )
@@ -282,30 +282,26 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
282
282
  )
283
283
  layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
284
284
 
285
+ def create_moe_runner(
286
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
287
+ ):
288
+ self.moe_runner_config = moe_runner_config
289
+
285
290
  def apply(
286
291
  self,
287
- layer: EPMoE,
288
- x: torch.Tensor,
289
- topk_output: StandardTopKOutput,
290
- moe_runner_config: MoeRunnerConfig,
291
- ) -> torch.Tensor:
292
+ layer: Module,
293
+ dispatch_output: StandardDispatchOutput,
294
+ ) -> CombineInput:
292
295
 
293
- # TODO(ch-wan): move it out of this class
294
296
  from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
297
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
298
+
299
+ x = dispatch_output.hidden_states
300
+ topk_output = dispatch_output.topk_output
295
301
 
296
302
  topk_weights, topk_ids, _ = topk_output
297
- local_topk_ids = topk_ids
298
- if get_moe_expert_parallel_world_size() > 1:
299
- local_topk_ids = torch.where(
300
- topk_ids == -1,
301
- layer.num_experts,
302
- topk_ids,
303
- )
304
303
 
305
304
  output = cutlass_w4a8_moe(
306
- layer.start_expert_id,
307
- layer.end_expert_id,
308
- layer.num_experts,
309
305
  x,
310
306
  layer.w13_weight,
311
307
  layer.w2_weight,
@@ -313,7 +309,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
313
309
  layer.w2_weight_scale_inv,
314
310
  topk_weights,
315
311
  topk_ids,
316
- local_topk_ids,
317
312
  self.a_strides1,
318
313
  self.b_strides1,
319
314
  self.c_strides1,
@@ -328,6 +323,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
328
323
  layer.w13_input_scale,
329
324
  layer.w2_input_scale,
330
325
  )
331
- if moe_runner_config.routed_scaling_factor is not None:
332
- output *= moe_runner_config.routed_scaling_factor
333
- return output
326
+ if self.moe_runner_config.routed_scaling_factor is not None:
327
+ output *= self.moe_runner_config.routed_scaling_factor
328
+ return StandardCombineInput(hidden_states=output)
@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
5
5
  import torch
6
6
  from torch.nn.parameter import Parameter
7
7
 
8
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
9
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
8
10
  from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
9
11
  from sglang.srt.layers.quantization.base_config import (
10
12
  FusedMoEMethodBase,
@@ -26,8 +28,10 @@ from sglang.srt.layers.quantization.fp8_utils import (
26
28
  from sglang.srt.utils import set_weight_attrs
27
29
 
28
30
  if TYPE_CHECKING:
29
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
30
- from sglang.srt.layers.moe.topk import StandardTopKOutput
31
+ from sglang.srt.layers.moe.token_dispatcher import (
32
+ CombineInput,
33
+ StandardDispatchOutput,
34
+ )
31
35
 
32
36
  _is_fp8_fnuz = is_fp8_fnuz()
33
37
 
@@ -209,7 +213,7 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
209
213
  layer: torch.nn.Module,
210
214
  num_experts: int,
211
215
  hidden_size: int,
212
- intermediate_size: int,
216
+ intermediate_size_per_partition: int,
213
217
  params_dtype: torch.dtype,
214
218
  **extra_weight_attrs,
215
219
  ):
@@ -218,7 +222,10 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
218
222
  # WEIGHTS
219
223
  w13_weight = torch.nn.Parameter(
220
224
  torch.empty(
221
- num_experts, 2 * intermediate_size, hidden_size, dtype=fp8_dtype
225
+ num_experts,
226
+ 2 * intermediate_size_per_partition,
227
+ hidden_size,
228
+ dtype=fp8_dtype,
222
229
  ),
223
230
  requires_grad=False,
224
231
  )
@@ -226,14 +233,21 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
226
233
  set_weight_attrs(w13_weight, extra_weight_attrs)
227
234
 
228
235
  w2_weight = torch.nn.Parameter(
229
- torch.empty(num_experts, hidden_size, intermediate_size, dtype=fp8_dtype),
236
+ torch.empty(
237
+ num_experts,
238
+ hidden_size,
239
+ intermediate_size_per_partition,
240
+ dtype=fp8_dtype,
241
+ ),
230
242
  requires_grad=False,
231
243
  )
232
244
  layer.register_parameter("w2_weight", w2_weight)
233
245
  set_weight_attrs(w2_weight, extra_weight_attrs)
234
246
 
235
247
  w13_weight_scale = torch.nn.Parameter(
236
- torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
248
+ torch.ones(
249
+ num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
250
+ ),
237
251
  requires_grad=False,
238
252
  )
239
253
  w2_weight_scale = torch.nn.Parameter(
@@ -266,25 +280,26 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
266
280
  layer.w2_weight_scale.data, requires_grad=False
267
281
  )
268
282
 
283
+ def create_moe_runner(
284
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
285
+ ):
286
+ self.moe_runner_config = moe_runner_config
287
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
288
+
269
289
  def apply(
270
290
  self,
271
291
  layer: torch.nn.Module,
272
- x: torch.Tensor,
273
- topk_output: StandardTopKOutput,
274
- moe_runner_config: MoeRunnerConfig,
275
- ) -> torch.Tensor:
276
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
292
+ dispatch_output: StandardDispatchOutput,
293
+ ) -> CombineInput:
277
294
 
278
- return fused_experts(
279
- x,
280
- layer.w13_weight,
281
- layer.w2_weight,
282
- topk_output=topk_output,
283
- moe_runner_config=moe_runner_config,
295
+ quant_info = TritonMoeQuantInfo(
296
+ w13_weight=layer.w13_weight,
297
+ w2_weight=layer.w2_weight,
284
298
  use_fp8_w8a8=True,
285
299
  per_channel_quant=True,
286
- w1_scale=(layer.w13_weight_scale),
287
- w2_scale=(layer.w2_weight_scale),
288
- a1_scale=layer.w13_input_scale,
300
+ w13_scale=layer.w13_weight_scale,
301
+ w2_scale=layer.w2_weight_scale,
302
+ a13_scale=layer.w13_input_scale,
289
303
  a2_scale=layer.w2_input_scale,
290
304
  )
305
+ return self.runner.run(dispatch_output, quant_info)