sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,6 @@ Life cycle of a request in the prefill server
20
20
  from __future__ import annotations
21
21
 
22
22
  import logging
23
- import threading
24
23
  import time
25
24
  from collections import deque
26
25
  from http import HTTPStatus
@@ -49,13 +48,12 @@ from sglang.srt.managers.schedule_batch import (
49
48
  RequestStage,
50
49
  ScheduleBatch,
51
50
  )
52
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
53
- from sglang.srt.utils import (
54
- DynamicGradMode,
55
- broadcast_pyobj,
56
- point_to_point_pyobj,
57
- require_mlp_sync,
51
+ from sglang.srt.mem_cache.memory_pool import (
52
+ HybridLinearKVPool,
53
+ NSATokenToKVPool,
54
+ SWAKVPool,
58
55
  )
56
+ from sglang.srt.utils import broadcast_pyobj, point_to_point_pyobj, require_mlp_sync
59
57
 
60
58
  if TYPE_CHECKING:
61
59
  from torch.distributed import ProcessGroup
@@ -146,6 +144,28 @@ class PrefillBootstrapQueue:
146
144
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
147
145
  kv_args.gpu_id = self.scheduler.gpu_id
148
146
 
147
+ if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
148
+ state_data_ptrs, state_data_lens, state_item_lens = (
149
+ self.token_to_kv_pool.get_state_buf_infos()
150
+ )
151
+ kv_args.state_data_ptrs = state_data_ptrs
152
+ kv_args.state_data_lens = state_data_lens
153
+ kv_args.state_item_lens = state_item_lens
154
+
155
+ if isinstance(self.token_to_kv_pool, SWAKVPool):
156
+ kv_args.state_type = "swa"
157
+ elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
158
+ kv_args.state_type = "mamba"
159
+ elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
160
+ kv_args.state_type = "nsa"
161
+ else:
162
+ kv_args.state_type = "none"
163
+ else:
164
+ kv_args.state_data_ptrs = []
165
+ kv_args.state_data_lens = []
166
+ kv_args.state_item_lens = []
167
+ kv_args.state_type = "none"
168
+
149
169
  kv_manager_class: Type[BaseKVManager] = get_kv_class(
150
170
  self.transfer_backend, KVClassType.MANAGER
151
171
  )
@@ -332,30 +352,21 @@ class SchedulerDisaggregationPrefillMixin:
332
352
  if require_mlp_sync(self.server_args):
333
353
  batch = self.prepare_mlp_sync_batch(batch)
334
354
  self.cur_batch = batch
355
+
356
+ batch_result = None
335
357
  if batch:
336
- result = self.run_batch(batch)
337
- self.result_queue.append((batch.copy(), result))
338
-
339
- if self.last_batch is None:
340
- # Create a dummy first batch to start the pipeline for overlap schedule.
341
- # It is now used for triggering the sampling_info_done event.
342
- tmp_batch = ScheduleBatch(
343
- reqs=None,
344
- forward_mode=ForwardMode.DUMMY_FIRST,
345
- next_batch_sampling_info=self.tp_worker.cur_sampling_info,
346
- )
347
- self.set_next_batch_sampling_info_done(tmp_batch)
358
+ batch_result = self.run_batch(batch)
359
+ self.result_queue.append((batch.copy(), batch_result))
348
360
 
349
361
  if self.last_batch:
350
362
  tmp_batch, tmp_result = self.result_queue.popleft()
351
- tmp_batch.next_batch_sampling_info = (
352
- self.tp_worker.cur_sampling_info if batch else None
353
- )
354
363
  self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
355
364
 
356
365
  if len(self.disagg_prefill_inflight_queue) > 0:
357
366
  self.process_disagg_prefill_inflight_queue()
358
367
 
368
+ self.launch_batch_sample_if_needed(batch_result)
369
+
359
370
  if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
360
371
  self.self_check_during_idle()
361
372
 
@@ -368,7 +379,6 @@ class SchedulerDisaggregationPrefillMixin:
368
379
  self: Scheduler,
369
380
  batch: ScheduleBatch,
370
381
  result: GenerationBatchResult,
371
- launch_done: Optional[threading.Event] = None,
372
382
  ) -> None:
