sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -21,9 +21,10 @@ from __future__ import annotations
21
21
 
22
22
  import logging
23
23
  import threading
24
+ import time
24
25
  from collections import deque
25
26
  from http import HTTPStatus
26
- from typing import TYPE_CHECKING, List, Optional
27
+ from typing import TYPE_CHECKING, List, Optional, Type
27
28
 
28
29
  import torch
29
30
 
@@ -42,7 +43,12 @@ from sglang.srt.disaggregation.utils import (
42
43
  poll_and_all_reduce,
43
44
  prepare_abort,
44
45
  )
45
- from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
46
+ from sglang.srt.managers.schedule_batch import (
47
+ FINISH_LENGTH,
48
+ Req,
49
+ RequestStage,
50
+ ScheduleBatch,
51
+ )
46
52
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
47
53
  from sglang.srt.utils import (
48
54
  DynamicGradMode,
@@ -140,8 +146,10 @@ class PrefillBootstrapQueue:
140
146
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
141
147
  kv_args.gpu_id = self.scheduler.gpu_id
142
148
 
143
- kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
144
- kv_manager = kv_manager_class(
149
+ kv_manager_class: Type[BaseKVManager] = get_kv_class(
150
+ self.transfer_backend, KVClassType.MANAGER
151
+ )
152
+ kv_manager: BaseKVManager = kv_manager_class(
145
153
  kv_args,
146
154
  DisaggregationMode.PREFILL,
147
155
  self.scheduler.server_args,
@@ -168,6 +176,7 @@ class PrefillBootstrapQueue:
168
176
  pp_rank=self.pp_rank,
169
177
  )
170
178
  self._process_req(req)
179
+ req.add_latency(RequestStage.PREFILL_PREPARE)
171
180
  self.queue.append(req)
172
181
 
173
182
  def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
@@ -254,8 +263,11 @@ class PrefillBootstrapQueue:
254
263
 
255
264
  num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
256
265
  req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
266
+
257
267
  bootstrapped_reqs.append(req)
258
268
  indices_to_remove.add(i)
269
+ req.time_stats.wait_queue_entry_time = time.perf_counter()
270
+ req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
259
271
 
260
272
  self.queue = [
261
273
  entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
@@ -397,11 +409,11 @@ class SchedulerDisaggregationPrefillMixin:
397
409
  for i, (req, next_token_id) in enumerate(
398
410
  zip(batch.reqs, next_token_ids, strict=True)
399
411
  ):
400
- req: Req
401
412
  if req.is_chunked <= 0:
402
413
  # There is no output_ids for prefill
403
414
  req.output_ids.append(next_token_id)
404
415
  self.tree_cache.cache_unfinished_req(req) # update the tree and lock
416
+ req.add_latency(RequestStage.PREFILL_FORWARD)
405
417
  self.disagg_prefill_inflight_queue.append(req)
406
418
  if (
407
419
  logits_output is not None
@@ -410,9 +422,16 @@ class SchedulerDisaggregationPrefillMixin:
410
422
  last_hidden_index = (
411
423
  hidden_state_offset + extend_input_len_per_req[i] - 1
412
424
  )
413
- req.hidden_states_tensor = (
414
- logits_output.hidden_states[last_hidden_index].cpu().clone()
415
- )
425
+ req.output_topk_p = batch.spec_info.topk_p[i]
426
+ req.output_topk_index = batch.spec_info.topk_index[i]
427
+ if self.spec_algorithm.is_eagle3():
428
+ req.hidden_states_tensor = (
429
+ batch.spec_info.hidden_states[i].cpu().clone()
430
+ )
431
+ else:
432
+ req.hidden_states_tensor = (
433
+ logits_output.hidden_states[last_hidden_index].cpu().clone()
434
+ )
416
435
  hidden_state_offset += extend_input_len_per_req[i]
417
436
  else:
418
437
  req.hidden_states_tensor = None
@@ -432,6 +451,7 @@ class SchedulerDisaggregationPrefillMixin:
432
451
  )
433
452
  logprob_pt += num_input_logprobs
434
453
  self.send_kv_chunk(req, last_chunk=True)
454
+ req.time_stats.prefill_transfer_queue_entry_time = time.perf_counter()
435
455
 
436
456
  if req.grammar is not None:
437
457
  # FIXME: this try-except block is for handling unexpected xgrammar issue.
@@ -529,6 +549,9 @@ class SchedulerDisaggregationPrefillMixin:
529
549
  else:
530
550
  assert False, f"Unexpected polling state {poll=}"
531
551
 
552
+ for req in done_reqs:
553
+ req.time_stats.completion_time = time.perf_counter()
554
+
532
555
  # Stream requests which have finished transfer
533
556
  self.stream_output(
534
557
  done_reqs,
@@ -537,6 +560,7 @@ class SchedulerDisaggregationPrefillMixin:
537
560
  )
538
561
  for req in done_reqs:
539
562
  req: Req
563
+ req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
540
564
  self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
541
565
  req.metadata_buffer_index = -1
542
566
 
@@ -665,7 +689,6 @@ class SchedulerDisaggregationPrefillMixin:
665
689
  self.running_mbs = [
666
690
  ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
667
691
  ]
668
- bids = [None] * self.pp_size
669
692
  pp_outputs: Optional[PPProxyTensors] = None
670
693
 
671
694
  # Either success or failed
@@ -737,10 +760,7 @@ class SchedulerDisaggregationPrefillMixin:
737
760
  # send the outputs to the next step
738
761
  if self.pp_group.is_last_rank:
739
762
  if self.cur_batch:
740
- next_token_ids, bids[mb_id] = (
741
- result.next_token_ids,
742
- result.bid,
743
- )
763
+ next_token_ids = result.next_token_ids
744
764
  pp_outputs = PPProxyTensors(
745
765
  {
746
766
  "next_token_ids": next_token_ids,
@@ -777,7 +797,6 @@ class SchedulerDisaggregationPrefillMixin:
777
797
  next_token_ids=next_pp_outputs["next_token_ids"],
778
798
  extend_input_len_per_req=None,
779
799
  extend_logprob_start_len_per_req=None,
780
- bid=bids[next_mb_id],
781
800
  can_run_cuda_graph=result.can_run_cuda_graph,
782
801
  )
783
802
  self.process_batch_result_disagg_prefill(
@@ -794,8 +813,6 @@ class SchedulerDisaggregationPrefillMixin:
794
813
 
795
814
  # carry the outputs to the next stage
796
815
  if not self.pp_group.is_last_rank:
797
- if self.cur_batch:
798
- bids[mb_id] = result.bid
799
816
  if pp_outputs:
800
817
  # send the outputs from the last round to let the next stage worker run post processing
801
818
  self.pp_group.send_tensor_dict(
@@ -814,8 +831,10 @@ class SchedulerDisaggregationPrefillMixin:
814
831
 
815
832
  # send out proxy tensors to the next stage
816
833
  if self.cur_batch:
834
+ # FIXME(lsyin): remove this assert
835
+ assert result.pp_hidden_states_proxy_tensors.tensors is not None
817
836
  self.pp_group.send_tensor_dict(
818
- result.pp_hidden_states_proxy_tensors,
837
+ result.pp_hidden_states_proxy_tensors.tensors,
819
838
  all_gather_group=self.attn_tp_group,
820
839
  )
821
840
 
@@ -1,21 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
- import dataclasses
4
3
  import os
5
4
  import random
6
- import threading
7
- import warnings
8
5
  from collections import deque
9
6
  from contextlib import nullcontext
10
7
  from enum import Enum
11
- from typing import TYPE_CHECKING, List, Optional
8
+ from typing import TYPE_CHECKING, Optional, Type
12
9
 
13
10
  import numpy as np
14
- import requests
15
11
  import torch
16
12
  import torch.distributed as dist
17
13
 
18
- from sglang.srt.utils import get_ip, is_npu
14
+ from sglang.srt.utils import is_npu
19
15
 
20
16
  if TYPE_CHECKING:
21
17
  from sglang.srt.managers.schedule_batch import Req
@@ -89,7 +85,7 @@ class MetadataBuffers:
89
85
  self,
90
86
  size: int,
91
87
  hidden_size: int,
92
- dtype: torch.dtype,
88
+ hidden_states_dtype: torch.dtype,
93
89
  max_top_logprobs_num: int = 128,
94
90
  custom_mem_pool: torch.cuda.MemPool = None,
95
91
  ):
@@ -111,7 +107,9 @@ class MetadataBuffers:
111
107
  # We transfer the metadata of first output token to decode
112
108
  # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
113
109
  self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
114
-
110
+ self.cached_tokens = torch.zeros(
111
+ (size, 16), dtype=torch.int32, device=device
112
+ )
115
113
  self.output_token_logprobs_val = torch.zeros(
116
114
  (size, 16), dtype=torch.float32, device=device
117
115
  )
@@ -124,33 +122,49 @@ class MetadataBuffers:
124
122
  self.output_top_logprobs_idx = torch.zeros(
125
123
  (size, max_top_logprobs_num), dtype=torch.int32, device=device
126
124
  )
125
+ # For PD + spec decode
126
+ self.output_topk_p = torch.zeros(
127
+ (size, 16), dtype=torch.float32, device=device
128
+ )
129
+ self.output_topk_index = torch.zeros(
130
+ (size, 16), dtype=torch.int64, device=device
131
+ )
127
132
  self.output_hidden_states = torch.zeros(
128
- (size, hidden_size), dtype=dtype, device=device
133
+ (size, hidden_size), dtype=hidden_states_dtype, device=device
129
134
  )
130
135
 
131
136
  def get_buf_infos(self):
132
137
  ptrs = [
133
138
  self.output_ids.data_ptr(),
139
+ self.cached_tokens.data_ptr(),
134
140
  self.output_token_logprobs_val.data_ptr(),
135
141
  self.output_token_logprobs_idx.data_ptr(),
136
142
  self.output_top_logprobs_val.data_ptr(),
137
143
  self.output_top_logprobs_idx.data_ptr(),
144
+ self.output_topk_p.data_ptr(),
145
+ self.output_topk_index.data_ptr(),
138
146
  self.output_hidden_states.data_ptr(),
139
147
  ]
140
148
  data_lens = [
141
149
  self.output_ids.nbytes,
150
+ self.cached_tokens.nbytes,
142
151
  self.output_token_logprobs_val.nbytes,
143
152
  self.output_token_logprobs_idx.nbytes,
144
153
  self.output_top_logprobs_val.nbytes,
145
154
  self.output_top_logprobs_idx.nbytes,
155
+ self.output_topk_p.nbytes,
156
+ self.output_topk_index.nbytes,
146
157
  self.output_hidden_states.nbytes,
147
158
  ]
148
159
  item_lens = [
149
160
  self.output_ids[0].nbytes,
161
+ self.cached_tokens[0].nbytes,
150
162
  self.output_token_logprobs_val[0].nbytes,
151
163
  self.output_token_logprobs_idx[0].nbytes,
152
164
  self.output_top_logprobs_val[0].nbytes,
153
165
  self.output_top_logprobs_idx[0].nbytes,
166
+ self.output_topk_p[0].nbytes,
167
+ self.output_topk_index[0].nbytes,
154
168
  self.output_hidden_states[0].nbytes,
155
169
  ]
156
170
  return ptrs, data_lens, item_lens
@@ -158,16 +172,20 @@ class MetadataBuffers:
158
172
  def get_buf(self, idx: int):
159
173
  return (
160
174
  self.output_ids[idx],
175
+ self.cached_tokens[idx],
161
176
  self.output_token_logprobs_val[idx],
162
177
  self.output_token_logprobs_idx[idx],
163
178
  self.output_top_logprobs_val[idx],
164
179
  self.output_top_logprobs_idx[idx],
180
+ self.output_topk_p[idx],
181
+ self.output_topk_index[idx],
165
182
  self.output_hidden_states[idx],
166
183
  )
167
184
 
168
185
  def set_buf(self, req: Req):
169
186
 
170
187
  self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
188
+ self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens
171
189
  if req.return_logprob:
172
190
  if req.output_token_logprobs_val: # not none or empty list
173
191
  self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
@@ -190,8 +208,17 @@ class MetadataBuffers:
190
208
  ] = torch.tensor(
191
209
  req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
192
210
  )
193
- # for PD + spec decode
211
+ # For PD + spec decode
194
212
  if req.hidden_states_tensor is not None:
213
+ # speculative_eagle_topk should not be greater than 16 currently
214
+ topk = req.output_topk_p.size(0)
215
+
216
+ self.output_topk_p[req.metadata_buffer_index, :topk].copy_(
217
+ req.output_topk_p
218
+ )
219
+ self.output_topk_index[req.metadata_buffer_index, :topk].copy_(
220
+ req.output_topk_index
221
+ )
195
222
  self.output_hidden_states[req.metadata_buffer_index].copy_(
196
223
  req.hidden_states_tensor
197
224
  )
@@ -217,7 +244,9 @@ class KVClassType(Enum):
217
244
  BOOTSTRAP_SERVER = "bootstrap_server"
218
245
 
219
246
 
220
- def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
247
+ def get_kv_class(
248
+ transfer_backend: TransferBackend, class_type: KVClassType
249
+ ) -> Optional[Type]:
221
250
  from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
222
251
 
223
252
  if transfer_backend == TransferBackend.MOONCAKE:
@@ -305,49 +334,6 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
305
334
  return (num_kv_indices + page_size - 1) // page_size
306
335
 
307
336
 
308
- #########################
309
- # PDLB Registry
310
- #########################
311
-
312
-
313
- @dataclasses.dataclass
314
- class PDRegistryRequest:
315
- """A request to register a machine itself to the LB."""
316
-
317
- mode: str
318
- registry_url: str
319
- bootstrap_port: Optional[int] = None
320
-
321
- def __post_init__(self):
322
- if self.mode == "prefill" and self.bootstrap_port is None:
323
- raise ValueError("Bootstrap port must be set in PREFILL mode.")
324
- elif self.mode == "decode" and self.bootstrap_port is not None:
325
- raise ValueError("Bootstrap port must not be set in DECODE mode.")
326
- elif self.mode not in ["prefill", "decode"]:
327
- raise ValueError(
328
- f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
329
- )
330
-
331
-
332
- def register_disaggregation_server(
333
- mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
334
- ):
335
- boostrap_port = bootstrap_port if mode == "prefill" else None
336
- registry_request = PDRegistryRequest(
337
- mode=mode,
338
- registry_url=f"http://{get_ip()}:{server_port}",
339
- bootstrap_port=boostrap_port,
340
- )
341
- res = requests.post(
342
- f"{pdlb_url}/register",
343
- json=dataclasses.asdict(registry_request),
344
- )
345
- if res.status_code != 200:
346
- warnings.warn(
347
- f"Failed to register disaggregation server: {res.status_code} {res.text}"
348
- )
349
-
350
-
351
337
  #########################
352
338
  # Misc
353
339
  #########################
@@ -0,0 +1,16 @@
1
+ MiB = 1024 * 1024
2
+
3
+ SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
4
+ 9: {
5
+ 2: 64 * MiB, # 64 MB
6
+ 4: 32 * MiB, # 32 MB
7
+ 6: 64 * MiB, # 64 MB
8
+ 8: 64 * MiB, # 64 MB
9
+ },
10
+ 10: {
11
+ 2: 64 * MiB, # 64 MB
12
+ 4: 32 * MiB, # 32 MB
13
+ 6: 128 * MiB, # 128 MB
14
+ 8: 128 * MiB, # 128 MB
15
+ },
16
+ }
@@ -18,7 +18,7 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
18
18
 
19
19
  from sglang.srt.utils import (
20
20
  format_tcp_address,
21
- get_ip,
21
+ get_local_ip_auto,
22
22
  get_open_port,
23
23
  is_valid_ipv6_address,
24
24
  )
@@ -191,7 +191,9 @@ class MessageQueue:
191
191
  self.n_remote_reader = n_remote_reader
192
192
 
193
193
  if connect_ip is None:
194
- connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
194
+ connect_ip = (
195
+ get_local_ip_auto("0.0.0.0") if n_remote_reader > 0 else "127.0.0.1"
196
+ )
195
197
 
196
198
  context = Context()
197
199
 
@@ -0,0 +1,164 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/bf214ca22625e311a2c4c0dfbf7af19128f4919c/vllm/distributed/device_communicators/symm_mem.py
2
+ import logging
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+ from torch.distributed import ProcessGroup
8
+
9
+ from sglang.srt.distributed.device_communicators.all_reduce_utils import (
10
+ SYMM_MEM_ALL_REDUCE_MAX_SIZES,
11
+ )
12
+ from sglang.srt.utils import get_device_capability, is_cuda, is_hip
13
+
14
+ try:
15
+ import torch.distributed._symmetric_memory as torch_symm_mem
16
+
17
+ symm_mem_available = True
18
+ except ImportError:
19
+ symm_mem_available = False
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ _is_cuda = is_cuda()
25
+ _is_hip = is_hip()
26
+
27
+ symm_mem_is_available = False
28
+ if _is_hip:
29
+ symm_mem_is_available = False
30
+ if _is_cuda:
31
+ symm_mem_is_available = True
32
+
33
+
34
+ class SymmMemCommunicator:
35
+ """
36
+ Thin wrapper around symmetric-memory collectives.
37
+
38
+ This communicator:
39
+ - Validates device capability and world size.
40
+ - Allocates a shared symmetric buffer.
41
+ - Chooses between 'multimem' and 'two-shot' all-reduce kernels.
42
+ - Exposes a fast-path all_reduce() compatible with bfloat16 inputs.
43
+
44
+ If any prerequisite is not met, the instance remains disabled and will
45
+ decline to perform symmetric-memory all-reduce.
46
+ """
47
+
48
+ # Mapping: compute capability major -> supported world sizes for multimem
49
+ # If the current (cc_major, world_size) is not listed, we fall back
50
+ # to the two-shot path.
51
+ _WORLD_SIZES_MULTIMEM = {
52
+ 9: [4, 6, 8],
53
+ 10: [6, 8],
54
+ }
55
+
56
+ def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]):
57
+ """
58
+ Args:
59
+ group: Torch process group used for rendezvous and naming.
60
+ device: Target CUDA device (index, 'cuda:X', or torch.device).
61
+ """
62
+
63
+ self.disabled = True
64
+
65
+ if not symm_mem_available:
66
+ return
67
+
68
+ if isinstance(device, int):
69
+ device = torch.device(f"cuda:{device}")
70
+ elif isinstance(device, str):
71
+ device = torch.device(device)
72
+ torch.cuda.set_device(device)
73
+ self.dtype = torch.bfloat16
74
+ self.device = device
75
+ self.group = group
76
+ self.world_size = dist.get_world_size(self.group)
77
+ self.device_capability = torch.cuda.get_device_capability(device)[0]
78
+ if self.device_capability < 9:
79
+ logger.warning(
80
+ "SymmMemCommunicator: Device capability %s not supported, "
81
+ "communicator is not available.",
82
+ self.device_capability,
83
+ )
84
+ return
85
+ if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]:
86
+ logger.warning(
87
+ "SymmMemCommunicator: World size %d not supported, "
88
+ "communicator is not available.",
89
+ self.world_size,
90
+ )
91
+ return
92
+ self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
93
+ self.world_size
94
+ ]
95
+ self.buffer = torch_symm_mem.empty(
96
+ self.max_size // self.dtype.itemsize,
97
+ device=self.device,
98
+ dtype=self.dtype,
99
+ )
100
+ handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
101
+ if handle.multicast_ptr == 0:
102
+ logger.warning(
103
+ "SymmMemCommunicator: symmetric memory "
104
+ "multicast operations are not supported."
105
+ )
106
+ self.buffer = None
107
+ self.disabled = True
108
+ return
109
+ self.disabled = False
110
+
111
+ def should_symm_mem_allreduce(self, inp: torch.Tensor):
112
+ """
113
+ Fast-path eligibility check for a given tensor.
114
+
115
+ Conditions:
116
+ - Communicator must be enabled.
117
+ - dtype must be bfloat16 (matches kernel + buffer dtype).
118
+ - Total byte size must be 4-byte aligned (hardware requirement).
119
+ - Payload must be smaller than the symmetric-memory max size.
120
+
121
+ Returns:
122
+ True if the symmetric-memory path can handle this tensor.
123
+ """
124
+ if self.disabled:
125
+ return False
126
+ if inp.dtype != self.dtype:
127
+ return False
128
+ inp_size = inp.numel() * inp.element_size()
129
+ # enforce 4-byte alignment
130
+ if inp_size % 4 != 0:
131
+ return False
132
+ return inp_size < self.max_size
133
+
134
+ def all_reduce(
135
+ self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None
136
+ ) -> Optional[torch.Tensor]:
137
+ """
138
+ Perform an in-place sum all-reduce via symmetric memory.
139
+
140
+ Args:
141
+ inp: Input tensor on the target CUDA device (bfloat16).
142
+ out: Optional output tensor; if omitted, a new tensor is allocated.
143
+
144
+ Returns:
145
+ The reduced tensor (same shape as inp), or None if disabled.
146
+
147
+ Implementation details:
148
+ - Stages 'inp' into the symmetric buffer.
149
+ - Selects 'multimem' or 'two_shot' kernel based on topology.
150
+ - Writes the result into 'out' and returns it.
151
+ """
152
+ if out is None:
153
+ out = torch.empty_like(inp)
154
+ self.buffer[: inp.numel()].copy_(inp.view(-1))
155
+ if self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability]:
156
+ torch.ops.symm_mem.multimem_all_reduce_(
157
+ self.buffer[: inp.numel()], "sum", self.group.group_name
158
+ )
159
+ else:
160
+ torch.ops.symm_mem.two_shot_all_reduce_(
161
+ self.buffer[: inp.numel()], "sum", self.group.group_name
162
+ )
163
+ out.copy_(self.buffer[: inp.numel()].view(out.shape))
164
+ return out