sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,20 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
+ from contextlib import nullcontext
4
5
  from dataclasses import dataclass
5
- from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
6
+ from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union
6
7
 
7
8
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
8
- from sglang.srt.layers.moe import DeepEPMode, get_deepep_config, is_tbo_enabled
9
- from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
9
+ from sglang.srt.layers.moe.token_dispatcher.base import (
10
10
  BaseDispatcher,
11
11
  BaseDispatcherConfig,
12
+ CombineInput,
13
+ CombineInputFormat,
12
14
  DispatchOutput,
13
15
  DispatchOutputFormat,
14
16
  )
17
+ from sglang.srt.layers.moe.utils import DeepEPMode, get_deepep_config, is_tbo_enabled
15
18
  from sglang.srt.layers.quantization import deep_gemm_wrapper
16
19
  from sglang.srt.utils import (
17
20
  get_bool_env_var,
@@ -23,6 +26,9 @@ from sglang.srt.utils import (
23
26
 
24
27
  _is_npu = is_npu()
25
28
 
29
+ if TYPE_CHECKING:
30
+ from sglang.srt.single_batch_overlap import CombineOverlapArgs
31
+
26
32
  try:
27
33
  from deep_ep import Buffer, Config
28
34
 
@@ -40,11 +46,6 @@ from enum import Enum, IntEnum, auto
40
46
  import torch
41
47
  import torch.distributed as dist
42
48
 
43
- from sglang.srt.layers.moe.ep_moe.kernels import (
44
- deepep_permute_triton_kernel,
45
- deepep_post_reorder_triton_kernel,
46
- deepep_run_moe_deep_preprocess,
47
- )
48
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
50
 
50
51
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
@@ -56,6 +57,7 @@ class DeepEPNormalOutput(NamedTuple):
56
57
  """DeepEP normal dispatch output."""
57
58
 
58
59
  hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
60
+ # hidden_states_scale
59
61
  topk_idx: torch.Tensor
60
62
  topk_weights: torch.Tensor
61
63
  num_recv_tokens_per_expert: List[int]
@@ -79,24 +81,32 @@ class DeepEPLLOutput(NamedTuple):
79
81
  return DispatchOutputFormat.DEEPEP_LL
80
82
 
81
83
 
82
- class AscendDeepEPLLOutput(NamedTuple):
83
- """AscendDeepEP low latency dispatch output."""
84
+ assert isinstance(DeepEPNormalOutput, DispatchOutput)
85
+ assert isinstance(DeepEPLLOutput, DispatchOutput)
84
86
 
85
- hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
86
- topk_idx: torch.Tensor
87
- topk_weights: torch.Tensor
88
- masked_m: torch.Tensor
89
- seg_indptr: torch.Tensor
90
- expected_m: int
87
+
88
+ class DeepEPNormalCombineInput(NamedTuple):
89
+ """DeepEP normal combine input."""
90
+
91
+ pass
91
92
 
92
93
  @property
93
- def format(self) -> DispatchOutputFormat:
94
- return DispatchOutputFormat.ASCENT_LL
94
+ def format(self) -> CombineInputFormat:
95
+ return CombineInputFormat.DEEPEP_NORMAL
95
96
 
96
97
 
97
- assert isinstance(DeepEPNormalOutput, DispatchOutput)
98
- assert isinstance(DeepEPLLOutput, DispatchOutput)
99
- assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
98
+ class DeepEPLLCombineInput(NamedTuple):
99
+ """DeepEP low latency combine input."""
100
+
101
+ pass
102
+
103
+ @property
104
+ def format(self) -> CombineInputFormat:
105
+ return CombineInputFormat.DEEPEP_LL
106
+
107
+
108
+ assert isinstance(DeepEPNormalCombineInput, CombineInput)
109
+ assert isinstance(DeepEPLLCombineInput, CombineInput)
100
110
 
101
111
 
102
112
  class DeepEPDispatchMode(IntEnum):
@@ -158,10 +168,19 @@ class DeepEPBuffer:
158
168
  num_rdma_bytes,
159
169
  )
160
170
 
171
+ # We should calculate num_qps_per_rank consistently with DeepEP's test script logic:
161
172
  if deepep_mode == DeepEPMode.NORMAL:
162
- num_qps_per_rank = DeepEPConfig.get_instance().num_sms // 2
163
- elif deepep_mode in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]:
173
+ # refer: https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py#L235
174
+ num_qps_per_rank = DeepEPConfig.get_instance().num_sms
175
+ elif deepep_mode == DeepEPMode.LOW_LATENCY:
176
+ # refer: https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_low_latency.py#L176
164
177
  num_qps_per_rank = num_experts // group.size()
