sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/lang/interpreter.py +1 -1
  7. sglang/launch_server.py +14 -0
  8. sglang/profiler.py +2 -2
  9. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  10. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  11. sglang/srt/configs/__init__.py +8 -0
  12. sglang/srt/configs/device_config.py +3 -1
  13. sglang/srt/configs/dots_ocr.py +64 -0
  14. sglang/srt/configs/dots_vlm.py +139 -0
  15. sglang/srt/configs/falcon_h1.py +360 -0
  16. sglang/srt/configs/internvl.py +6 -0
  17. sglang/srt/configs/load_config.py +9 -0
  18. sglang/srt/configs/model_config.py +181 -82
  19. sglang/srt/configs/qwen3_next.py +326 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +71 -19
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +326 -53
  44. sglang/srt/disaggregation/prefill.py +36 -17
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +192 -113
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  52. sglang/srt/entrypoints/grpc_server.py +810 -0
  53. sglang/srt/entrypoints/http_server.py +132 -57
  54. sglang/srt/entrypoints/openai/protocol.py +115 -7
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +207 -58
  57. sglang/srt/entrypoints/openai/serving_completions.py +17 -4
  58. sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +49 -4
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/environ.py +285 -0
  63. sglang/srt/eplb/eplb_manager.py +2 -2
  64. sglang/srt/eplb/expert_distribution.py +26 -13
  65. sglang/srt/eplb/expert_location.py +38 -8
  66. sglang/srt/eplb/expert_location_updater.py +1 -1
  67. sglang/srt/function_call/base_format_detector.py +3 -6
  68. sglang/srt/function_call/ebnf_composer.py +11 -9
  69. sglang/srt/function_call/function_call_parser.py +9 -2
  70. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  71. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  72. sglang/srt/function_call/json_array_parser.py +63 -0
  73. sglang/srt/function_call/kimik2_detector.py +17 -4
  74. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  75. sglang/srt/function_call/utils.py +96 -5
  76. sglang/srt/grpc/__init__.py +1 -0
  77. sglang/srt/grpc/compile_proto.py +245 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  81. sglang/srt/layers/activation.py +143 -9
  82. sglang/srt/layers/attention/aiter_backend.py +106 -82
  83. sglang/srt/layers/attention/ascend_backend.py +115 -9
  84. sglang/srt/layers/attention/attention_registry.py +206 -0
  85. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  86. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  87. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  88. sglang/srt/layers/attention/fla/chunk.py +242 -0
  89. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  90. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  91. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  92. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  93. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  94. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  95. sglang/srt/layers/attention/fla/index.py +37 -0
  96. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  97. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  98. sglang/srt/layers/attention/fla/op.py +66 -0
  99. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  100. sglang/srt/layers/attention/fla/utils.py +331 -0
  101. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  102. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  103. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  104. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  105. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  106. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  107. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  108. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  109. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  111. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  112. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  113. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  114. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  115. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  121. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  122. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  123. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  124. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  125. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  126. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  127. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  128. sglang/srt/layers/attention/nsa/utils.py +24 -0
  129. sglang/srt/layers/attention/nsa_backend.py +887 -0
  130. sglang/srt/layers/attention/tbo_backend.py +6 -6
  131. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  132. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  133. sglang/srt/layers/attention/triton_backend.py +57 -7
  134. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  135. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  136. sglang/srt/layers/attention/vision.py +58 -0
  137. sglang/srt/layers/attention/wave_backend.py +4 -4
  138. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  139. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  140. sglang/srt/layers/communicator.py +53 -7
  141. sglang/srt/layers/dp_attention.py +41 -2
  142. sglang/srt/layers/elementwise.py +3 -1
  143. sglang/srt/layers/layernorm.py +34 -15
  144. sglang/srt/layers/linear.py +55 -7
  145. sglang/srt/layers/logits_processor.py +44 -12
  146. sglang/srt/layers/moe/__init__.py +2 -1
  147. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  148. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  149. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  150. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  151. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  167. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  169. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  170. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  171. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  172. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  173. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  174. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  175. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  176. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  177. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  178. sglang/srt/layers/moe/topk.py +30 -9
  179. sglang/srt/layers/moe/utils.py +22 -7
  180. sglang/srt/layers/parameter.py +23 -6
  181. sglang/srt/layers/quantization/awq.py +19 -7
  182. sglang/srt/layers/quantization/base_config.py +11 -6
  183. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  184. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  185. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  186. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  187. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  188. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  189. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  190. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  191. sglang/srt/layers/quantization/fp8.py +78 -49
  192. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  193. sglang/srt/layers/quantization/gptq.py +25 -17
  194. sglang/srt/layers/quantization/modelopt_quant.py +225 -57
  195. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  196. sglang/srt/layers/quantization/mxfp4.py +77 -42
  197. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  199. sglang/srt/layers/quantization/quark/utils.py +97 -0
  200. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  201. sglang/srt/layers/quantization/unquant.py +135 -47
  202. sglang/srt/layers/quantization/w4afp8.py +26 -17
  203. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  204. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  205. sglang/srt/layers/rocm_linear_utils.py +44 -0
  206. sglang/srt/layers/rotary_embedding.py +78 -49
  207. sglang/srt/layers/sampler.py +213 -21
  208. sglang/srt/layers/utils.py +23 -0
  209. sglang/srt/lora/backend/base_backend.py +50 -8
  210. sglang/srt/lora/backend/chunked_backend.py +348 -0
  211. sglang/srt/lora/backend/triton_backend.py +99 -5
  212. sglang/srt/lora/layers.py +32 -0
  213. sglang/srt/lora/lora.py +8 -3
  214. sglang/srt/lora/lora_manager.py +52 -118
  215. sglang/srt/lora/mem_pool.py +25 -11
  216. sglang/srt/lora/triton_ops/__init__.py +4 -0
  217. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  219. sglang/srt/lora/utils.py +22 -11
  220. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  221. sglang/srt/managers/cache_controller.py +215 -314
  222. sglang/srt/managers/data_parallel_controller.py +115 -80
  223. sglang/srt/managers/detokenizer_manager.py +19 -15
  224. sglang/srt/managers/disagg_service.py +46 -0
  225. sglang/srt/managers/io_struct.py +340 -109
  226. sglang/srt/managers/mm_utils.py +44 -6
  227. sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
  228. sglang/srt/managers/multimodal_processor.py +1 -2
  229. sglang/srt/managers/overlap_utils.py +53 -0
  230. sglang/srt/managers/schedule_batch.py +240 -138
  231. sglang/srt/managers/schedule_policy.py +147 -19
  232. sglang/srt/managers/scheduler.py +501 -304
  233. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  234. sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
  235. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  236. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  237. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  238. sglang/srt/managers/template_manager.py +3 -3
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +321 -632
  241. sglang/srt/managers/tp_worker.py +81 -22
  242. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  243. sglang/srt/managers/utils.py +1 -45
  244. sglang/srt/mem_cache/allocator.py +15 -21
  245. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  246. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  247. sglang/srt/mem_cache/chunk_cache.py +8 -1
  248. sglang/srt/mem_cache/evict_policy.py +23 -0
  249. sglang/srt/mem_cache/hicache_storage.py +58 -34
  250. sglang/srt/mem_cache/hiradix_cache.py +227 -80
  251. sglang/srt/mem_cache/memory_pool.py +535 -58
  252. sglang/srt/mem_cache/memory_pool_host.py +239 -223
  253. sglang/srt/mem_cache/radix_cache.py +222 -73
  254. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  255. sglang/srt/mem_cache/storage/__init__.py +10 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  257. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  258. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  259. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  260. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  261. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  262. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  263. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
  264. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  265. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  266. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
  267. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  268. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  269. sglang/srt/metrics/collector.py +519 -132
  270. sglang/srt/metrics/func_timer.py +2 -7
  271. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  272. sglang/srt/metrics/utils.py +55 -0
  273. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  274. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  275. sglang/srt/model_executor/forward_batch_info.py +98 -57
  276. sglang/srt/model_executor/model_runner.py +433 -158
  277. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  278. sglang/srt/model_loader/__init__.py +9 -3
  279. sglang/srt/model_loader/loader.py +133 -5
  280. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  281. sglang/srt/model_loader/weight_utils.py +158 -3
  282. sglang/srt/models/apertus.py +686 -0
  283. sglang/srt/models/bailing_moe.py +820 -217
  284. sglang/srt/models/bailing_moe_nextn.py +168 -0
  285. sglang/srt/models/deepseek_nextn.py +6 -1
  286. sglang/srt/models/deepseek_v2.py +833 -152
  287. sglang/srt/models/dots_ocr.py +173 -0
  288. sglang/srt/models/dots_vlm.py +174 -0
  289. sglang/srt/models/dots_vlm_vit.py +337 -0
  290. sglang/srt/models/ernie4.py +1 -1
  291. sglang/srt/models/falcon_h1.py +576 -0
  292. sglang/srt/models/gemma3_causal.py +0 -2
  293. sglang/srt/models/gemma3_mm.py +1 -1
  294. sglang/srt/models/gemma3n_mm.py +2 -2
  295. sglang/srt/models/glm4_moe.py +14 -5
  296. sglang/srt/models/glm4_moe_nextn.py +2 -2
  297. sglang/srt/models/glm4v.py +5 -3
  298. sglang/srt/models/glm4v_moe.py +4 -1
  299. sglang/srt/models/gpt_oss.py +8 -31
  300. sglang/srt/models/internvl.py +28 -0
  301. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  302. sglang/srt/models/llama.py +4 -0
  303. sglang/srt/models/llama4.py +9 -0
  304. sglang/srt/models/llama_eagle3.py +13 -0
  305. sglang/srt/models/longcat_flash.py +3 -3
  306. sglang/srt/models/longcat_flash_nextn.py +1 -1
  307. sglang/srt/models/minicpmv.py +165 -3
  308. sglang/srt/models/mllama4.py +40 -4
  309. sglang/srt/models/opt.py +637 -0
  310. sglang/srt/models/qwen2_5_vl.py +29 -5
  311. sglang/srt/models/qwen2_audio.py +1 -1
  312. sglang/srt/models/qwen2_moe.py +124 -14
  313. sglang/srt/models/qwen2_vl.py +1 -1
  314. sglang/srt/models/qwen3.py +26 -5
  315. sglang/srt/models/qwen3_moe.py +71 -12
  316. sglang/srt/models/qwen3_next.py +1069 -0
  317. sglang/srt/models/qwen3_next_mtp.py +112 -0
  318. sglang/srt/models/qwen3_vl.py +787 -0
  319. sglang/srt/models/qwen3_vl_moe.py +471 -0
  320. sglang/srt/models/registry.py +15 -3
  321. sglang/srt/models/sarashina2_vision.py +269 -0
  322. sglang/srt/models/solar.py +505 -0
  323. sglang/srt/models/starcoder2.py +357 -0
  324. sglang/srt/models/step3_vl.py +1 -1
  325. sglang/srt/models/torch_native_llama.py +10 -3
  326. sglang/srt/models/utils.py +51 -0
  327. sglang/srt/multimodal/processors/base_processor.py +15 -7
  328. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  329. sglang/srt/multimodal/processors/glm4v.py +9 -9
  330. sglang/srt/multimodal/processors/internvl.py +153 -129
  331. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  332. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  333. sglang/srt/offloader.py +27 -3
  334. sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
  335. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  336. sglang/srt/sampling/sampling_batch_info.py +38 -17
  337. sglang/srt/sampling/sampling_params.py +7 -0
  338. sglang/srt/server_args.py +1030 -254
  339. sglang/srt/server_args_config_parser.py +146 -0
  340. sglang/srt/single_batch_overlap.py +151 -0
  341. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  342. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  343. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  344. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  345. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  346. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  347. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  348. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  349. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  350. sglang/srt/speculative/eagle_worker.py +253 -136
  351. sglang/srt/speculative/ngram_utils.py +428 -0
  352. sglang/srt/speculative/ngram_worker.py +245 -0
  353. sglang/srt/speculative/spec_info.py +52 -0
  354. sglang/srt/speculative/spec_utils.py +606 -0
  355. sglang/srt/speculative/standalone_worker.py +109 -0
  356. sglang/srt/torch_memory_saver_adapter.py +5 -7
  357. sglang/srt/tracing/trace.py +578 -0
  358. sglang/srt/two_batch_overlap.py +8 -5
  359. sglang/srt/utils/__init__.py +2 -0
  360. sglang/srt/{utils.py → utils/common.py} +445 -77
  361. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  362. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  363. sglang/srt/utils/rpd_utils.py +452 -0
  364. sglang/srt/utils/slow_rank_detector.py +71 -0
  365. sglang/srt/warmup.py +8 -4
  366. sglang/srt/weight_sync/utils.py +2 -2
  367. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  368. sglang/test/few_shot_gsm8k.py +1 -0
  369. sglang/test/get_logits_ut.py +57 -0
  370. sglang/test/run_eval.py +79 -11
  371. sglang/test/runners.py +5 -1
  372. sglang/test/simple_eval_common.py +5 -2
  373. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  374. sglang/test/test_block_fp8.py +2 -2
  375. sglang/test/test_cutlass_moe.py +24 -6
  376. sglang/test/test_deterministic.py +297 -0
  377. sglang/test/test_disaggregation_utils.py +77 -0
  378. sglang/test/test_fp4_moe.py +370 -1
  379. sglang/test/test_programs.py +1 -1
  380. sglang/test/test_utils.py +383 -5
  381. sglang/utils.py +22 -1
  382. sglang/version.py +1 -1
  383. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  384. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
  385. sglang/srt/disaggregation/launch_lb.py +0 -118
  386. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  387. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  388. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  389. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  390. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  391. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  392. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  393. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  394. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  395. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@