373
383
  """
374
384
  Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
@@ -379,31 +389,30 @@ class SchedulerDisaggregationPrefillMixin:
379
389
  next_token_ids,
380
390
  extend_input_len_per_req,
381
391
  extend_logprob_start_len_per_req,
392
+ copy_done,
382
393
  ) = (
383
394
  result.logits_output,
384
395
  result.next_token_ids,
385
396
  result.extend_input_len_per_req,
386
397
  result.extend_logprob_start_len_per_req,
398
+ result.copy_done,
387
399
  )
388
400
 
401
+ if copy_done is not None:
402
+ copy_done.synchronize()
403
+
389
404
  logprob_pt = 0
390
405
  # Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
391
- if self.enable_overlap:
392
- # wait
393
- logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
394
- launch_done
395
- )
396
- else:
397
- next_token_ids = result.next_token_ids.tolist()
398
- if batch.return_logprob:
399
- if logits_output.next_token_logprobs is not None:
400
- logits_output.next_token_logprobs = (
401
- logits_output.next_token_logprobs.tolist()
402
- )
403
- if logits_output.input_token_logprobs is not None:
404
- logits_output.input_token_logprobs = tuple(
405
- logits_output.input_token_logprobs.tolist()
406
- )
406
+ next_token_ids = result.next_token_ids.tolist()
407
+ if batch.return_logprob:
408
+ if logits_output.next_token_logprobs is not None:
409
+ logits_output.next_token_logprobs = (
410
+ logits_output.next_token_logprobs.tolist()
411
+ )
412
+ if logits_output.input_token_logprobs is not None:
413
+ logits_output.input_token_logprobs = tuple(
414
+ logits_output.input_token_logprobs.tolist()
415
+ )
407
416
 
408
417
  hidden_state_offset = 0
409
418
  for i, (req, next_token_id) in enumerate(
@@ -415,24 +424,12 @@ class SchedulerDisaggregationPrefillMixin:
415
424
  self.tree_cache.cache_unfinished_req(req) # update the tree and lock
416
425
  req.add_latency(RequestStage.PREFILL_FORWARD)
417
426
  self.disagg_prefill_inflight_queue.append(req)
418
- if (
419
- logits_output is not None
420
- and logits_output.hidden_states is not None
421
- ):
422
- last_hidden_index = (
423
- hidden_state_offset + extend_input_len_per_req[i] - 1
424
- )
427
+ if self.spec_algorithm.is_eagle() and batch.spec_info is not None:
425
428
  req.output_topk_p = batch.spec_info.topk_p[i]
426
429
  req.output_topk_index = batch.spec_info.topk_index[i]
427
- if self.spec_algorithm.is_eagle3():
428
- req.hidden_states_tensor = (
429
- batch.spec_info.hidden_states[i].cpu().clone()
430
- )
431
- else:
432
- req.hidden_states_tensor = (
433
- logits_output.hidden_states[last_hidden_index].cpu().clone()
434
- )
435
- hidden_state_offset += extend_input_len_per_req[i]
430
+ req.hidden_states_tensor = (
431
+ batch.spec_info.hidden_states[i].cpu().clone()
432
+ )
436
433
  else:
437
434
  req.hidden_states_tensor = None
438
435
  if req.return_logprob:
@@ -491,8 +488,6 @@ class SchedulerDisaggregationPrefillMixin:
491
488
  if self.enable_overlap:
492
489
  self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
493
490
 
494
- # We need to remove the sync in the following function for overlap schedule.
495
- self.set_next_batch_sampling_info_done(batch)
496
491
  self.maybe_send_health_check_signal()
497
492
 
498
493
  def process_disagg_prefill_inflight_queue(
@@ -631,227 +626,58 @@ class SchedulerDisaggregationPrefillMixin:
631
626
  .numpy()
632
627
  )
633
628
  req.start_send_idx = end_idx
629
+ state_indices = None
634
630
  if last_chunk:
635
631
  self.disagg_metadata_buffers.set_buf(req)
632
+
633
+ # Prepare extra pool indices for hybrid models
634
+ if isinstance(
635
+ self.token_to_kv_pool_allocator.get_kvcache(), HybridLinearKVPool
636
+ ):
637
+ # Mamba hybrid model: send single mamba state index
638
+ state_indices = [
639
+ self.req_to_token_pool.req_index_to_mamba_index_mapping[
640
+ req.req_pool_idx
641
+ ]
642
+ .cpu()
643
+ .numpy()
644
+ ]
645
+ elif isinstance(self.token_to_kv_pool_allocator.get_kvcache(), SWAKVPool):
646
+ # SWA hybrid model: send last window KV indices
647
+ seq_len = len(req.fill_ids)
648
+ window_size = self.sliding_window_size
649
+ window_start = max(0, seq_len - window_size)
650
+ window_start = (window_start // page_size) * page_size
651
+
652
+ window_kv_indices_full = self.req_to_token_pool.req_to_token[
653
+ req.req_pool_idx, window_start:seq_len
654
+ ]
655
+
656
+ # Translate to SWA pool indices
657
+ window_kv_indices_swa = (
658
+ self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
659
+ window_kv_indices_full
660
+ )
661
+ )
662
+ state_indices = window_kv_indices_swa.cpu().numpy()
663
+ state_indices = kv_to_page_indices(state_indices, page_size)
664
+ elif isinstance(
665
+ self.token_to_kv_pool_allocator.get_kvcache(), NSATokenToKVPool
666
+ ):
667
+ seq_len = len(req.fill_ids)
668
+ kv_indices_full = self.req_to_token_pool.req_to_token[
669
+ req.req_pool_idx, :seq_len
670
+ ]
671
+ state_indices = kv_indices_full.cpu().numpy()
672
+ state_indices = kv_to_page_indices(state_indices, page_size)
673
+
636
674
  page_indices = kv_to_page_indices(kv_indices, page_size)
637
675
  if len(page_indices) == 0:
638
676
  logger.info(
639
677
  f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
640
678
  )
641
679
  return
642
- req.disagg_kv_sender.send(page_indices)
643
-
644
- # PP
645
- @DynamicGradMode()
646
- def event_loop_pp_disagg_prefill(self: Scheduler):
647
- """
648
- An event loop for the prefill server in pipeline parallelism.
649
-
650
- Rules:
651
- 1. Each stage runs in the same order and is notified by the previous stage.
652
- 2. Each send/recv operation is blocking and matched by the neighboring stage.
653
-
654
- Regular Schedule:
655
- ====================================================================
656
- Stage i | Stage i+1
657
- send ith req | recv ith req
658
- send ith proxy | recv ith proxy
659
- send prev (i+1)th carry | recv prev (i+1)th carry
660
- ====================================================================
661
-
662
- Prefill Server Schedule:
663
- ====================================================================
664
- Stage i | Stage i+1
665
- send ith req | recv ith req
666
- send ith bootstrap req | recv ith bootstrap req
667
- send ith transferred req | recv ith transferred req
668
- send ith proxy | recv ith proxy
669
- send prev (i+1)th carry | recv prev (i+1)th carry
670
- send prev (i+1)th release req | recv prev (i+1)th release req
671
- ====================================================================
672
-
673
- There are two additional elements compared to the regular schedule:
674
-
675
- 1. Bootstrap Requests:
676
- a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization.
677
- b. The first stage polls the status and propagates the bootstrapped requests down to all other stages.
678
- c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together.
679
-
680
- 2. Transferred Requests + Release Requests:
681
- a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage.
682
- b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory.
683
- c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage.
684
- """
685
- from sglang.srt.managers.scheduler import GenerationBatchResult
686
-
687
- mbs = [None] * self.pp_size
688
- last_mbs = [None] * self.pp_size
689
- self.running_mbs = [
690
- ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
691
- ]
692
- pp_outputs: Optional[PPProxyTensors] = None
693
-
694
- # Either success or failed
695
- bootstrapped_rids: List[str] = []
696
- transferred_rids: List[str] = []
697
- release_rids: Optional[List[str]] = None
698
-
699
- # transferred microbatch
700
- tmbs = [None] * self.pp_size
701
-
702
- ENABLE_RELEASE = True # For debug
703
-
704
- while True:
705
- server_is_idle = True
706
-
707
- for mb_id in range(self.pp_size):
708
- self.running_batch = self.running_mbs[mb_id]
709
- self.last_batch = last_mbs[mb_id]
710
-
711
- recv_reqs = self.recv_requests()
712
-
713
- self.process_input_requests(recv_reqs)
714
-
715
- if self.pp_group.is_first_rank:
716
- # First rank, pop the bootstrap reqs from the bootstrap queue
717
- bootstrapped_reqs, failed_reqs = (
718
- self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
719
- return_failed_reqs=True
720
- )
721
- )
722
- bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [
723
- req.rid for req in failed_reqs
724
- ]
725
- self.waiting_queue.extend(bootstrapped_reqs)
726
- else:
727
- # Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus
728
- bootstrapped_rids = self.recv_pyobj_from_prev_stage()
729
- bootstrapped_reqs = (
730
- self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
731
- rids_to_check=bootstrapped_rids
732
- )
733
- )
734
- self.waiting_queue.extend(bootstrapped_reqs)
735
-
736
- if self.pp_group.is_first_rank:
737
- transferred_rids = self.get_transferred_rids()
738
- # if other ranks,
739
- else:
740
- # 1. recv previous stage's transferred reqs info
741
- prev_transferred_rids = self.recv_pyobj_from_prev_stage()
742
- # 2. get the current stage's transferred reqs info
743
- curr_transferred_rids = self.get_transferred_rids()
744
- # 3. new consensus rids = intersection(previous consensus rids, transfer finished rids)
745
- transferred_rids = list(
746
- set(prev_transferred_rids) & set(curr_transferred_rids)
747
- )
748
-
749
- tmbs[mb_id] = transferred_rids
750
-
751
- self.process_prefill_chunk()
752
- mbs[mb_id] = self.get_new_batch_prefill()
753
- self.running_mbs[mb_id] = self.running_batch
754
-
755
- self.cur_batch = mbs[mb_id]
756
- if self.cur_batch:
757
- server_is_idle = False
758
- result = self.run_batch(self.cur_batch)
759
-
760
- # send the outputs to the next step
761
- if self.pp_group.is_last_rank:
762
- if self.cur_batch:
763
- next_token_ids = result.next_token_ids
764
- pp_outputs = PPProxyTensors(
765
- {
766
- "next_token_ids": next_token_ids,
767
- }
768
- )
769
- # send the output from the last round to let the next stage worker run post processing
770
- self.pp_group.send_tensor_dict(
771
- pp_outputs.tensors,
772
- all_gather_group=self.attn_tp_group,
773
- )
774
-
775
- if ENABLE_RELEASE:
776
- if self.pp_group.is_last_rank:
777
- # At the last stage, all stages has reached the consensus to release memory for transferred_rids
778
- release_rids = transferred_rids
779
- # send to the first rank
780
- self.send_pyobj_to_next_stage(release_rids)
781
-
782
- # receive outputs and post-process (filter finished reqs) the coming microbatch
783
- next_mb_id = (mb_id + 1) % self.pp_size
784
- next_pp_outputs = None
785
- next_release_rids = None
786
-
787
- if mbs[next_mb_id] is not None:
788
- next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
789
- self.pp_group.recv_tensor_dict(
790
- all_gather_group=self.attn_tp_group
791
- )
792
- )
793
- mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
794
- output_result = GenerationBatchResult(
795
- logits_output=None,
796
- pp_hidden_states_proxy_tensors=None,
797
- next_token_ids=next_pp_outputs["next_token_ids"],
798
- extend_input_len_per_req=None,
799
- extend_logprob_start_len_per_req=None,
800
- can_run_cuda_graph=result.can_run_cuda_graph,
801
- )
802
- self.process_batch_result_disagg_prefill(
803
- mbs[next_mb_id], output_result
804
- )
805
-
806
- last_mbs[next_mb_id] = mbs[next_mb_id]
807
-
808
- if ENABLE_RELEASE:
809
- if tmbs[next_mb_id] is not None:
810
- # recv consensus rids from the previous rank
811
- next_release_rids = self.recv_pyobj_from_prev_stage()
812
- self.process_disagg_prefill_inflight_queue(next_release_rids)
813
-
814
- # carry the outputs to the next stage
815
- if not self.pp_group.is_last_rank:
816
- if pp_outputs:
817
- # send the outputs from the last round to let the next stage worker run post processing
818
- self.pp_group.send_tensor_dict(
819
- pp_outputs.tensors,
820
- all_gather_group=self.attn_tp_group,
821
- )
822
- if ENABLE_RELEASE:
823
- if release_rids is not None:
824
- self.send_pyobj_to_next_stage(release_rids)
825
-
826
- if not self.pp_group.is_last_rank:
827
- # send out reqs to the next stage
828
- self.send_pyobj_to_next_stage(recv_reqs)
829
- self.send_pyobj_to_next_stage(bootstrapped_rids)
830
- self.send_pyobj_to_next_stage(transferred_rids)
831
-
832
- # send out proxy tensors to the next stage
833
- if self.cur_batch:
834
- # FIXME(lsyin): remove this assert
835
- assert result.pp_hidden_states_proxy_tensors.tensors is not None
836
- self.pp_group.send_tensor_dict(
837
- result.pp_hidden_states_proxy_tensors.tensors,
838
- all_gather_group=self.attn_tp_group,
839
- )
840
-
841
- pp_outputs = next_pp_outputs
842
- release_rids = next_release_rids
843
-
844
- self.running_batch.batch_is_full = False
845
-
846
- if not ENABLE_RELEASE:
847
- if len(self.disagg_prefill_inflight_queue) > 0:
848
- self.process_disagg_prefill_inflight_queue()
849
-
850
- # When the server is idle, self-check and re-init some states
851
- if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0:
852
- self.check_memory()
853
- self.check_tree_cache()
854
- self.new_token_ratio = self.init_new_token_ratio
680
+ req.disagg_kv_sender.send(page_indices, state_indices)
855
681
 
856
682
  def send_pyobj_to_next_stage(self, data):
857
683
  if self.attn_tp_rank == 0:
@@ -3,13 +3,13 @@ MiB = 1024 * 1024
3
3
  SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
4
4
  9: {
5
5
  2: 64 * MiB, # 64 MB
6
- 4: 32 * MiB, # 32 MB
7
- 6: 64 * MiB, # 64 MB
8
- 8: 64 * MiB, # 64 MB
6
+ 4: 64 * MiB, # 64 MB
7
+ 6: 128 * MiB, # 128 MB
8
+ 8: 128 * MiB, # 128 MB
9
9
  },
10
10
  10: {
11
11
  2: 64 * MiB, # 64 MB
12
- 4: 32 * MiB, # 32 MB
12
+ 4: 64 * MiB, # 64 MB
13
13
  6: 128 * MiB, # 128 MB
14
14
  8: 128 * MiB, # 128 MB
15
15
  },
@@ -18,7 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
18
18
  is_weak_contiguous,
19
19
  )
20
20
  from sglang.srt.distributed.parallel_state import in_the_same_node_as
21
- from sglang.srt.utils import is_cuda, is_hip
21
+ from sglang.srt.utils import is_cuda, is_hip, log_info_on_rank0
22
22
 
23
23
  logger = logging.getLogger(__name__)
24
24
 
@@ -32,7 +32,7 @@ try:
32
32
  ops.meta_size()
33
33
  else:
34
34
  # Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
35
- import sgl_kernel
35
+ import sgl_kernel # noqa: F401
36
36
  custom_ar = True
37
37
  except Exception:
38
38
  # For CPUs
@@ -185,7 +185,7 @@ class CustomAllreduce:
185
185
  # is enough for 131072 such tuples. The largest model I've seen only
186
186
  # needs less than 10000 of registered tuples.
187
187
  self.rank_data = torch.empty(
188
- 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
188
+ max_size, dtype=torch.uint8, device=self.device
189
189
  )
190
190
  self._ptr = ops.init_custom_ar(
191
191
  self.meta_ptrs, self.rank_data, rank, self.full_nvlink
@@ -202,7 +202,7 @@ class CustomAllreduce:
202
202
  )
203
203
  handles, offsets = self._gather_ipc_meta(shard_data)
204
204
  self.rank_data = torch.empty(
205
- 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
205
+ max_size, dtype=torch.uint8, device=self.device
206
206
  )
207
207
  self._ptr = ops.init_custom_ar(
208
208
  self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
@@ -301,11 +301,11 @@ class CustomAllreduce:
301
301
  if _is_hip:
302
302
  handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
303
303
  handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
304
- logger.info("Registering %d cuda graph addresses", len(offset))
304
+ log_info_on_rank0(logger, f"Registering {len(offset)} cuda graph addresses")
305
305
  ops.register_graph_buffers(self._ptr, handles, offsets)
306
306
  else:
307
307
  handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
308
- logger.info("Registering %d cuda graph addresses", len(offset))
308
+ log_info_on_rank0(logger, f"Registering {len(offset)} cuda graph addresses")
309
309
  # We cannot directly use `dist.all_gather_object` here
310
310
  # because it is incompatible with `gloo` backend under inference mode.
311
311
  # see https://github.com/pytorch/pytorch/issues/126032 for details.
@@ -4,7 +4,7 @@ import math
4
4
  import os
5
5
  from contextlib import contextmanager
6
6
  from enum import IntEnum
7
- from typing import Any, Callable, List, Optional, TypeVar, Union
7
+ from typing import Optional, Union
8
8
 
9
9
  import torch
10
10
  import torch.distributed as dist
@@ -24,7 +24,7 @@ if _is_hip:
24
24
  mscclpp_is_available = False
25
25
  if _is_cuda:
26
26
  try:
27
- import sgl_kernel
27
+ import sgl_kernel # noqa: F401
28
28
 
29
29
  mscclpp_is_available = True
30
30
  except:
@@ -30,6 +30,7 @@ class PyNcclCommunicator:
30
30
  group: Union[ProcessGroup, StatelessProcessGroup],
31
31
  device: Union[int, str, torch.device],
32
32
  library_path: Optional[str] = None,
33
+ use_current_stream: bool = False,
33
34
  ):
34
35
  """
