sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@
4
4
  # Adapted from
5
5
  # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
6
6
  # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
7
- """vLLM distributed state.
7
+ """Distributed state.
8
8
  It takes over the control of the distributed environment from PyTorch.
9
9
  The typical workflow is:
10
10
 
@@ -39,22 +39,31 @@ import torch
39
39
  import torch.distributed
40
40
  from torch.distributed import Backend, ProcessGroup
41
41
 
42
+ from sglang.srt.environ import envs
42
43
  from sglang.srt.utils import (
43
44
  direct_register_custom_op,
44
45
  get_bool_env_var,
45
46
  get_int_env_var,
47
+ get_local_ip_auto,
46
48
  is_cpu,
47
49
  is_cuda_alike,
48
50
  is_hip,
49
51
  is_npu,
50
52
  is_shm_available,
53
+ is_xpu,
51
54
  supports_custom_op,
52
55
  )
53
56
 
54
57
  _is_npu = is_npu()
55
58
  _is_cpu = is_cpu()
59
+ _is_xpu = is_xpu()
60
+ _supports_custom_op = supports_custom_op()
56
61
 
57
- IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
62
+
63
+ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
64
+
65
+ # use int value instead of ReduceOp.SUM to support torch compile
66
+ REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
58
67
 
59
68
 
60
69
  @dataclass
@@ -62,10 +71,10 @@ class GraphCaptureContext:
62
71
  stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
63
72
 
64
73
 
65
- TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
66
-
67
- # use int value instead of ReduceOp.SUM to support torch compile
68
- REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
74
+ @dataclass
75
+ class P2PWork:
76
+ work: Optional[torch.distributed.Work]
77
+ payload: Optional[torch.Tensor]
69
78
 
70
79
 