4
4
  # Adapted from
5
5
  # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
6
6
  # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
7
- """vLLM distributed state.
7
+ """Distributed state.
8
8
  It takes over the control of the distributed environment from PyTorch.
9
9
  The typical workflow is:
10
10
 
@@ -53,16 +53,26 @@ from sglang.srt.utils import (
53
53
 
54
54
  _is_npu = is_npu()
55
55
  _is_cpu = is_cpu()
56
+ _supports_custom_op = supports_custom_op()
56
57
 
57
58
  IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
58
59
 
59
60
 
61
+ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
62
+
63
+ # use int value instead of ReduceOp.SUM to support torch compile
64
+ REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
65
+
66
+
60
67
  @dataclass
61
68
  class GraphCaptureContext:
62
69
  stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
63
70
 
64
71
 
65
- TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
72
+ @dataclass
73
+ class P2PWork:
74
+ work: Optional[torch.distributed.Work]
75
+ payload: Optional[torch.Tensor]
66
76
 
67
77
 
68
78
  def _split_tensor_dict(
@@ -114,7 +124,7 @@ def _register_group(group: "GroupCoordinator") -> None:
114
124
  _groups[group.unique_name] = weakref.ref(group)
115
125
 
116
126
 
117
- if supports_custom_op():
127
+ if _supports_custom_op:
118
128
 
119
129
  def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
120
130
  assert group_name in _groups, f"Group {group_name} is not found."
@@ -205,12 +215,14 @@ class GroupCoordinator:
205
215
  use_pynccl: bool # a hint of whether to use PyNccl
206
216
  use_pymscclpp: bool # a hint of whether to use PyMsccl
207
217
  use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
218
+ use_torch_symm_mem: bool # a hint of whether to use SymmMemAllReduce
208
219
  use_message_queue_broadcaster: (
209
220
  bool # a hint of whether to use message queue broadcaster
210
221
  )
211
222
  # communicators are only created for world size > 1
212
223
  pynccl_comm: Optional[Any] # PyNccl communicator
213
224
  ca_comm: Optional[Any] # Custom allreduce communicator
225
+ symm_mem_comm: Optional[Any] # Symm mem communicator
214
226
  mq_broadcaster: Optional[Any] # shared memory broadcaster
215
227
 
216
228
  def __init__(
@@ -221,6 +233,7 @@ class GroupCoordinator:
221
233
  use_pynccl: bool,
222
234
  use_pymscclpp: bool,
223
235
  use_custom_allreduce: bool,
236
+ use_torch_symm_mem: bool,
224
237
  use_hpu_communicator: bool,
225
238
  use_xpu_communicator: bool,
226
239
  use_npu_communicator: bool,
@@ -269,12 +282,13 @@ class GroupCoordinator:
269
282
  self.use_pynccl = use_pynccl
270
283
  self.use_pymscclpp = use_pymscclpp
271
284
  self.use_custom_allreduce = use_custom_allreduce
285
+ self.use_torch_symm_mem = use_torch_symm_mem
272
286
  self.use_hpu_communicator = use_hpu_communicator
273
287
  self.use_xpu_communicator = use_xpu_communicator
274
288
  self.use_npu_communicator = use_npu_communicator
275
289
  self.use_message_queue_broadcaster = use_message_queue_broadcaster
276
290
 
277
- # lazy import to avoid documentation build error
291
+ # Lazy import to avoid documentation build error
278
292
  from sglang.srt.distributed.device_communicators.custom_all_reduce import (
279
293
  CustomAllreduce,
280
294
  )
@@ -284,6 +298,9 @@ class GroupCoordinator:
284
298
  from sglang.srt.distributed.device_communicators.pynccl import (
285
299
  PyNcclCommunicator,
286
300
  )
301
+ from sglang.srt.distributed.device_communicators.symm_mem import (
302
+ SymmMemCommunicator,
303
+ )
287
304
 
288
305
  if is_hip():
289
306
  from sglang.srt.distributed.device_communicators.quick_all_reduce import (
@@ -332,6 +349,13 @@ class GroupCoordinator:
332
349
  except Exception as e:
333
350
  logger.warning(f"Failed to initialize QuickAllReduce: {e}")
334
351
 
352
+ self.symm_mem_comm: Optional[SymmMemCommunicator] = None
353
+ if self.use_torch_symm_mem and self.world_size > 1:
354
+ self.symm_mem_comm = SymmMemCommunicator(
355
+ group=self.cpu_group,
356
+ device=self.device,
357
+ )
358
+
335
359
  # Create communicator for other hardware backends
336
360
  from sglang.srt.distributed.device_communicators.hpu_communicator import (
337
361
  HpuCommunicator,
@@ -436,6 +460,7 @@ class GroupCoordinator:
436
460
  # custom allreduce | enabled | enabled |
437
461
  # PyNccl | disabled| enabled |
438
462
  # PyMscclpp | disabled| enabled |
463
+ # TorchSymmMem | disabled| enabled |
439
464
  # torch.distributed | enabled | disabled|
440
465
  #
441
466
  # Note: When custom quick allreduce is enabled, a runtime check
@@ -489,14 +514,12 @@ class GroupCoordinator:
489
514
 
490
515
  if input_.is_cpu:
491
516
  if is_shm_available(input_.dtype, self.world_size, self.local_size):
492
- torch.ops.sgl_kernel.shm_allreduce(
493
- input_, torch.distributed.ReduceOp.SUM
494
- )
517
+ torch.ops.sgl_kernel.shm_allreduce(input_, REDUCE_OP_SUM)
495
518
  else:
496
519
  torch.distributed.all_reduce(input_, group=self.device_group)
497
520
  return input_
498
521
 
499
- if not supports_custom_op():
522
+ if not _supports_custom_op:
500
523
  self._all_reduce_in_place(input_)
501
524
  return input_
502
525
 
@@ -522,23 +545,29 @@ class GroupCoordinator:
522
545
 
523
546
  outplace_all_reduce_method = None
524
547
  if (
525
- self.qr_comm is not None
526
- and not self.qr_comm.disabled
527
- and self.qr_comm.should_quick_allreduce(input_)
528
- ):
529
- outplace_all_reduce_method = "qr"
530
- elif (
531
548
  self.ca_comm is not None
532
549
  and not self.ca_comm.disabled
533
550
  and self.ca_comm.should_custom_ar(input_)
534
551
  ):
535
552
  outplace_all_reduce_method = "ca"
553
+ elif (
554
+ self.qr_comm is not None
555
+ and not self.qr_comm.disabled
556
+ and self.qr_comm.should_quick_allreduce(input_)
557
+ ):
558
+ outplace_all_reduce_method = "qr"
536
559
  elif (
537
560
  self.pymscclpp_comm is not None
538
561
  and not self.pymscclpp_comm.disabled
539
562
  and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
540
563
  ):
541
564
  outplace_all_reduce_method = "pymscclpp"
565
+ elif (
566
+ self.symm_mem_comm is not None
567
+ and not self.symm_mem_comm.disabled
568
+ and self.symm_mem_comm.should_symm_mem_allreduce(input_)
569
+ ):
570
+ outplace_all_reduce_method = "symm_mem"
542
571
  if outplace_all_reduce_method is not None:
543
572
  return torch.ops.sglang.outplace_all_reduce(
544
573
  input_,
@@ -552,16 +581,20 @@ class GroupCoordinator:
552
581
  def _all_reduce_out_place(
553
582
  self, input_: torch.Tensor, outplace_all_reduce_method: str
554
583
  ) -> torch.Tensor:
555
- qr_comm = self.qr_comm
556
584
  ca_comm = self.ca_comm
585
+ qr_comm = self.qr_comm
557
586
  pymscclpp_comm = self.pymscclpp_comm
587
+ symm_mem_comm = self.symm_mem_comm
558
588
  assert any([qr_comm, ca_comm, pymscclpp_comm])
559
- if outplace_all_reduce_method == "qr":
560
- assert not qr_comm.disabled
561
- out = qr_comm.quick_all_reduce(input_)
562
- elif outplace_all_reduce_method == "ca":
589
+ if outplace_all_reduce_method == "ca":
563
590
  assert not ca_comm.disabled
564
591
  out = ca_comm.custom_all_reduce(input_)
592
+ elif outplace_all_reduce_method == "qr":
593
+ assert not qr_comm.disabled
594
+ out = qr_comm.quick_all_reduce(input_)
595
+ elif outplace_all_reduce_method == "symm_mem":
596
+ assert not symm_mem_comm.disabled
597
+ out = symm_mem_comm.all_reduce(input_)
565
598
  else:
566
599
  assert not pymscclpp_comm.disabled
567
600
  out = pymscclpp_comm.all_reduce(input_)
@@ -636,7 +669,7 @@ class GroupCoordinator:
636
669
  )
637
670
 
638
671
  def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
639
- if _is_npu or not supports_custom_op():
672
+ if _is_npu or not _supports_custom_op:
640
673
  self._all_gather_into_tensor(output, input)
641
674
  else:
642
675
  torch.ops.sglang.reg_all_gather_into_tensor(
@@ -696,15 +729,13 @@ class GroupCoordinator:
696
729
  )
697
730
 
698
731
  # All-gather.
699
- if input_.is_cpu and is_shm_available(
700
- input_.dtype, self.world_size, self.local_size
701
- ):
702
- return torch.ops.sgl_kernel.shm_allgather(input_, dim)
703
-
704
732
  if input_.is_cpu:
705
- torch.distributed.all_gather_into_tensor(
706
- output_tensor, input_, group=self.device_group
707
- )
733
+ if is_shm_available(input_.dtype, self.world_size, self.local_size):
734
+ return torch.ops.sgl_kernel.shm_allgather(input_, dim)
735
+ else:
736
+ torch.distributed.all_gather_into_tensor(
737
+ output_tensor, input_, group=self.device_group
738
+ )
708
739
  else:
709
740
  self.all_gather_into_tensor(output_tensor, input_)
710
741
 
@@ -860,76 +891,89 @@ class GroupCoordinator:
860
891
  torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)
861
892
  return objs
862
893
 
863
- def send_object(self, obj: Any, dst: int) -> None:
864
- """Send the input object list to the destination rank."""
865
- """NOTE: `dst` is the local rank of the destination rank."""
894
+ def send_object(
895
+ self,
896
+ obj: Any,
897
+ dst: int,
898
+ async_send: bool = False,
899
+ ) -> List[P2PWork]:
900
+ """
901
+ Send the input object list to the destination rank.
902
+ This function uses the CPU group for all communications.
866
903
 
867
- assert dst < self.world_size, f"Invalid dst rank ({dst})"
904
+ TODO: If you want to use GPU communication, please add a new argument (e.g., data_group, group),
905
+ use other functions (e.g., send), or implement a new function (e.g., send_object_device).
868
906
 
907
+ NOTE: `dst` is the local rank of the destination rank.
908
+ """
909
+
910
+ assert dst < self.world_size, f"Invalid dst rank ({dst})"
869
911
  assert dst != self.rank_in_group, (
870
912
  "Invalid destination rank. Destination rank is the same "
871
913
  "as the current rank."
872
914
  )
915
+ send_func = torch.distributed.isend if async_send else torch.distributed.send
873
916
 
874
917
  # Serialize object to tensor and get the size as well
875
- object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
876
- device=torch.cuda.current_device()
877
- )
878
-
918
+ object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
879
919
  size_tensor = torch.tensor(
880
- [object_tensor.numel()],
881
- dtype=torch.long,
882
- device=torch.cuda.current_device(),
920
+ [object_tensor.numel()], dtype=torch.long, device="cpu"
883
921
  )
884
922
 
885
923
  # Send object size
886
- torch.distributed.send(
887
- size_tensor, dst=self.ranks[dst], group=self.device_group
924
+ p2p_work = []
925
+ size_work = send_func(
926
+ size_tensor,
927
+ self.ranks[dst],
928
+ group=self.cpu_group,
888
929
  )
930
+ if async_send:
931
+ p2p_work.append(P2PWork(size_work, size_tensor))
889
932
 
890
- # Send object
891
- torch.distributed.send(
892
- object_tensor, dst=self.ranks[dst], group=self.device_group
933
+ object_work = send_func(
934
+ object_tensor,
935
+ self.ranks[dst],
936
+ group=self.cpu_group,
893
937
  )
938
+ if async_send:
939
+ p2p_work.append(P2PWork(object_work, object_tensor))
894
940
 
895
- return None
941
+ return p2p_work
896
942
 
897
- def recv_object(self, src: int) -> Any:
943
+ def recv_object(
944
+ self,
945
+ src: int,
946
+ ) -> Any:
898
947
  """Receive the input object list from the source rank."""
899
948
  """NOTE: `src` is the local rank of the source rank."""
900
949
 
901
950
  assert src < self.world_size, f"Invalid src rank ({src})"
902
-
903
951
  assert (
904
952
  src != self.rank_in_group
905
953
  ), "Invalid source rank. Source rank is the same as the current rank."
906
954
 
907
- size_tensor = torch.empty(
908
- 1, dtype=torch.long, device=torch.cuda.current_device()
909
- )
955
+ size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
910
956
 
911
957
  # Receive object size
912
- rank_size = torch.distributed.recv(
913
- size_tensor, src=self.ranks[src], group=self.device_group
958
+ # We have to use irecv here to make it work for both isend and send.
959
+ work = torch.distributed.irecv(
960
+ size_tensor, src=self.ranks[src], group=self.cpu_group
914
961
  )
962
+ work.wait()
915
963
 
916
964
  # Tensor to receive serialized objects into.
917
- object_tensor = torch.empty( # type: ignore[call-overload]
965
+ object_tensor: Any = torch.empty( # type: ignore[call-overload]
918
966
  size_tensor.item(), # type: ignore[arg-type]
919
967
  dtype=torch.uint8,
920
- device=torch.cuda.current_device(),
968
+ device="cpu",
921
969
  )
922
970
 
923
- rank_object = torch.distributed.recv(
924
- object_tensor, src=self.ranks[src], group=self.device_group
971
+ work = torch.distributed.irecv(
972
+ object_tensor, src=self.ranks[src], group=self.cpu_group
925
973
  )
974
+ work.wait()
926
975
 
927
- assert (
928
- rank_object == rank_size
929
- ), "Received object sender rank does not match the size sender rank."
930
-
931
- obj = pickle.loads(object_tensor.cpu().numpy().tobytes())
932
-
976
+ obj = pickle.loads(object_tensor.numpy())
933
977
  return obj
934
978
 
935
979
  def broadcast_tensor_dict(
@@ -1019,12 +1063,13 @@ class GroupCoordinator:
1019
1063
  tensor_dict: Dict[str, Union[torch.Tensor, Any]],
1020
1064
  dst: Optional[int] = None,
1021
1065
  all_gather_group: Optional["GroupCoordinator"] = None,
1022
- ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
1066
+ async_send: bool = False,
1067
+ ) -> Optional[List[P2PWork]]:
1023
1068
  """Send the input tensor dictionary.
1024
1069
  NOTE: `dst` is the local rank of the source rank.
1025
1070
  """
1026
1071
  # Bypass the function if we are using only 1 GPU.
1027
- if not torch.distributed.is_initialized() or self.world_size == 1:
1072
+ if self.world_size == 1:
1028
1073
  return tensor_dict
1029
1074
 
1030
1075
  all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
@@ -1049,7 +1094,10 @@ class GroupCoordinator:
1049
1094
  # 1. Superior D2D transfer bandwidth
1050
1095
  # 2. Ability to overlap send and recv operations
1051
1096
  # Thus the net performance gain justifies this approach.
1052
- self.send_object(metadata_list, dst=dst)
1097
+
1098
+ send_func = torch.distributed.isend if async_send else torch.distributed.send
1099
+ p2p_works = self.send_object(metadata_list, dst=dst, async_send=async_send)
1100
+
1053
1101
  for tensor in tensor_list:
1054
1102
  if tensor.numel() == 0:
1055
1103
  # Skip sending empty tensors.
@@ -1059,15 +1107,11 @@ class GroupCoordinator:
1059
1107
  if all_gather_group is not None and tensor.numel() % all_gather_size == 0:
1060
1108
  tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
1061
1109
 
1062
- if tensor.is_cpu:
1063
- # use metadata_group for CPU tensors
1064
- torch.distributed.send(
1065
- tensor, dst=self.ranks[dst], group=metadata_group
1066
- )
1067
- else:
1068
- # use group for GPU tensors
1069
- torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
1070
- return None
1110
+ comm_group = metadata_group if tensor.is_cpu else group
1111
+ work = send_func(tensor, self.ranks[dst], group=comm_group)
1112
+ if async_send:
1113
+ p2p_works.append(P2PWork(work, tensor))
1114
+ return p2p_works
1071
1115
 
1072
1116
  def recv_tensor_dict(
1073
1117
  self,
@@ -1113,17 +1157,15 @@ class GroupCoordinator:
1113
1157
  orig_shape = tensor.shape
1114
1158
  tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
1115
1159
 
1116
- if tensor.is_cpu:
1117
- # use metadata_group for CPU tensors
1118
- torch.distributed.recv(
1119
- tensor, src=self.ranks[src], group=metadata_group
1120
- )
1121
- else:
1122
- # use group for GPU tensors
1123
- torch.distributed.recv(tensor, src=self.ranks[src], group=group)
1160
+ # We have to use irecv here to make it work for both isend and send.
1161
+ comm_group = metadata_group if tensor.is_cpu else group
1162
+ work = torch.distributed.irecv(
1163
+ tensor, src=self.ranks[src], group=comm_group
1164
+ )
1165
+ work.wait()
1166
+
1124
1167
  if use_all_gather:
1125
- # do the allgather
1126
- tensor = all_gather_group.all_gather(tensor, dim=0) # type: ignore
1168
+ tensor = all_gather_group.all_gather(tensor, dim=0)
1127
1169
  tensor = tensor.reshape(orig_shape)
1128
1170
 
1129
1171
  tensor_dict[key] = tensor
@@ -1201,6 +1243,7 @@ def init_world_group(
1201
1243
  use_pynccl=False,
1202
1244
  use_pymscclpp=False,
1203
1245
  use_custom_allreduce=False,
1246
+ use_torch_symm_mem=False,
1204
1247
  use_hpu_communicator=False,
1205
1248
  use_xpu_communicator=False,
1206
1249
  use_npu_communicator=False,
@@ -1216,11 +1259,14 @@ def init_model_parallel_group(
1216
1259
  use_message_queue_broadcaster: bool = False,
1217
1260
  group_name: Optional[str] = None,
1218
1261
  use_mscclpp_allreduce: Optional[bool] = None,
1262
+ use_symm_mem_allreduce: Optional[bool] = None,
1219
1263
  ) -> GroupCoordinator:
1220
1264
  if use_custom_allreduce is None:
1221
1265
  use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
1222
1266
  if use_mscclpp_allreduce is None:
1223
1267
  use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE
1268
+ if use_symm_mem_allreduce is None:
1269
+ use_symm_mem_allreduce = _ENABLE_SYMM_MEM_ALL_REDUCE
1224
1270
  return GroupCoordinator(
1225
1271
  group_ranks=group_ranks,
1226
1272
  local_rank=local_rank,
@@ -1228,6 +1274,7 @@ def init_model_parallel_group(
1228
1274
  use_pynccl=not _is_npu,
1229
1275
  use_pymscclpp=use_mscclpp_allreduce,
1230
1276
  use_custom_allreduce=use_custom_allreduce,
1277
+ use_torch_symm_mem=use_symm_mem_allreduce,
1231
1278
  use_hpu_communicator=True,
1232
1279
  use_xpu_communicator=True,
1233
1280
  use_npu_communicator=True,
@@ -1313,6 +1360,7 @@ logger = logging.getLogger(__name__)
1313
1360
 
1314
1361
  _ENABLE_CUSTOM_ALL_REDUCE = True
1315
1362
  _ENABLE_MSCCLPP_ALL_REDUCE = False
1363
+ _ENABLE_SYMM_MEM_ALL_REDUCE = False
1316
1364
 
1317
1365
 
1318
1366
  def set_custom_all_reduce(enable: bool):
@@ -1325,6 +1373,11 @@ def set_mscclpp_all_reduce(enable: bool):
1325
1373
  _ENABLE_MSCCLPP_ALL_REDUCE = enable
1326
1374
 
1327
1375
 
1376
+ def set_symm_mem_all_reduce(enable: bool):
1377
+ global _ENABLE_SYMM_MEM_ALL_REDUCE
1378
+ _ENABLE_SYMM_MEM_ALL_REDUCE = enable
1379
+
1380
+
1328
1381
  def init_distributed_environment(
1329
1382
  world_size: int = -1,
1330
1383
  rank: int = -1,
@@ -1461,43 +1514,49 @@ def initialize_model_parallel(
1461
1514
  _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
1462
1515
 
1463
1516
  moe_ep_size = expert_model_parallel_size
1464
-
1465
1517
  moe_tp_size = tensor_model_parallel_size // moe_ep_size
1518
+
1466
1519
  global _MOE_EP
1467
1520
  assert _MOE_EP is None, "expert model parallel group is already initialized"
1468
- group_ranks = []
1469
- for i in range(num_tensor_model_parallel_groups):
1470
- for j in range(moe_tp_size):
1471
- st = i * tensor_model_parallel_size + j
1472
- en = (i + 1) * tensor_model_parallel_size + j
1473
- ranks = list(range(st, en, moe_tp_size))
1474
- group_ranks.append(ranks)
1475
1521
 
1476
- _MOE_EP = init_model_parallel_group(
1477
- group_ranks,
1478
- get_world_group().local_rank,
1479
- backend,
1480
- use_custom_allreduce=False,
1481
- group_name="moe_ep",
1482
- )
1522
+ if moe_ep_size == tensor_model_parallel_size:
1523
+ _MOE_EP = _TP
1524
+ else:
1525
+ # TODO(ch-wan): use split_group to save memory
1526
+ group_ranks = []
1527
+ for i in range(num_tensor_model_parallel_groups):
1528
+ for j in range(moe_tp_size):
1529
+ st = i * tensor_model_parallel_size + j
1530
+ en = (i + 1) * tensor_model_parallel_size + j
1531
+ ranks = list(range(st, en, moe_tp_size))
1532
+ group_ranks.append(ranks)
1533
+ _MOE_EP = init_model_parallel_group(
1534
+ group_ranks,
1535
+ get_world_group().local_rank,
1536
+ backend,
1537
+ group_name="moe_ep",
1538
+ )
1483
1539
 
1484
1540
  global _MOE_TP
1485
1541
  assert _MOE_TP is None, "expert model parallel group is already initialized"
1486
- group_ranks = []
1487
- for i in range(num_tensor_model_parallel_groups):
1488
- for j in range(moe_ep_size):
1489
- st = i * tensor_model_parallel_size + j * moe_tp_size
1490
- en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
1491
- ranks = list(range(st, en))
1492
- group_ranks.append(ranks)
1493
1542
 
1494
- _MOE_TP = init_model_parallel_group(
1495
- group_ranks,
1496
- get_world_group().local_rank,
1497
- backend,
1498
- use_custom_allreduce=False,
1499
- group_name="moe_tp",
1500
- )
1543
+ if moe_tp_size == tensor_model_parallel_size:
1544
+ _MOE_TP = _TP
1545
+ else:
1546
+ # TODO(ch-wan): use split_group to save memory
1547
+ group_ranks = []
1548
+ for i in range(num_tensor_model_parallel_groups):
1549
+ for j in range(moe_ep_size):
1550
+ st = i * tensor_model_parallel_size + j * moe_tp_size
1551
+ en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
1552
+ ranks = list(range(st, en))
1553
+ group_ranks.append(ranks)
1554
+ _MOE_TP = init_model_parallel_group(
1555
+ group_ranks,
1556
+ get_world_group().local_rank,
1557
+ backend,
1558
+ group_name="moe_tp",
1559
+ )
1501
1560
 
1502
1561
  # Build the pipeline model-parallel groups.
1503
1562
  num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
@@ -1583,6 +1642,16 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
1583
1642
  _TP = old_tp_group
1584
1643
 
1585
1644
 
1645
+ def get_world_size():
1646
+ """Return world size for the world group."""
1647
+ return get_world_group().world_size
1648
+
1649
+
1650
+ def get_world_rank():
1651
+ """Return my rank for the world group."""
1652
+ return get_world_group().rank_in_group
1653
+
1654
+
1586
1655
  def get_tensor_model_parallel_world_size():
1587
1656
  """Return world size for the tensor model parallel group."""
1588
1657
  return get_tp_group().world_size
@@ -1593,6 +1662,16 @@ def get_tensor_model_parallel_rank():
1593
1662
  return get_tp_group().rank_in_group
1594
1663
 
1595
1664
 
1665
+ def get_pipeline_model_parallel_world_size():
1666
+ """Return world size for the pipeline model parallel group."""
1667
+ return get_pp_group().world_size
1668
+
1669
+
1670
+ def get_pipeline_model_parallel_rank():
1671
+ """Return my rank for the pipeline model parallel group."""
1672
+ return get_pp_group().rank_in_group
1673
+
1674
+
1596
1675
  def get_moe_expert_parallel_world_size():
1597
1676
  """Return world size for the moe expert parallel group."""
1598
1677
  return get_moe_ep_group().world_size