35
36
  Args:
@@ -74,6 +75,7 @@ class PyNcclCommunicator:
74
75
 
75
76
  self.available = True
76
77
  self.disabled = False
78
+ self.use_current_stream = use_current_stream
77
79
 
78
80
  self.nccl_version = self.nccl.ncclGetRawVersion()
79
81
  if self.rank == 0:
@@ -123,6 +125,21 @@ class PyNcclCommunicator:
123
125
  # when we are using CUDA graph.
124
126
  self.disabled = True
125
127
 
128
+ def _resolve_stream(self, stream: Optional[torch.cuda.Stream]):
129
+ """Return the stream to use for NCCL calls.
130
+
131
+ Behavior mirrors the previous inline logic:
132
+ - if an explicit stream is provided, return it
133
+ - if stream is None and self.use_current_stream is True, return
134
+ torch.cuda.current_stream()
135
+ - otherwise return the communicator's default stream (self.stream)
136
+ """
137
+ if stream is not None:
138
+ return stream
139
+ if self.use_current_stream:
140
+ return torch.cuda.current_stream()
141
+ return self.stream
142
+
126
143
  def all_reduce(
127
144
  self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
128
145
  ):
@@ -135,8 +152,7 @@ class PyNcclCommunicator:
135
152
  f"this nccl communicator is created to work on {self.device}, "