178
+ elif deepep_mode == DeepEPMode.AUTO:
179
+ # low-latency and normal mode all need run
180
+ # refer: https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py#L235
181
+ num_qps_per_rank = max(
182
+ DeepEPConfig.get_instance().num_sms, num_experts // group.size()
183
+ )
165
184
  else:
166
185
  raise NotImplementedError
167
186
 
@@ -272,12 +291,16 @@ class _DeepEPDispatcherImplBase:
272
291
  self.num_max_dispatch_tokens_per_rank = get_int_env_var(
273
292
  "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128
274
293
  )
294
+ # DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024
295
+ # and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
296
+ assert self.num_max_dispatch_tokens_per_rank <= 1024
275
297
 
276
298
  self.handle = None
277
299
 
278
300
  def dispatch_a(
279
301
  self,
280
302
  hidden_states: torch.Tensor,
303
+ input_global_scale: Optional[torch.Tensor],
281
304
  topk_idx: torch.Tensor,
282
305
  topk_weights: torch.Tensor,
283
306
  ):
@@ -291,6 +314,7 @@ class _DeepEPDispatcherImplBase:
291
314
  hidden_states: torch.Tensor,
292
315
  topk_idx: torch.Tensor,
293
316
  topk_weights: torch.Tensor,
317
+ overlap_args: Optional["CombineOverlapArgs"],
294
318
  ):
295
319
  raise NotImplementedError
296
320
 
@@ -311,6 +335,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
311
335
  def dispatch_a(
312
336
  self,
313
337
  hidden_states: torch.Tensor,
338
+ input_global_scale: Optional[torch.Tensor],
314
339
  topk_idx: torch.Tensor,
315
340
  topk_weights: torch.Tensor,
316
341
  ):
@@ -408,8 +433,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
408
433
  hidden_states: torch.Tensor,
409
434
  topk_idx: torch.Tensor,
410
435
  topk_weights: torch.Tensor,
436
+ overlap_args: Optional["CombineOverlapArgs"],
411
437
  ):
412
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
438
+ from sglang.srt.layers.moe.ep_moe.kernels import (
439
+ deepep_post_reorder_triton_kernel,
440
+ )
441
+
442
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
413
443
  output = hidden_states
414
444
  else:
415
445
  if hidden_states.shape[0] > 0:
@@ -479,10 +509,12 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
479
509
  https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