71
80
  def _split_tensor_dict(
@@ -117,7 +126,7 @@ def _register_group(group: "GroupCoordinator") -> None:
117
126
  _groups[group.unique_name] = weakref.ref(group)
118
127
 
119
128
 
120
- if supports_custom_op():
129
+ if _supports_custom_op:
121
130
 
122
131
  def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
123
132
  assert group_name in _groups, f"Group {group_name} is not found."
@@ -208,12 +217,14 @@ class GroupCoordinator:
208
217
  use_pynccl: bool # a hint of whether to use PyNccl
209
218
  use_pymscclpp: bool # a hint of whether to use PyMsccl
210
219
  use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
220
+ use_torch_symm_mem: bool # a hint of whether to use SymmMemAllReduce
211
221
  use_message_queue_broadcaster: (
212
222
  bool # a hint of whether to use message queue broadcaster
213
223
  )
214
224
  # communicators are only created for world size > 1
215
225
  pynccl_comm: Optional[Any] # PyNccl communicator
216
226
  ca_comm: Optional[Any] # Custom allreduce communicator
227
+ symm_mem_comm: Optional[Any] # Symm mem communicator
217
228
  mq_broadcaster: Optional[Any] # shared memory broadcaster
218
229
 
219
230
  def __init__(
@@ -224,11 +235,15 @@ class GroupCoordinator:
224
235
  use_pynccl: bool,
225
236
  use_pymscclpp: bool,
226
237
  use_custom_allreduce: bool,
238
+ use_torch_symm_mem: bool,
227
239
  use_hpu_communicator: bool,
228
240
  use_xpu_communicator: bool,
229
241
  use_npu_communicator: bool,
230
242
  use_message_queue_broadcaster: bool = False,
231
243
  group_name: Optional[str] = None,
244
+ pynccl_use_current_stream: bool = False,
245
+ torch_compile: Optional[bool] = None,
246
+ gloo_timeout: timedelta = timedelta(seconds=120 * 60),
232
247
  ):
233
248
  # Set group info
234
249
  group_name = group_name or "anonymous"
@@ -246,9 +261,14 @@ class GroupCoordinator:
246
261
  device_group = torch.distributed.new_group(
247
262
  ranks, backend=torch_distributed_backend
248
263
  )
249
- # a group with `gloo` backend, to allow direct coordination between
250
- # processes through the CPU.
251
- cpu_group = torch.distributed.new_group(ranks, backend="gloo")
264
+ # a cpu_group to allow direct coordination between processes through
265
+ # the CPU. The backend is chosen based on `torch_distributed_backend`
266
+ if "mooncake" in torch_distributed_backend:
267
+ cpu_group = torch.distributed.new_group(ranks, backend="mooncake-cpu")
268
+ else:
269
+ cpu_group = torch.distributed.new_group(
270
+ ranks, backend="gloo", timeout=gloo_timeout
271
+ )
252
272
  if self.rank in ranks:
253
273
  self.ranks = ranks
254
274
  self.world_size = len(ranks)
@@ -259,25 +279,29 @@ class GroupCoordinator:
259
279
  assert self.cpu_group is not None
260
280
  assert self.device_group is not None
261
281
 
262
- device_id = 0 if IS_ONE_DEVICE_PER_PROCESS else local_rank
263
282
  if is_cuda_alike():
283
+ device_id = (
284
+ 0 if envs.SGLANG_ONE_VISIBLE_DEVICE_PER_PROCESS.get() else local_rank
285
+ )
264
286
  self.device = torch.device(f"cuda:{device_id}")
265
287
  elif _is_npu:
266
- self.device = torch.device(f"npu:{device_id}")
288
+ self.device = torch.device(f"npu:{local_rank}")
267
289
  else:
268
290
  self.device = torch.device("cpu")
269
291
  self.device_module = torch.get_device_module(self.device)
270
292
 
271
293
  # Import communicators
272
294
  self.use_pynccl = use_pynccl
295
+ self.pynccl_use_current_stream = pynccl_use_current_stream
273
296
  self.use_pymscclpp = use_pymscclpp
274
297
  self.use_custom_allreduce = use_custom_allreduce
298
+ self.use_torch_symm_mem = use_torch_symm_mem
275
299
  self.use_hpu_communicator = use_hpu_communicator
276
300
  self.use_xpu_communicator = use_xpu_communicator
277
301
  self.use_npu_communicator = use_npu_communicator
278
302
  self.use_message_queue_broadcaster = use_message_queue_broadcaster
279
303
 
280
- # lazy import to avoid documentation build error
304
+ # Lazy import to avoid documentation build error
281
305
  from sglang.srt.distributed.device_communicators.custom_all_reduce import (
282
306
  CustomAllreduce,
283
307
  )
@@ -287,6 +311,9 @@ class GroupCoordinator:
287
311
  from sglang.srt.distributed.device_communicators.pynccl import (
288
312
  PyNcclCommunicator,
289
313
  )
314
+ from sglang.srt.distributed.device_communicators.symm_mem import (
315
+ SymmMemCommunicator,
316
+ )
290
317
 
291
318
  if is_hip():
292
319
  from sglang.srt.distributed.device_communicators.quick_all_reduce import (
@@ -299,6 +326,7 @@ class GroupCoordinator:
299
326
  self.pynccl_comm = PyNcclCommunicator(
300
327
  group=self.cpu_group,
301
328
  device=self.device,
329
+ use_current_stream=pynccl_use_current_stream,
302
330
  )
303
331
 
304
332
  self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
@@ -312,10 +340,17 @@ class GroupCoordinator:
312
340
  self.qr_comm: Optional[QuickAllReduce] = None
313
341
  if use_custom_allreduce and self.world_size > 1:
314
342
  # Initialize a custom fast all-reduce implementation.
343
+ if torch_compile is not None and torch_compile:
344
+ # For piecewise CUDA graph, the requirement for custom allreduce is larger to
345
+ # avoid illegal cuda memory access.
346
+ ca_max_size = 256 * 1024 * 1024
347
+ else:
348
+ ca_max_size = 8 * 1024 * 1024
315
349
  try:
316
350
  self.ca_comm = CustomAllreduce(
317
351
  group=self.cpu_group,
318
352
  device=self.device,
353
+ max_size=ca_max_size,
319
354
  )
320
355
  except Exception as e:
321
356
  logger.warning(
@@ -335,6 +370,13 @@ class GroupCoordinator:
335
370
  except Exception as e:
336
371
  logger.warning(f"Failed to initialize QuickAllReduce: {e}")
337
372
 
373
+ self.symm_mem_comm: Optional[SymmMemCommunicator] = None
374
+ if self.use_torch_symm_mem and self.world_size > 1:
375
+ self.symm_mem_comm = SymmMemCommunicator(
376
+ group=self.cpu_group,
377
+ device=self.device,
378
+ )
379
+
338
380
  # Create communicator for other hardware backends
339
381
  from sglang.srt.distributed.device_communicators.hpu_communicator import (
340
382
  HpuCommunicator,
@@ -412,10 +454,13 @@ class GroupCoordinator:
412
454
 
413
455
  @contextmanager
414
456
  def graph_capture(
415
- self, graph_capture_context: Optional[GraphCaptureContext] = None
457
+ self,
458
+ graph_capture_context: Optional[GraphCaptureContext] = None,
459
+ stream: Optional[torch.cuda.Stream] = None,
416
460
  ):
417
461
  if graph_capture_context is None:
418
- stream = self.device_module.Stream()
462
+ if stream is None:
463
+ stream = self.device_module.Stream()
419
464
  graph_capture_context = GraphCaptureContext(stream)
420
465
  else:
421
466
  stream = graph_capture_context.stream
@@ -439,6 +484,7 @@ class GroupCoordinator:
439
484
  # custom allreduce | enabled | enabled |
440
485
  # PyNccl | disabled| enabled |
441
486
  # PyMscclpp | disabled| enabled |
487
+ # TorchSymmMem | disabled| enabled |
442
488
  # torch.distributed | enabled | disabled|
443
489
  #
444
490
  # Note: When custom quick allreduce is enabled, a runtime check
@@ -497,7 +543,7 @@ class GroupCoordinator:
497
543
  torch.distributed.all_reduce(input_, group=self.device_group)
498
544
  return input_
499
545
 
500
- if not supports_custom_op():
546
+ if not _supports_custom_op:
501
547
  self._all_reduce_in_place(input_)
502
548
  return input_
503
549
 
@@ -523,23 +569,29 @@ class GroupCoordinator:
523
569
 
524
570
  outplace_all_reduce_method = None
525
571
  if (
526
- self.qr_comm is not None
527
- and not self.qr_comm.disabled
528
- and self.qr_comm.should_quick_allreduce(input_)
529
- ):
530
- outplace_all_reduce_method = "qr"
531
- elif (
532
572
  self.ca_comm is not None
533
573
  and not self.ca_comm.disabled
534
574
  and self.ca_comm.should_custom_ar(input_)
535
575
  ):
536
576
  outplace_all_reduce_method = "ca"
577
+ elif (
578
+ self.qr_comm is not None
579
+ and not self.qr_comm.disabled
580
+ and self.qr_comm.should_quick_allreduce(input_)
581
+ ):
582
+ outplace_all_reduce_method = "qr"
537
583
  elif (
538
584
  self.pymscclpp_comm is not None
539
585
  and not self.pymscclpp_comm.disabled
540
586
  and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
541
587
  ):
542
588
  outplace_all_reduce_method = "pymscclpp"
589
+ elif (
590
+ self.symm_mem_comm is not None
591
+ and not self.symm_mem_comm.disabled
592
+ and self.symm_mem_comm.should_symm_mem_allreduce(input_)
593
+ ):
594
+ outplace_all_reduce_method = "symm_mem"
543
595
  if outplace_all_reduce_method is not None:
544
596
  return torch.ops.sglang.outplace_all_reduce(
545
597
  input_,
@@ -553,16 +605,20 @@ class GroupCoordinator:
553
605
  def _all_reduce_out_place(
554
606
  self, input_: torch.Tensor, outplace_all_reduce_method: str
555
607
  ) -> torch.Tensor:
556
- qr_comm = self.qr_comm
557
608
  ca_comm = self.ca_comm
609
+ qr_comm = self.qr_comm
558
610
  pymscclpp_comm = self.pymscclpp_comm
611
+ symm_mem_comm = self.symm_mem_comm
559
612
  assert any([qr_comm, ca_comm, pymscclpp_comm])
560
- if outplace_all_reduce_method == "qr":
561
- assert not qr_comm.disabled
562
- out = qr_comm.quick_all_reduce(input_)
563
- elif outplace_all_reduce_method == "ca":
613
+ if outplace_all_reduce_method == "ca":
564
614
  assert not ca_comm.disabled
565
615
  out = ca_comm.custom_all_reduce(input_)
616
+ elif outplace_all_reduce_method == "qr":
617
+ assert not qr_comm.disabled
618
+ out = qr_comm.quick_all_reduce(input_)
619
+ elif outplace_all_reduce_method == "symm_mem":
620
+ assert not symm_mem_comm.disabled
621
+ out = symm_mem_comm.all_reduce(input_)
566
622
  else:
567
623
  assert not pymscclpp_comm.disabled
568
624
  out = pymscclpp_comm.all_reduce(input_)
@@ -571,8 +627,11 @@ class GroupCoordinator:
571
627
 
572
628
  def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
573
629
  pynccl_comm = self.pynccl_comm
630
+ symm_mem_comm = self.symm_mem_comm
574
631
  if pynccl_comm is not None and not pynccl_comm.disabled:
575
632
  pynccl_comm.all_reduce(input_)
633
+ elif symm_mem_comm is not None and not symm_mem_comm.disabled:
634
+ symm_mem_comm.all_reduce(input_)
576
635
  else:
577
636
  torch.distributed.all_reduce(input_, group=self.device_group)
578
637
 
@@ -637,7 +696,7 @@ class GroupCoordinator:
637
696
  )
638
697
 
639
698
  def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
640
- if _is_npu or not supports_custom_op():
699
+ if _is_npu or _is_xpu or not _supports_custom_op:
641
700
  self._all_gather_into_tensor(output, input)
642
701
  else:
643
702
  torch.ops.sglang.reg_all_gather_into_tensor(
@@ -697,15 +756,13 @@ class GroupCoordinator:
697
756
  )
698
757
 
699
758
  # All-gather.
700
- if input_.is_cpu and is_shm_available(
701
- input_.dtype, self.world_size, self.local_size
702
- ):
703
- return torch.ops.sgl_kernel.shm_allgather(input_, dim)
704
-
705
759
  if input_.is_cpu:
706
- torch.distributed.all_gather_into_tensor(
707
- output_tensor, input_, group=self.device_group
708
- )
760
+ if is_shm_available(input_.dtype, self.world_size, self.local_size):
761
+ return torch.ops.sgl_kernel.shm_allgather(input_, dim)
762
+ else:
763
+ torch.distributed.all_gather_into_tensor(
764
+ output_tensor, input_, group=self.device_group
765
+ )
709
766
  else:
710
767
  self.all_gather_into_tensor(output_tensor, input_)
711
768
 
@@ -861,45 +918,63 @@ class GroupCoordinator:
861
918
  torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)
862
919
  return objs
863
920
 
864
- def send_object(self, obj: Any, dst: int) -> None:
865
- """Send the input object list to the destination rank."""
866
- """NOTE: `dst` is the local rank of the destination rank."""
921
+ def send_object(
922
+ self,
923
+ obj: Any,
924
+ dst: int,
925
+ async_send: bool = False,
926
+ ) -> List[P2PWork]:
927
+ """
928
+ Send the input object list to the destination rank.
929
+ This function uses the CPU group for all communications.
867
930
 
868
- assert dst < self.world_size, f"Invalid dst rank ({dst})"
931
+ TODO: If you want to use GPU communication, please add a new argument (e.g., data_group, group),
932
+ use other functions (e.g., send), or implement a new function (e.g., send_object_device).
869
933
 
934
+ NOTE: `dst` is the local rank of the destination rank.
935
+ """
936
+
937
+ assert dst < self.world_size, f"Invalid dst rank ({dst})"
870
938
  assert dst != self.rank_in_group, (
871
939
  "Invalid destination rank. Destination rank is the same "
872
940
  "as the current rank."
873
941
  )
942
+ send_func = torch.distributed.isend if async_send else torch.distributed.send
874
943
 
875
944
  # Serialize object to tensor and get the size as well
876
- object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
877
- device=torch.cuda.current_device()
878
- )
879
-
945
+ object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
880
946
  size_tensor = torch.tensor(
881
- [object_tensor.numel()],
882
- dtype=torch.long,
883
- device="cpu",
947
+ [object_tensor.numel()], dtype=torch.long, device="cpu"
884
948
  )
949
+
885
950
  # Send object size
886
- torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
951
+ p2p_work = []
952
+ size_work = send_func(
953
+ size_tensor,
954
+ self.ranks[dst],
955
+ group=self.cpu_group,
956
+ )
957
+ if async_send:
958
+ p2p_work.append(P2PWork(size_work, size_tensor))
887
959
 
888
- # Send object
889
- torch.distributed.send(
960
+ object_work = send_func(
890
961
  object_tensor,
891
- dst=self.ranks[dst],
892
- group=self.device_group,
962
+ self.ranks[dst],
963
+ group=self.cpu_group,
893
964
  )
965
+ if async_send:
966
+ p2p_work.append(P2PWork(object_work, object_tensor))
894
967
 
895
- return None
968
+ return p2p_work
896
969
 
897
- def recv_object(self, src: int) -> Any:
970
+ def recv_object(
971
+ self,
972
+ src: int,
973
+ ) -> Any:
898
974
  """Receive the input object list from the source rank."""
899
975
  """NOTE: `src` is the local rank of the source rank."""
900
976
 
901
977
  assert src < self.world_size, f"Invalid src rank ({src})"
902
-
903
978
  assert (
904
979
  src != self.rank_in_group
905
980
  ), "Invalid source rank. Source rank is the same as the current rank."
@@ -907,27 +982,25 @@ class GroupCoordinator:
907
982
  size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
908
983
 
909
984
  # Receive object size
910
- rank_size = torch.distributed.recv(
985
+ # We have to use irecv here to make it work for both isend and send.
986
+ work = torch.distributed.irecv(
911
987
  size_tensor, src=self.ranks[src], group=self.cpu_group
912
988
  )
989
+ work.wait()
913
990
 
914
991
  # Tensor to receive serialized objects into.
915
- object_tensor = torch.empty( # type: ignore[call-overload]
992
+ object_tensor: Any = torch.empty( # type: ignore[call-overload]
916
993
  size_tensor.item(), # type: ignore[arg-type]
917
994
  dtype=torch.uint8,
918
- device=torch.cuda.current_device(),
995
+ device="cpu",
919
996
  )
920
997
 
921
- rank_object = torch.distributed.recv(
922
- object_tensor, src=self.ranks[src], group=self.device_group
998
+ work = torch.distributed.irecv(
999
+ object_tensor, src=self.ranks[src], group=self.cpu_group
923
1000
  )
1001
+ work.wait()
924
1002
 
925
- assert (
926
- rank_object == rank_size
927
- ), "Received object sender rank does not match the size sender rank."
928
-
929
- obj = pickle.loads(object_tensor.cpu().numpy())
930
-
1003
+ obj = pickle.loads(object_tensor.numpy())
931
1004
  return obj
932
1005
 
933
1006
  def broadcast_tensor_dict(
@@ -1017,12 +1090,13 @@ class GroupCoordinator:
1017
1090
  tensor_dict: Dict[str, Union[torch.Tensor, Any]],
1018
1091
  dst: Optional[int] = None,
1019
1092
  all_gather_group: Optional["GroupCoordinator"] = None,
1020
- ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
1093
+ async_send: bool = False,
1094
+ ) -> Optional[List[P2PWork]]:
1021
1095
  """Send the input tensor dictionary.
1022
1096
  NOTE: `dst` is the local rank of the source rank.
1023
1097
  """
1024
1098
  # Bypass the function if we are using only 1 GPU.
1025
- if not torch.distributed.is_initialized() or self.world_size == 1:
1099
+ if self.world_size == 1:
1026
1100
  return tensor_dict
1027
1101
 
1028
1102
  all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
@@ -1047,7 +1121,10 @@ class GroupCoordinator:
1047
1121
  # 1. Superior D2D transfer bandwidth
1048
1122
  # 2. Ability to overlap send and recv operations
1049
1123
  # Thus the net performance gain justifies this approach.
1050
- self.send_object(metadata_list, dst=dst)
1124
+
1125
+ send_func = torch.distributed.isend if async_send else torch.distributed.send
1126
+ p2p_works = self.send_object(metadata_list, dst=dst, async_send=async_send)
1127
+
1051
1128
  for tensor in tensor_list:
1052
1129
  if tensor.numel() == 0:
1053
1130
  # Skip sending empty tensors.
@@ -1057,15 +1134,11 @@ class GroupCoordinator:
1057
1134
  if all_gather_group is not None and tensor.numel() % all_gather_size == 0:
1058
1135
  tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
1059
1136
 
1060
- if tensor.is_cpu:
1061
- # use metadata_group for CPU tensors
1062
- torch.distributed.send(
1063
- tensor, dst=self.ranks[dst], group=metadata_group
1064
- )
1065
- else:
1066
- # use group for GPU tensors
1067
- torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
1068
- return None
1137
+ comm_group = metadata_group if tensor.is_cpu else group
1138
+ work = send_func(tensor, self.ranks[dst], group=comm_group)
1139
+ if async_send:
1140
+ p2p_works.append(P2PWork(work, tensor))
1141
+ return p2p_works
1069
1142
 
1070
1143
  def recv_tensor_dict(
1071
1144
  self,
@@ -1111,17 +1184,15 @@ class GroupCoordinator:
1111
1184
  orig_shape = tensor.shape
1112
1185
  tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
1113
1186
 
1114
- if tensor.is_cpu:
1115
- # use metadata_group for CPU tensors
1116
- torch.distributed.recv(
1117
- tensor, src=self.ranks[src], group=metadata_group
1118
- )
1119
- else:
1120
- # use group for GPU tensors
1121
- torch.distributed.recv(tensor, src=self.ranks[src], group=group)
1187
+ # We have to use irecv here to make it work for both isend and send.
1188
+ comm_group = metadata_group if tensor.is_cpu else group
1189
+ work = torch.distributed.irecv(
1190
+ tensor, src=self.ranks[src], group=comm_group
1191
+ )
1192
+ work.wait()
1193
+
1122
1194
  if use_all_gather:
1123
- # do the allgather
1124
- tensor = all_gather_group.all_gather(tensor, dim=0) # type: ignore
1195
+ tensor = all_gather_group.all_gather(tensor, dim=0)
1125
1196
  tensor = tensor.reshape(orig_shape)
1126
1197
 
1127
1198
  tensor_dict[key] = tensor
@@ -1199,6 +1270,7 @@ def init_world_group(
1199
1270
  use_pynccl=False,
1200
1271
  use_pymscclpp=False,
1201
1272
  use_custom_allreduce=False,
1273
+ use_torch_symm_mem=False,
1202
1274
  use_hpu_communicator=False,
1203
1275
  use_xpu_communicator=False,
1204
1276
  use_npu_communicator=False,
@@ -1214,23 +1286,31 @@ def init_model_parallel_group(
1214
1286
  use_message_queue_broadcaster: bool = False,
1215
1287
  group_name: Optional[str] = None,
1216
1288
  use_mscclpp_allreduce: Optional[bool] = None,
1289
+ pynccl_use_current_stream: bool = True,
1290
+ use_symm_mem_allreduce: Optional[bool] = None,
1291
+ torch_compile: Optional[bool] = None,
1217
1292
  ) -> GroupCoordinator:
1218
1293
  if use_custom_allreduce is None:
1219
1294
  use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
1220
1295
  if use_mscclpp_allreduce is None:
1221
1296
  use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE
1297
+ if use_symm_mem_allreduce is None:
1298
+ use_symm_mem_allreduce = _ENABLE_SYMM_MEM_ALL_REDUCE
1222
1299
  return GroupCoordinator(
1223
1300
  group_ranks=group_ranks,
1224
1301
  local_rank=local_rank,
1225
1302
  torch_distributed_backend=backend,
1226
- use_pynccl=not _is_npu,
1303
+ use_pynccl=not (_is_npu or _is_xpu),
1227
1304
  use_pymscclpp=use_mscclpp_allreduce,
1228
1305
  use_custom_allreduce=use_custom_allreduce,
1306
+ use_torch_symm_mem=use_symm_mem_allreduce,
1229
1307
  use_hpu_communicator=True,
1230
1308
  use_xpu_communicator=True,
1231
1309
  use_npu_communicator=True,
1232
1310
  use_message_queue_broadcaster=use_message_queue_broadcaster,
1233
1311
  group_name=group_name,
1312
+ pynccl_use_current_stream=pynccl_use_current_stream,
1313
+ torch_compile=torch_compile,
1234
1314
  )
1235
1315
 
1236
1316
 
@@ -1287,7 +1367,7 @@ get_pipeline_model_parallel_group = get_pp_group
1287
1367
 
1288
1368
 
1289
1369
  @contextmanager
1290
- def graph_capture():
1370
+ def graph_capture(stream: Optional[torch.cuda.Stream] = None):
1291
1371
  """
1292
1372
  `graph_capture` is a context manager which should surround the code that
1293
1373
  is capturing the CUDA graph. Its main purpose is to ensure that the
@@ -1301,9 +1381,9 @@ def graph_capture():
1301
1381
  in order to explicitly distinguish the kernels to capture
1302
1382
  from other kernels possibly launched on background in the default stream.
1303
1383
  """
1304
- with get_tp_group().graph_capture() as context, get_pp_group().graph_capture(
1305
- context
1306
- ):
1384
+ with get_tp_group().graph_capture(
1385
+ stream=stream
1386
+ ) as context, get_pp_group().graph_capture(context):
1307
1387
  yield context
1308
1388
 
1309
1389
 
@@ -1311,6 +1391,7 @@ logger = logging.getLogger(__name__)
1311
1391
 
1312
1392
  _ENABLE_CUSTOM_ALL_REDUCE = True
1313
1393
  _ENABLE_MSCCLPP_ALL_REDUCE = False
1394
+ _ENABLE_SYMM_MEM_ALL_REDUCE = False
1314
1395
 
1315
1396
 
1316
1397
  def set_custom_all_reduce(enable: bool):
@@ -1323,6 +1404,11 @@ def set_mscclpp_all_reduce(enable: bool):
1323
1404
  _ENABLE_MSCCLPP_ALL_REDUCE = enable
1324
1405
 
1325
1406
 
1407
+ def set_symm_mem_all_reduce(enable: bool):
1408
+ global _ENABLE_SYMM_MEM_ALL_REDUCE
1409
+ _ENABLE_SYMM_MEM_ALL_REDUCE = enable
1410
+
1411
+
1326
1412
  def init_distributed_environment(
1327
1413
  world_size: int = -1,
1328
1414
  rank: int = -1,
@@ -1339,6 +1425,17 @@ def init_distributed_environment(
1339
1425
  distributed_init_method,
1340
1426
  backend,
1341
1427
  )
1428
+ if "mooncake" in backend:
1429
+ try:
1430
+ from mooncake import ep as mooncake_ep
1431
+ except ImportError as e:
1432
+ raise ImportError(
1433
+ "Please install mooncake by following the instructions at "
1434
+ "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
1435
+ "to run SGLang with Mooncake Backend."
1436
+ ) from e
1437
+ mooncake_ep.set_host_ip(get_local_ip_auto())
1438
+
1342
1439
  if not torch.distributed.is_initialized():
1343
1440
  assert distributed_init_method is not None, (
1344
1441
  "distributed_init_method must be provided when initializing "
@@ -1384,6 +1481,7 @@ def initialize_model_parallel(
1384
1481
  pipeline_model_parallel_size: int = 1,
1385
1482
  backend: Optional[str] = None,
1386
1483
  duplicate_tp_group: bool = False,
1484
+ torch_compile: Optional[bool] = None,
1387
1485
  ) -> None:
1388
1486
  """
1389
1487
  Initialize model parallel groups.
@@ -1439,6 +1537,8 @@ def initialize_model_parallel(
1439
1537
  "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
1440
1538
  ),
1441
1539
  group_name="tp",
1540
+ pynccl_use_current_stream=duplicate_tp_group,
1541
+ torch_compile=torch_compile,
1442
1542
  )
1443
1543
 
1444
1544
  if duplicate_tp_group:
@@ -1454,16 +1554,18 @@ def initialize_model_parallel(
1454
1554
  "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
1455
1555
  ),
1456
1556
  group_name="pdmux_prefill_tp",
1557
+ pynccl_use_current_stream=True,
1558
+ torch_compile=torch_compile,
1457
1559
  )
1458
- _TP.pynccl_comm.disabled = False
1459
- _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
1560
+ if _TP.pynccl_comm:
1561
+ _TP.pynccl_comm.disabled = False
1562
+ _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
1460
1563
 
1461
1564
  moe_ep_size = expert_model_parallel_size
1462
1565
  moe_tp_size = tensor_model_parallel_size // moe_ep_size
1463
1566
 
1464
1567
  global _MOE_EP
1465
1568
  assert _MOE_EP is None, "expert model parallel group is already initialized"
1466
-
1467
1569
  if moe_ep_size == tensor_model_parallel_size:
1468
1570
  _MOE_EP = _TP
1469
1571
  else:
@@ -1484,7 +1586,6 @@ def initialize_model_parallel(
1484
1586
 
1485
1587
  global _MOE_TP
1486
1588
  assert _MOE_TP is None, "expert model parallel group is already initialized"
1487
-
1488
1589
  if moe_tp_size == tensor_model_parallel_size:
1489
1590
  _MOE_TP = _TP
1490
1591
  else:
@@ -1649,6 +1750,11 @@ def destroy_model_parallel():
1649
1750
  _PP.destroy()
1650
1751
  _PP = None
1651
1752
 
1753
+ global _PDMUX_PREFILL_TP_GROUP
1754
+ if _PDMUX_PREFILL_TP_GROUP: # type: ignore[union-attr]
1755
+ _PDMUX_PREFILL_TP_GROUP.destroy()
1756
+ _PDMUX_PREFILL_TP_GROUP = None
1757
+
1652
1758
 
1653
1759
  def destroy_distributed_environment():
1654
1760
  global _WORLD