136
153
  f"but the input tensor is on {tensor.device}"
137
154
  )
138
- if stream is None:
139
- stream = self.stream
155
+ stream = self._resolve_stream(stream)
140
156
  self.nccl.ncclAllReduce(
141
157
  buffer_type(tensor.data_ptr()),
142
158
  buffer_type(tensor.data_ptr()),
@@ -163,8 +179,7 @@ class PyNcclCommunicator:
163
179
  f"this nccl communicator is created to work on {self.device}, "
164
180
  f"but the input tensor is on {input_tensor.device}"
165
181
  )
166
- if stream is None:
167
- stream = self.stream
182
+ stream = self._resolve_stream(stream)
168
183
 
169
184
  if sizes is not None:
170
185
  split_offset = 0
@@ -210,8 +225,7 @@ class PyNcclCommunicator:
210
225
  f"this nccl communicator is created to work on {self.device}, "
211
226
  f"but the input tensor is on {input_tensor.device}"
212
227
  )
213
- if stream is None:
214
- stream = self.stream
228
+ stream = self._resolve_stream(stream)
215
229
 
216
230
  if sizes is not None:
217
231
  split_offset = 0
@@ -249,8 +263,7 @@ class PyNcclCommunicator:
249
263
  f"this nccl communicator is created to work on {self.device}, "