480
510
  """
481
511
  self.return_recv_hook = return_recv_hook
512
+ self.device_module = torch.get_device_module()
482
513
 
483
514
  def dispatch_a(
484
515
  self,
485
516
  hidden_states: torch.Tensor,
517
+ input_global_scale: Optional[torch.Tensor],
486
518
  topk_idx: torch.Tensor,
487
519
  topk_weights: torch.Tensor,
488
520
  ):
@@ -494,8 +526,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
494
526
  ) // self.num_experts
495
527
  hidden_states, masked_m, event, hook = self._dispatch_core(
496
528
  hidden_states,
529
+ input_global_scale,
497
530
  topk_idx,
498
- use_fp8=True,
499
531
  )
500
532
  return (
501
533
  hidden_states,
@@ -523,39 +555,41 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
523
555
  masked_m
524
556
  )
525
557
 
526
- if _is_npu:
527
- deepep_output = AscendDeepEPLLOutput(
528
- hidden_states,
529
- topk_idx,
530
- topk_weights,
531
- masked_m,
532
- self.handle[1],
533
- expected_m,
534
- )
535
- else:
536
- deepep_output = DeepEPLLOutput(
537
- hidden_states,
538
- topk_idx,
539
- topk_weights,
540
- masked_m,
541
- expected_m,
542
- )
558
+ deepep_output = DeepEPLLOutput(
559
+ hidden_states,
560
+ topk_idx,
561
+ topk_weights,
562
+ masked_m,
563
+ expected_m,
564
+ )
543
565
  return deepep_output
544
566
 
545
567
  def _dispatch_core(
546
568
  self,
547
569
  hidden_states: torch.Tensor,
570
+ input_global_scale: Optional[torch.Tensor],
548
571
  topk_idx: torch.Tensor,
549
- use_fp8: bool = False,
550
572
  ):
573
+ use_nvfp4 = use_fp8 = False
574
+ if input_global_scale is not None:
575
+ use_nvfp4 = True
576
+ elif not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"):
577
+ use_fp8 = True
578
+
551
579
  buffer = self._get_buffer()
552
- packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
580
+ packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
553
581
  buffer.low_latency_dispatch(
554
582
  hidden_states,
555
583
  topk_idx,
556
584
  self.num_max_dispatch_tokens_per_rank,
557
585
  self.num_experts,
558
586
  use_fp8=use_fp8,
587
+ **(dict(use_nvfp4=True) if use_nvfp4 else dict()),
588
+ **(
589
+ dict(x_global_scale=input_global_scale)
590
+ if input_global_scale is not None
591
+ else dict()
592
+ ),
559
593
  async_finish=not self.return_recv_hook,
560
594
  return_recv_hook=self.return_recv_hook,
561
595
  round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
@@ -564,23 +598,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
564
598
  and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
565
599
  )
566
600
  )
567
- return packed_recv_hidden, packed_recv_count, event, hook
601
+ return packed_recv_hidden, self.packed_recv_count, event, hook
568
602
 
569
603
  def combine_a(
570
604
  self,
571
605
  hidden_states: torch.Tensor,
572
606
  topk_idx: torch.Tensor,
573
607
  topk_weights: torch.Tensor,
608
+ overlap_args: Optional["CombineOverlapArgs"],
574
609
  ):
575
610
  hidden_states, event, hook = self._combine_core(
576
611
  hidden_states,
577
612
  topk_idx,
578
613
  topk_weights,
614
+ overlap_args=overlap_args,
579
615
  )
580
- return hidden_states, event, hook
616
+ return hidden_states, event, hook, overlap_args
581
617
 
582
- def combine_b(self, hidden_states, event, hook):
618
+ def combine_b(self, hidden_states, event, hook, overlap_args):
583
619
  hook() if self.return_recv_hook else event.current_stream_wait()
620
+
621
+ if overlap_args is not None:
622
+ self.device_module.current_stream().wait_stream(overlap_args.stream)
623
+
584
624
  return hidden_states
585
625
 
586
626
  def _combine_core(
@@ -588,17 +628,35 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
588
628
  hidden_states: torch.Tensor,
589
629
  topk_idx: torch.Tensor,
590
630
  topk_weights: torch.Tensor,
631
+ overlap_args: Optional["CombineOverlapArgs"],
591
632
  ):
592
633
  buffer = self._get_buffer()
593
- combined_hidden_states, event, hook = buffer.low_latency_combine(
594
- hidden_states,
595
- topk_idx,
596
- topk_weights,
597
- self.handle,
598
- async_finish=not self.return_recv_hook,
599
- return_recv_hook=self.return_recv_hook,
600
- )
601
- self.handle = None
634
+
635
+ ctx = nullcontext()
636
+ if overlap_args is not None:
637
+ overlap_args.stream.wait_event(overlap_args.wait_event)
638
+ ctx = torch.cuda.stream(overlap_args.stream)
639
+
640
+ with ctx:
641
+ combined_hidden_states, event, hook = buffer.low_latency_combine(
642
+ x=hidden_states,
643
+ topk_idx=topk_idx,
644
+ topk_weights=topk_weights,
645
+ handle=self.handle,
646
+ async_finish=not self.return_recv_hook,
647
+ return_recv_hook=self.return_recv_hook,
648
+ **(
649
+ dict(
650
+ overlap=overlap_args.overlap,
651
+ src_signals=overlap_args.signal,
652
+ src_signal_expect_value=overlap_args.threshold,
653
+ )
654
+ if overlap_args is not None
655
+ else {}
656
+ ),
657
+ )
658
+
659
+ self.packed_recv_count = self.handle = None
602
660
  return combined_hidden_states, event, hook
603
661
 
604
662
  def _get_buffer(self):
@@ -669,6 +727,7 @@ class DeepEPDispatcher(BaseDispatcher):
669
727
  def dispatch_a(
670
728
  self,
671
729
  hidden_states: torch.Tensor,
730
+ input_global_scale: Optional[torch.Tensor],
672
731
  topk_idx: torch.Tensor,
673
732
  topk_weights: torch.Tensor,
674
733
  forward_batch: ForwardBatch,
@@ -676,6 +735,7 @@ class DeepEPDispatcher(BaseDispatcher):
676
735
  self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
677
736
  inner_state = self._get_impl(forward_batch).dispatch_a(
678
737
  hidden_states=hidden_states,
738
+ input_global_scale=input_global_scale,
679
739
  topk_idx=topk_idx,
680
740
  topk_weights=topk_weights,
681
741
  )
@@ -698,12 +758,14 @@ class DeepEPDispatcher(BaseDispatcher):
698
758
  topk_idx: torch.Tensor,
699
759
  topk_weights: torch.Tensor,
700
760
  forward_batch: ForwardBatch,
761
+ overlap_args: Optional["CombineOverlapArgs"] = None,
701
762
  ):
702
763
  self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
703
764
  inner_state = self._get_impl(forward_batch).combine_a(
704
765
  hidden_states=hidden_states,
705
766
  topk_idx=topk_idx,
706
767
  topk_weights=topk_weights,
768
+ overlap_args=overlap_args,
707
769
  )
708
770
  self._combine_intermediate_state = forward_batch, inner_state
709
771
 
@@ -1,19 +1,61 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import NamedTuple
3
+ from typing import TYPE_CHECKING, NamedTuple
4
4
 
5
- from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
5
+ import torch
6
+
7
+ from sglang.srt.layers.moe.token_dispatcher.base import (
8
+ BaseDispatcher,
9
+ CombineInput,
10
+ CombineInputFormat,
6
11
  DispatchOutput,
7
12
  DispatchOutputFormat,
8
13
  )
9
14
 
15
+ if TYPE_CHECKING:
16
+ from sglang.srt.layers.moe.topk import TopKOutput
17
+
10
18
 
11
19
  class StandardDispatchOutput(NamedTuple):
12
20
  """Standard dispatch output."""
13
21
 
22
+ hidden_states: torch.Tensor
23
+ topk_output: TopKOutput
24
+
14
25
  @property
15
26
  def format(self) -> DispatchOutputFormat:
16
27
  return DispatchOutputFormat.STANDARD
17
28
 
18
29
 
19
30
  assert isinstance(StandardDispatchOutput, DispatchOutput)
31
+
32
+
33
+ class StandardCombineInput(NamedTuple):
34
+ """Standard combine input."""
35
+
36
+ hidden_states: torch.Tensor
37
+
38
+ @property
39
+ def format(self) -> CombineInputFormat:
40
+ return CombineInputFormat.STANDARD
41
+
42
+
43
+ assert isinstance(StandardCombineInput, CombineInput)
44
+
45
+
46
+ class StandardDispatcher(BaseDispatcher):
47
+
48
+ def dispatch(
49
+ self, hidden_states: torch.Tensor, topk_output: TopKOutput
50
+ ) -> DispatchOutput:
51
+ return StandardDispatchOutput(
52
+ hidden_states=hidden_states, topk_output=topk_output
53
+ )
54
+
55
+ def combine(self, combine_input: CombineInput) -> torch.Tensor:
56
+ if isinstance(combine_input, StandardCombineInput):
57
+ return combine_input.hidden_states
58
+ else:
59
+ # TODO: this branch should be removed in the future
60
+ assert isinstance(combine_input, torch.Tensor)
61
+ return combine_input
@@ -19,6 +19,7 @@ import math
19
19
  from dataclasses import dataclass
20
20
  from enum import Enum, auto
21
21
  from typing import (
22
+ TYPE_CHECKING,
22
23
  Callable,
23
24
  NamedTuple,
24
25
  Optional,
@@ -51,6 +52,9 @@ from sglang.srt.utils import (
51
52
  is_npu,
52
53
  )
53
54
 
55
+ if TYPE_CHECKING:
56
+ from sglang.srt.layers.quantization import QuantizationConfig
57
+
54
58
  try:
55
59
  from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
56
60
  except ImportError:
@@ -94,6 +98,7 @@ class TopKConfig:
94
98
  torch_native: bool = False
95
99
  routed_scaling_factor: Optional[float] = None
96
100
  apply_routed_scaling_factor_on_output: bool = False
101
+ output_format: Optional[TopKOutputFormat] = None
97
102
 
98
103
 
99
104
  # -------------------------------- TopKOutput ---------------------------------------
@@ -196,9 +201,10 @@ class TopK(CustomOp):
196
201
  custom_routing_function: Optional[Callable] = None,
197
202
  scoring_func: str = "softmax",
198
203
  correction_bias: Optional[torch.Tensor] = None,
204
+ quant_config: Optional[QuantizationConfig] = None,
199
205
  routed_scaling_factor: Optional[float] = None,
200
206
  apply_routed_scaling_factor_on_output: Optional[bool] = False,
201
- force_topk: bool = False,
207
+ output_format: Optional[TopKOutputFormat] = None,
202
208
  ):
203
209
  # NOTE: scoring_func is not used for now, but we keep it for future use
204
210
  # see https://github.com/sgl-project/sglang/pull/4505 for more details
@@ -218,11 +224,9 @@ class TopK(CustomOp):
218
224
  correction_bias=correction_bias,
219
225
  routed_scaling_factor=routed_scaling_factor,
220
226
  apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
227
+ output_format=output_format,
221
228
  )
222
229
 
223
- self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
224
- self.force_topk = force_topk
225
-
226
230
  def forward_native(
227
231
  self,
228
232
  hidden_states: torch.Tensor,
@@ -248,7 +252,19 @@ class TopK(CustomOp):
248
252
  num_token_non_padded: Optional[torch.Tensor] = None,
249
253
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
250
254
  ) -> TopKOutput:
251
- if self.use_triton_kernels:
255
+ if self.topk_config.output_format is not None:
256
+ output_format = self.topk_config.output_format
257
+ elif get_moe_runner_backend().is_triton_kernel():
258
+ output_format = TopKOutputFormat.TRITON_KERNEL
259
+ elif (
260
+ should_use_flashinfer_trtllm_moe()
261
+ or get_moe_runner_backend().is_flashinfer_mxfp4()
262
+ ):
263
+ output_format = TopKOutputFormat.BYPASSED
264
+ else:
265
+ output_format = TopKOutputFormat.STANDARD
266
+
267
+ if output_format == TopKOutputFormat.TRITON_KERNEL:
252
268
  # renormalize=True is equivalent to sm_first=False
253
269
  routing_data, gather_idx, scatter_idx = routing(
254
270
  router_logits,
@@ -256,10 +272,7 @@ class TopK(CustomOp):
256
272
  sm_first=not self.topk_config.renormalize,
257
273
  )
258
274
  return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
259
- elif not self.force_topk and (
260
- should_use_flashinfer_trtllm_moe()
261
- or get_moe_runner_backend().is_flashinfer_mxfp4()
262
- ):
275
+ elif output_format == TopKOutputFormat.BYPASSED:
263
276
  return BypassedTopKOutput(
264
277
  hidden_states=hidden_states,
265
278
  router_logits=router_logits,
@@ -330,6 +343,14 @@ class TopK(CustomOp):
330
343
  )
331
344
  topk_weights = topk_weights / topk_weights_sum
332
345
 
346
+ if expert_location_dispatch_info is not None:
347
+ topk_ids = topk_ids_logical_to_physical(
348
+ topk_ids, expert_location_dispatch_info
349
+ )
350
+ get_global_expert_distribution_recorder().on_select_experts(
351
+ topk_ids=topk_ids
352
+ )
353
+
333
354
  return StandardTopKOutput(topk_weights, topk_ids, _)
334
355
  else:
335
356
  self.topk_config.torch_native = True
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import importlib.util
4
+ import logging
4
5
  from enum import Enum
5
6
  from functools import lru_cache
6
7
  from typing import TYPE_CHECKING, Optional
@@ -12,11 +13,12 @@ from sglang.srt.layers.dp_attention import (
12
13
  get_attention_dp_size,
13
14
  is_dp_attention_enabled,
14
15
  )
15
- from sglang.srt.utils import logger
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  from sglang.srt.server_args import ServerArgs
19
19
 
20
+ logger = logging.getLogger(__name__)
21
+
20
22
 
21
23
  class MoeA2ABackend(Enum):
22
24
 
@@ -44,9 +46,10 @@ class MoeRunnerBackend(Enum):
44
46
  AUTO = "auto"
45
47
  TRITON = "triton"
46
48
  TRITON_KERNEL = "triton_kernel"
47
- FLASHINFER = "flashinfer_trtllm"
49
+ FLASHINFER_TRTLLM = "flashinfer_trtllm"
48
50
  FLASHINFER_CUTLASS = "flashinfer_cutlass"
49
51
  FLASHINFER_MXFP4 = "flashinfer_mxfp4"
52
+ FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
50
53
 
51
54
  def is_auto(self):
52
55
  return self == MoeRunnerBackend.AUTO
@@ -58,11 +61,14 @@ class MoeRunnerBackend(Enum):
58
61
  return self == MoeRunnerBackend.TRITON_KERNEL
59
62
 
60
63
  def is_flashinfer_trtllm(self):
61
- return self == MoeRunnerBackend.FLASHINFER
64
+ return self == MoeRunnerBackend.FLASHINFER_TRTLLM
62
65
 
63
66
  def is_flashinfer_cutlass(self):
64
67
  return self == MoeRunnerBackend.FLASHINFER_CUTLASS
65
68
 
69
+ def is_flashinfer_cutedsl(self):
70
+ return self == MoeRunnerBackend.FLASHINFER_CUTEDSL
71
+
66
72
  def is_flashinfer_mxfp4(self):
67
73
  return self == MoeRunnerBackend.FLASHINFER_MXFP4
68
74
 
@@ -102,6 +108,7 @@ MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
102
108
  MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
103
109
  DEEPEP_MODE: Optional[DeepEPMode] = None
104
110
  IS_TBO_ENABLED: Optional[bool] = None
111
+ IS_SBO_ENABLED: Optional[bool] = None
105
112
  TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
106
113
  DEEPEP_CONFIG: Optional[str] = None
107
114
  DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
@@ -113,6 +120,7 @@ def initialize_moe_config(server_args: ServerArgs):
113
120
  global DEEPEP_MODE
114
121
  global DEEPEP_CONFIG
115
122
  global IS_TBO_ENABLED
123
+ global IS_SBO_ENABLED
116
124
  global TBO_TOKEN_DISTRIBUTION_THRESHOLD
117
125
  global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
118
126
 
@@ -121,6 +129,7 @@ def initialize_moe_config(server_args: ServerArgs):
121
129
  DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
122
130
  DEEPEP_CONFIG = server_args.deepep_config or ""
123
131
  IS_TBO_ENABLED = server_args.enable_two_batch_overlap
132
+ IS_SBO_ENABLED = server_args.enable_single_batch_overlap
124
133
  TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
125
134
  DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
126
135
  server_args.disable_flashinfer_cutlass_moe_fp4_allgather
@@ -131,7 +140,7 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
131
140
  global MOE_A2A_BACKEND
132
141
  if MOE_A2A_BACKEND is None:
133
142
  logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
134
- MOE_A2A_BACKEND = MoeA2ABackend(None)
143
+ MOE_A2A_BACKEND = MoeA2ABackend.NONE
135
144
  return MOE_A2A_BACKEND
136
145
 
137
146
 
@@ -139,7 +148,7 @@ def get_moe_runner_backend() -> MoeRunnerBackend:
139
148
  global MOE_RUNNER_BACKEND
140
149
  if MOE_RUNNER_BACKEND is None:
141
150
  logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
142
- MOE_RUNNER_BACKEND = MoeRunnerBackend("triton")
151
+ MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
143
152
  return MOE_RUNNER_BACKEND
144
153
 
145
154
 
@@ -147,7 +156,7 @@ def get_deepep_mode() -> DeepEPMode:
147
156
  global DEEPEP_MODE
148
157
  if DEEPEP_MODE is None:
149
158
  logger.warning("DEEPEP_MODE is not initialized, using auto mode")
150
- DEEPEP_MODE = DeepEPMode("auto")
159
+ DEEPEP_MODE = DeepEPMode.AUTO
151
160
  return DEEPEP_MODE
152
161
 
153
162
 
@@ -166,6 +175,13 @@ def is_tbo_enabled() -> bool:
166
175
  return IS_TBO_ENABLED
167
176
 
168
177
 
178
+ def is_sbo_enabled() -> bool:
179
+ global IS_SBO_ENABLED
180
+ if IS_SBO_ENABLED is None:
181
+ IS_SBO_ENABLED = False
182
+ return IS_SBO_ENABLED
183
+
184
+
169
185
  def get_tbo_token_distribution_threshold() -> float:
170
186
  global TBO_TOKEN_DISTRIBUTION_THRESHOLD
171
187
  if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
@@ -7,6 +7,7 @@ from typing import Callable, Optional, Union
7
7
  import torch
8
8
  from torch.nn import Parameter
9
9
 
10
+ from sglang.srt.layers.utils import pad_or_narrow_weight
10
11
  from sglang.srt.utils import is_cpu
11
12
 
12
13
  __all__ = [
@@ -156,9 +157,17 @@ class _ColumnvLLMParameter(BasevLLMParameter):
156
157
  )
157
158
  else:
158
159
  if not use_presharded_weights:
159
- loaded_weight = loaded_weight.narrow(
160
- self.output_dim, tp_rank * shard_size, shard_size
161
- )
160
+ # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
161
+ start_idx = tp_rank * shard_size
162
+ end_idx = start_idx + shard_size
163
+ if end_idx > loaded_weight.shape[self.output_dim]:
164
+ loaded_weight = pad_or_narrow_weight(
165
+ loaded_weight, self.output_dim, start_idx, shard_size
166
+ )
167
+ else:
168
+ loaded_weight = loaded_weight.narrow(
169
+ self.output_dim, start_idx, shard_size
170
+ )
162
171
 
163
172
  assert param_data.shape == loaded_weight.shape
164
173
  param_data.copy_(loaded_weight)
@@ -258,9 +267,17 @@ class RowvLLMParameter(BasevLLMParameter):
258
267
 
259
268
  return
260
269
  else:
261
- loaded_weight = loaded_weight.narrow(
262
- self.input_dim, tp_rank * shard_size, shard_size
263
- )
270
+ # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
271
+ start_idx = tp_rank * shard_size
272
+ end_idx = start_idx + shard_size
273
+ if end_idx > loaded_weight.shape[self.input_dim]:
274
+ loaded_weight = pad_or_narrow_weight(
275
+ loaded_weight, self.input_dim, start_idx, shard_size
276
+ )
277
+ else:
278
+ loaded_weight = loaded_weight.narrow(
279
+ self.input_dim, start_idx, shard_size
280
+ )
264
281
 
265
282
  if len(loaded_weight.shape) == 0:
266
283
  loaded_weight = loaded_weight.reshape(1)