sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -21,9 +21,10 @@ from __future__ import annotations
21
21
 
22
22
  import logging
23
23
  import threading
24
+ import time
24
25
  from collections import deque
25
26
  from http import HTTPStatus
26
- from typing import TYPE_CHECKING, List, Optional
27
+ from typing import TYPE_CHECKING, List, Optional, Type
27
28
 
28
29
  import torch
29
30
 
@@ -42,7 +43,12 @@ from sglang.srt.disaggregation.utils import (
42
43
  poll_and_all_reduce,
43
44
  prepare_abort,
44
45
  )
45
- from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
46
+ from sglang.srt.managers.schedule_batch import (
47
+ FINISH_LENGTH,
48
+ Req,
49
+ RequestStage,
50
+ ScheduleBatch,
51
+ )
46
52
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
47
53
  from sglang.srt.utils import (
48
54
  DynamicGradMode,
@@ -140,8 +146,10 @@ class PrefillBootstrapQueue:
140
146
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
141
147
  kv_args.gpu_id = self.scheduler.gpu_id
142
148
 
143
- kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
144
- kv_manager = kv_manager_class(
149
+ kv_manager_class: Type[BaseKVManager] = get_kv_class(
150
+ self.transfer_backend, KVClassType.MANAGER
151
+ )
152
+ kv_manager: BaseKVManager = kv_manager_class(
145
153
  kv_args,
146
154
  DisaggregationMode.PREFILL,
147
155
  self.scheduler.server_args,
@@ -168,6 +176,7 @@ class PrefillBootstrapQueue:
168
176
  pp_rank=self.pp_rank,
169
177
  )
170
178
  self._process_req(req)
179
+ req.add_latency(RequestStage.PREFILL_PREPARE)
171
180
  self.queue.append(req)
172
181
 
173
182
  def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
@@ -254,8 +263,11 @@ class PrefillBootstrapQueue:
254
263
 
255
264
  num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
256
265
  req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
266
+
257
267
  bootstrapped_reqs.append(req)
258
268
  indices_to_remove.add(i)
269
+ req.time_stats.wait_queue_entry_time = time.perf_counter()
270
+ req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
259
271
 
260
272
  self.queue = [
261
273
  entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
@@ -309,6 +321,8 @@ class SchedulerDisaggregationPrefillMixin:
309
321
  self.result_queue = deque()
310
322
 
311
323
  while True:
324
+ self.launch_last_batch_sample_if_needed()
325
+
312
326
  recv_reqs = self.recv_requests()
313
327
  self.process_input_requests(recv_reqs)
314
328
  self.waiting_queue.extend(
@@ -324,21 +338,8 @@ class SchedulerDisaggregationPrefillMixin:
324
338
  result = self.run_batch(batch)
325
339
  self.result_queue.append((batch.copy(), result))
326
340
 
327
- if self.last_batch is None:
328
- # Create a dummy first batch to start the pipeline for overlap schedule.
329
- # It is now used for triggering the sampling_info_done event.
330
- tmp_batch = ScheduleBatch(
331
- reqs=None,
332
- forward_mode=ForwardMode.DUMMY_FIRST,
333
- next_batch_sampling_info=self.tp_worker.cur_sampling_info,
334
- )
335
- self.set_next_batch_sampling_info_done(tmp_batch)
336
-
337
341
  if self.last_batch:
338
342
  tmp_batch, tmp_result = self.result_queue.popleft()
339
- tmp_batch.next_batch_sampling_info = (
340
- self.tp_worker.cur_sampling_info if batch else None
341
- )
342
343
  self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
343
344
 
344
345
  if len(self.disagg_prefill_inflight_queue) > 0:
@@ -356,7 +357,6 @@ class SchedulerDisaggregationPrefillMixin:
356
357
  self: Scheduler,
357
358
  batch: ScheduleBatch,
358
359
  result: GenerationBatchResult,
359
- launch_done: Optional[threading.Event] = None,
360
360
  ) -> None:
361
361
  """
362
362
  Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
@@ -367,41 +367,40 @@ class SchedulerDisaggregationPrefillMixin:
367
367
  next_token_ids,
368
368
  extend_input_len_per_req,
369
369
  extend_logprob_start_len_per_req,
370
+ copy_done,
370
371
  ) = (
371
372
  result.logits_output,
372
373
  result.next_token_ids,
373
374
  result.extend_input_len_per_req,
374
375
  result.extend_logprob_start_len_per_req,
376
+ result.copy_done,
375
377
  )
376
378
 
379
+ if copy_done is not None:
380
+ copy_done.synchronize()
381
+
377
382
  logprob_pt = 0
378
383
  # Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
379
- if self.enable_overlap:
380
- # wait
381
- logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
382
- launch_done
383
- )
384
- else:
385
- next_token_ids = result.next_token_ids.tolist()
386
- if batch.return_logprob:
387
- if logits_output.next_token_logprobs is not None:
388
- logits_output.next_token_logprobs = (
389
- logits_output.next_token_logprobs.tolist()
390
- )
391
- if logits_output.input_token_logprobs is not None:
392
- logits_output.input_token_logprobs = tuple(
393
- logits_output.input_token_logprobs.tolist()
394
- )
384
+ next_token_ids = result.next_token_ids.tolist()
385
+ if batch.return_logprob:
386
+ if logits_output.next_token_logprobs is not None:
387
+ logits_output.next_token_logprobs = (
388
+ logits_output.next_token_logprobs.tolist()
389
+ )
390
+ if logits_output.input_token_logprobs is not None:
391
+ logits_output.input_token_logprobs = tuple(
392
+ logits_output.input_token_logprobs.tolist()
393
+ )
395
394
 
396
395
  hidden_state_offset = 0
397
396
  for i, (req, next_token_id) in enumerate(
398
397
  zip(batch.reqs, next_token_ids, strict=True)
399
398
  ):
400
- req: Req
401
399
  if req.is_chunked <= 0:
402
400
  # There is no output_ids for prefill
403
401
  req.output_ids.append(next_token_id)
404
402
  self.tree_cache.cache_unfinished_req(req) # update the tree and lock
403
+ req.add_latency(RequestStage.PREFILL_FORWARD)
405
404
  self.disagg_prefill_inflight_queue.append(req)
406
405
  if (
407
406
  logits_output is not None
@@ -410,9 +409,16 @@ class SchedulerDisaggregationPrefillMixin:
410
409
  last_hidden_index = (
411
410
  hidden_state_offset + extend_input_len_per_req[i] - 1
412
411
  )
413
- req.hidden_states_tensor = (
414
- logits_output.hidden_states[last_hidden_index].cpu().clone()
415
- )
412
+ req.output_topk_p = batch.spec_info.topk_p[i]
413
+ req.output_topk_index = batch.spec_info.topk_index[i]
414
+ if self.spec_algorithm.is_eagle3():
415
+ req.hidden_states_tensor = (
416
+ batch.spec_info.hidden_states[i].cpu().clone()
417
+ )
418
+ else:
419
+ req.hidden_states_tensor = (
420
+ logits_output.hidden_states[last_hidden_index].cpu().clone()
421
+ )
416
422
  hidden_state_offset += extend_input_len_per_req[i]
417
423
  else:
418
424
  req.hidden_states_tensor = None
@@ -432,6 +438,7 @@ class SchedulerDisaggregationPrefillMixin:
432
438
  )
433
439
  logprob_pt += num_input_logprobs
434
440
  self.send_kv_chunk(req, last_chunk=True)
441
+ req.time_stats.prefill_transfer_queue_entry_time = time.perf_counter()
435
442
 
436
443
  if req.grammar is not None:
437
444
  # FIXME: this try-except block is for handling unexpected xgrammar issue.
@@ -471,8 +478,6 @@ class SchedulerDisaggregationPrefillMixin:
471
478
  if self.enable_overlap:
472
479
  self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
473
480
 
474
- # We need to remove the sync in the following function for overlap schedule.
475
- self.set_next_batch_sampling_info_done(batch)
476
481
  self.maybe_send_health_check_signal()
477
482
 
478
483
  def process_disagg_prefill_inflight_queue(
@@ -529,6 +534,9 @@ class SchedulerDisaggregationPrefillMixin:
529
534
  else:
530
535
  assert False, f"Unexpected polling state {poll=}"
531
536
 
537
+ for req in done_reqs:
538
+ req.time_stats.completion_time = time.perf_counter()
539
+
532
540
  # Stream requests which have finished transfer
533
541
  self.stream_output(
534
542
  done_reqs,
@@ -537,6 +545,7 @@ class SchedulerDisaggregationPrefillMixin:
537
545
  )
538
546
  for req in done_reqs:
539
547
  req: Req
548
+ req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
540
549
  self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
541
550
  req.metadata_buffer_index = -1
542
551
 
@@ -665,7 +674,6 @@ class SchedulerDisaggregationPrefillMixin:
665
674
  self.running_mbs = [
666
675
  ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
667
676
  ]
668
- bids = [None] * self.pp_size
669
677
  pp_outputs: Optional[PPProxyTensors] = None
670
678
 
671
679
  # Either success or failed
@@ -737,10 +745,7 @@ class SchedulerDisaggregationPrefillMixin:
737
745
  # send the outputs to the next step
738
746
  if self.pp_group.is_last_rank:
739
747
  if self.cur_batch:
740
- next_token_ids, bids[mb_id] = (
741
- result.next_token_ids,
742
- result.bid,
743
- )
748
+ next_token_ids = result.next_token_ids
744
749
  pp_outputs = PPProxyTensors(
745
750
  {
746
751
  "next_token_ids": next_token_ids,
@@ -777,7 +782,6 @@ class SchedulerDisaggregationPrefillMixin:
777
782
  next_token_ids=next_pp_outputs["next_token_ids"],
778
783
  extend_input_len_per_req=None,
779
784
  extend_logprob_start_len_per_req=None,
780
- bid=bids[next_mb_id],
781
785
  can_run_cuda_graph=result.can_run_cuda_graph,
782
786
  )
783
787
  self.process_batch_result_disagg_prefill(
@@ -794,8 +798,6 @@ class SchedulerDisaggregationPrefillMixin:
794
798
 
795
799
  # carry the outputs to the next stage
796
800
  if not self.pp_group.is_last_rank:
797
- if self.cur_batch:
798
- bids[mb_id] = result.bid
799
801
  if pp_outputs:
800
802
  # send the outputs from the last round to let the next stage worker run post processing
801
803
  self.pp_group.send_tensor_dict(
@@ -814,8 +816,10 @@ class SchedulerDisaggregationPrefillMixin:
814
816
 
815
817
  # send out proxy tensors to the next stage
816
818
  if self.cur_batch:
819
+ # FIXME(lsyin): remove this assert
820
+ assert result.pp_hidden_states_proxy_tensors.tensors is not None
817
821
  self.pp_group.send_tensor_dict(
818
- result.pp_hidden_states_proxy_tensors,
822
+ result.pp_hidden_states_proxy_tensors.tensors,
819
823
  all_gather_group=self.attn_tp_group,
820
824
  )
821
825
 
@@ -1,21 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
- import dataclasses
4
3
  import os
5
4
  import random
6
- import threading
7
- import warnings
8
5
  from collections import deque
9
6
  from contextlib import nullcontext
10
7
  from enum import Enum
11
- from typing import TYPE_CHECKING, List, Optional
8
+ from typing import TYPE_CHECKING, Optional, Type
12
9
 
13
10
  import numpy as np
14
- import requests
15
11
  import torch
16
12
  import torch.distributed as dist
17
13
 
18
- from sglang.srt.utils import get_ip, is_npu
14
+ from sglang.srt.utils import is_npu
19
15
 
20
16
  if TYPE_CHECKING:
21
17
  from sglang.srt.managers.schedule_batch import Req
@@ -89,7 +85,7 @@ class MetadataBuffers:
89
85
  self,
90
86
  size: int,
91
87
  hidden_size: int,
92
- dtype: torch.dtype,
88
+ hidden_states_dtype: torch.dtype,
93
89
  max_top_logprobs_num: int = 128,
94
90
  custom_mem_pool: torch.cuda.MemPool = None,
95
91
  ):
@@ -111,7 +107,9 @@ class MetadataBuffers:
111
107
  # We transfer the metadata of first output token to decode
112
108
  # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
113
109
  self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
114
-
110
+ self.cached_tokens = torch.zeros(
111
+ (size, 16), dtype=torch.int32, device=device
112
+ )
115
113
  self.output_token_logprobs_val = torch.zeros(
116
114
  (size, 16), dtype=torch.float32, device=device
117
115
  )
@@ -124,33 +122,49 @@ class MetadataBuffers:
124
122
  self.output_top_logprobs_idx = torch.zeros(
125
123
  (size, max_top_logprobs_num), dtype=torch.int32, device=device
126
124
  )
125
+ # For PD + spec decode
126
+ self.output_topk_p = torch.zeros(
127
+ (size, 16), dtype=torch.float32, device=device
128
+ )
129
+ self.output_topk_index = torch.zeros(
130
+ (size, 16), dtype=torch.int64, device=device
131
+ )
127
132
  self.output_hidden_states = torch.zeros(
128
- (size, hidden_size), dtype=dtype, device=device
133
+ (size, hidden_size), dtype=hidden_states_dtype, device=device
129
134
  )
130
135
 
131
136
  def get_buf_infos(self):
132
137
  ptrs = [
133
138
  self.output_ids.data_ptr(),
139
+ self.cached_tokens.data_ptr(),
134
140
  self.output_token_logprobs_val.data_ptr(),
135
141
  self.output_token_logprobs_idx.data_ptr(),
136
142
  self.output_top_logprobs_val.data_ptr(),
137
143
  self.output_top_logprobs_idx.data_ptr(),
144
+ self.output_topk_p.data_ptr(),
145
+ self.output_topk_index.data_ptr(),
138
146
  self.output_hidden_states.data_ptr(),
139
147
  ]
140
148
  data_lens = [
141
149
  self.output_ids.nbytes,
150
+ self.cached_tokens.nbytes,
142
151
  self.output_token_logprobs_val.nbytes,
143
152
  self.output_token_logprobs_idx.nbytes,
144
153
  self.output_top_logprobs_val.nbytes,
145
154
  self.output_top_logprobs_idx.nbytes,
155
+ self.output_topk_p.nbytes,
156
+ self.output_topk_index.nbytes,
146
157
  self.output_hidden_states.nbytes,
147
158
  ]
148
159
  item_lens = [
149
160
  self.output_ids[0].nbytes,
161
+ self.cached_tokens[0].nbytes,
150
162
  self.output_token_logprobs_val[0].nbytes,
151
163
  self.output_token_logprobs_idx[0].nbytes,
152
164
  self.output_top_logprobs_val[0].nbytes,
153
165
  self.output_top_logprobs_idx[0].nbytes,
166
+ self.output_topk_p[0].nbytes,
167
+ self.output_topk_index[0].nbytes,
154
168
  self.output_hidden_states[0].nbytes,
155
169
  ]
156
170
  return ptrs, data_lens, item_lens
@@ -158,16 +172,20 @@ class MetadataBuffers:
158
172
  def get_buf(self, idx: int):
159
173
  return (
160
174
  self.output_ids[idx],
175
+ self.cached_tokens[idx],
161
176
  self.output_token_logprobs_val[idx],
162
177
  self.output_token_logprobs_idx[idx],
163
178
  self.output_top_logprobs_val[idx],
164
179
  self.output_top_logprobs_idx[idx],
180
+ self.output_topk_p[idx],
181
+ self.output_topk_index[idx],
165
182
  self.output_hidden_states[idx],
166
183
  )
167
184
 
168
185
  def set_buf(self, req: Req):
169
186
 
170
187
  self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
188
+ self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens
171
189
  if req.return_logprob:
172
190
  if req.output_token_logprobs_val: # not none or empty list
173
191
  self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
@@ -190,8 +208,17 @@ class MetadataBuffers:
190
208
  ] = torch.tensor(
191
209
  req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
192
210
  )
193
- # for PD + spec decode
211
+ # For PD + spec decode
194
212
  if req.hidden_states_tensor is not None:
213
+ # speculative_eagle_topk should not be greater than 16 currently
214
+ topk = req.output_topk_p.size(0)
215
+
216
+ self.output_topk_p[req.metadata_buffer_index, :topk].copy_(
217
+ req.output_topk_p
218
+ )
219
+ self.output_topk_index[req.metadata_buffer_index, :topk].copy_(
220
+ req.output_topk_index
221
+ )
195
222
  self.output_hidden_states[req.metadata_buffer_index].copy_(
196
223
  req.hidden_states_tensor
197
224
  )
@@ -217,7 +244,9 @@ class KVClassType(Enum):
217
244
  BOOTSTRAP_SERVER = "bootstrap_server"
218
245
 
219
246
 
220
- def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
247
+ def get_kv_class(
248
+ transfer_backend: TransferBackend, class_type: KVClassType
249
+ ) -> Optional[Type]:
221
250
  from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
222
251
 
223
252
  if transfer_backend == TransferBackend.MOONCAKE:
@@ -305,49 +334,6 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
305
334
  return (num_kv_indices + page_size - 1) // page_size
306
335
 
307
336
 
308
- #########################
309
- # PDLB Registry
310
- #########################
311
-
312
-
313
- @dataclasses.dataclass
314
- class PDRegistryRequest:
315
- """A request to register a machine itself to the LB."""
316
-
317
- mode: str
318
- registry_url: str
319
- bootstrap_port: Optional[int] = None
320
-
321
- def __post_init__(self):
322
- if self.mode == "prefill" and self.bootstrap_port is None:
323
- raise ValueError("Bootstrap port must be set in PREFILL mode.")
324
- elif self.mode == "decode" and self.bootstrap_port is not None:
325
- raise ValueError("Bootstrap port must not be set in DECODE mode.")
326
- elif self.mode not in ["prefill", "decode"]:
327
- raise ValueError(
328
- f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
329
- )
330
-
331
-
332
- def register_disaggregation_server(
333
- mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
334
- ):
335
- boostrap_port = bootstrap_port if mode == "prefill" else None
336
- registry_request = PDRegistryRequest(
337
- mode=mode,
338
- registry_url=f"http://{get_ip()}:{server_port}",
339
- bootstrap_port=boostrap_port,
340
- )
341
- res = requests.post(
342
- f"{pdlb_url}/register",
343
- json=dataclasses.asdict(registry_request),
344
- )
345
- if res.status_code != 200:
346
- warnings.warn(
347
- f"Failed to register disaggregation server: {res.status_code} {res.text}"
348
- )
349
-
350
-
351
337
  #########################
352
338
  # Misc
353
339
  #########################
@@ -0,0 +1,16 @@
1
+ MiB = 1024 * 1024
2
+
3
+ SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
4
+ 9: {
5
+ 2: 64 * MiB, # 64 MB
6
+ 4: 32 * MiB, # 32 MB
7
+ 6: 64 * MiB, # 64 MB
8
+ 8: 64 * MiB, # 64 MB
9
+ },
10
+ 10: {
11
+ 2: 64 * MiB, # 64 MB
12
+ 4: 32 * MiB, # 32 MB
13
+ 6: 128 * MiB, # 128 MB
14
+ 8: 128 * MiB, # 128 MB
15
+ },
16
+ }
@@ -18,7 +18,7 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
18
18
 
19
19
  from sglang.srt.utils import (
20
20
  format_tcp_address,
21
- get_ip,
21
+ get_local_ip_auto,
22
22
  get_open_port,
23
23
  is_valid_ipv6_address,
24
24
  )
@@ -191,7 +191,9 @@ class MessageQueue:
191
191
  self.n_remote_reader = n_remote_reader
192
192
 
193
193
  if connect_ip is None:
194
- connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
194
+ connect_ip = (
195
+ get_local_ip_auto("0.0.0.0") if n_remote_reader > 0 else "127.0.0.1"
196
+ )
195
197
 
196
198
  context = Context()
197
199
 
@@ -0,0 +1,164 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/bf214ca22625e311a2c4c0dfbf7af19128f4919c/vllm/distributed/device_communicators/symm_mem.py
2
+ import logging
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+ from torch.distributed import ProcessGroup
8
+
9
+ from sglang.srt.distributed.device_communicators.all_reduce_utils import (
10
+ SYMM_MEM_ALL_REDUCE_MAX_SIZES,
11
+ )
12
+ from sglang.srt.utils import get_device_capability, is_cuda, is_hip
13
+
14
+ try:
15
+ import torch.distributed._symmetric_memory as torch_symm_mem
16
+
17
+ symm_mem_available = True
18
+ except ImportError:
19
+ symm_mem_available = False
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ _is_cuda = is_cuda()
25
+ _is_hip = is_hip()
26
+
27
+ symm_mem_is_available = False
28
+ if _is_hip:
29
+ symm_mem_is_available = False
30
+ if _is_cuda:
31
+ symm_mem_is_available = True
32
+
33
+
34
+ class SymmMemCommunicator:
35
+ """
36
+ Thin wrapper around symmetric-memory collectives.
37
+
38
+ This communicator:
39
+ - Validates device capability and world size.
40
+ - Allocates a shared symmetric buffer.
41
+ - Chooses between 'multimem' and 'two-shot' all-reduce kernels.
42
+ - Exposes a fast-path all_reduce() compatible with bfloat16 inputs.
43
+
44
+ If any prerequisite is not met, the instance remains disabled and will
45
+ decline to perform symmetric-memory all-reduce.
46
+ """
47
+
48
+ # Mapping: compute capability major -> supported world sizes for multimem
49
+ # If the current (cc_major, world_size) is not listed, we fall back
50
+ # to the two-shot path.
51
+ _WORLD_SIZES_MULTIMEM = {
52
+ 9: [4, 6, 8],
53
+ 10: [6, 8],
54
+ }
55
+
56
+ def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]):
57
+ """
58
+ Args:
59
+ group: Torch process group used for rendezvous and naming.
60
+ device: Target CUDA device (index, 'cuda:X', or torch.device).
61
+ """
62
+
63
+ self.disabled = True
64
+
65
+ if not symm_mem_available:
66
+ return
67
+
68
+ if isinstance(device, int):
69
+ device = torch.device(f"cuda:{device}")
70
+ elif isinstance(device, str):
71
+ device = torch.device(device)
72
+ torch.cuda.set_device(device)
73
+ self.dtype = torch.bfloat16
74
+ self.device = device
75
+ self.group = group
76
+ self.world_size = dist.get_world_size(self.group)
77
+ self.device_capability = torch.cuda.get_device_capability(device)[0]
78
+ if self.device_capability < 9:
79
+ logger.warning(
80
+ "SymmMemCommunicator: Device capability %s not supported, "
81
+ "communicator is not available.",
82
+ self.device_capability,
83
+ )
84
+ return
85
+ if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]:
86
+ logger.warning(
87
+ "SymmMemCommunicator: World size %d not supported, "
88
+ "communicator is not available.",
89
+ self.world_size,
90
+ )
91
+ return
92
+ self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
93
+ self.world_size
94
+ ]
95
+ self.buffer = torch_symm_mem.empty(
96
+ self.max_size // self.dtype.itemsize,
97
+ device=self.device,
98
+ dtype=self.dtype,
99
+ )
100
+ handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
101
+ if handle.multicast_ptr == 0:
102
+ logger.warning(
103
+ "SymmMemCommunicator: symmetric memory "
104
+ "multicast operations are not supported."
105
+ )
106
+ self.buffer = None
107
+ self.disabled = True
108
+ return
109
+ self.disabled = False
110
+
111
+ def should_symm_mem_allreduce(self, inp: torch.Tensor):
112
+ """
113
+ Fast-path eligibility check for a given tensor.
114
+
115
+ Conditions:
116
+ - Communicator must be enabled.
117
+ - dtype must be bfloat16 (matches kernel + buffer dtype).
118
+ - Total byte size must be 4-byte aligned (hardware requirement).
119
+ - Payload must be smaller than the symmetric-memory max size.
120
+
121
+ Returns:
122
+ True if the symmetric-memory path can handle this tensor.
123
+ """
124
+ if self.disabled:
125
+ return False
126
+ if inp.dtype != self.dtype:
127
+ return False
128
+ inp_size = inp.numel() * inp.element_size()
129
+ # enforce 4-byte alignment
130
+ if inp_size % 4 != 0:
131
+ return False
132
+ return inp_size < self.max_size
133
+
134
+ def all_reduce(
135
+ self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None
136
+ ) -> Optional[torch.Tensor]:
137
+ """
138
+ Perform an in-place sum all-reduce via symmetric memory.
139
+
140
+ Args:
141
+ inp: Input tensor on the target CUDA device (bfloat16).
142
+ out: Optional output tensor; if omitted, a new tensor is allocated.
143
+
144
+ Returns:
145
+ The reduced tensor (same shape as inp), or None if disabled.
146
+
147
+ Implementation details:
148
+ - Stages 'inp' into the symmetric buffer.
149
+ - Selects 'multimem' or 'two_shot' kernel based on topology.
150
+ - Writes the result into 'out' and returns it.
151
+ """
152
+ if out is None:
153
+ out = torch.empty_like(inp)
154
+ self.buffer[: inp.numel()].copy_(inp.view(-1))
155
+ if self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability]:
156
+ torch.ops.symm_mem.multimem_all_reduce_(
157
+ self.buffer[: inp.numel()], "sum", self.group.group_name
158
+ )
159
+ else:
160
+ torch.ops.symm_mem.two_shot_all_reduce_(
161
+ self.buffer[: inp.numel()], "sum", self.group.group_name
162
+ )
163
+ out.copy_(self.buffer[: inp.numel()].view(out.shape))
164
+ return out