250
264
  f"but the input tensor is on {tensor.device}"
251
265
  )
252
- if stream is None:
253
- stream = self.stream
266
+ stream = self._resolve_stream(stream)
254
267
  self.nccl.ncclSend(
255
268
  buffer_type(tensor.data_ptr()),
256
269
  tensor.numel(),
@@ -267,8 +280,7 @@ class PyNcclCommunicator:
267
280
  f"this nccl communicator is created to work on {self.device}, "
268
281
  f"but the input tensor is on {tensor.device}"
269
282
  )
270
- if stream is None:
271
- stream = self.stream
283
+ stream = self._resolve_stream(stream)
272
284
  self.nccl.ncclRecv(
273
285
  buffer_type(tensor.data_ptr()),
274
286
  tensor.numel(),
@@ -285,8 +297,8 @@ class PyNcclCommunicator:
285
297
  f"this nccl communicator is created to work on {self.device}, "
286
298
  f"but the input tensor is on {tensor.device}"
287
299
  )
288
- if stream is None:
289
- stream = self.stream
300
+ stream = self._resolve_stream(stream)
301
+
290
302
  if src == self.rank:
291
303
  sendbuff = buffer_type(tensor.data_ptr())
292
304
  # NCCL requires the sender also to have a receive buffer
@@ -5,7 +5,7 @@ from packaging import version
5
5
  from torch.cuda.memory import CUDAPluggableAllocator
6
6
 
7
7
  from sglang.srt.distributed.parallel_state import GroupCoordinator
8
- from sglang.srt.managers.schedule_batch import global_server_args_dict
8
+ from sglang.srt.server_args import get_global_server_args
9
9
 
10
10
  nccl_allocator_source = """
11
11
  #include <nccl.h>
@@ -32,7 +32,7 @@ _graph_pool_id = None
32
32
 
33
33
 
34
34
  def is_symmetric_memory_enabled():
35
- return global_server_args_dict["enable_symm_mem"]
35
+ return get_global_server_args().enable_symm_mem
36
36
 
37
37
 
38
38
  def set_graph_pool_id(graph_pool_id):
@@ -9,7 +9,7 @@ from torch.distributed import ProcessGroup
9
9
  from sglang.srt.distributed.device_communicators.all_reduce_utils import (
10
10
  SYMM_MEM_ALL_REDUCE_MAX_SIZES,
11
11
  )
12
- from sglang.srt.utils import get_device_capability, is_cuda, is_hip
12
+ from sglang.srt.utils import is_cuda, is_hip
13
13
 
14
14
  try:
15
15
  import torch.distributed._symmetric_memory as torch_symm_mem