sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ Life cycle of a request in the prefill server
20
20
  from __future__ import annotations
21
21
 
22
22
  import logging
23
- import threading
23
+ import time
24
24
  from collections import deque
25
25
  from http import HTTPStatus
26
26
  from typing import TYPE_CHECKING, List, Optional, Type
@@ -42,14 +42,18 @@ from sglang.srt.disaggregation.utils import (
42
42
  poll_and_all_reduce,
43
43
  prepare_abort,
44
44
  )
45
- from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
46
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
47
- from sglang.srt.utils import (
48
- DynamicGradMode,
49
- broadcast_pyobj,
50
- point_to_point_pyobj,
51
- require_mlp_sync,
45
+ from sglang.srt.managers.schedule_batch import (
46
+ FINISH_LENGTH,
47
+ Req,
48
+ RequestStage,
49
+ ScheduleBatch,
52
50
  )
51
+ from sglang.srt.mem_cache.memory_pool import (
52
+ HybridLinearKVPool,
53
+ NSATokenToKVPool,
54
+ SWAKVPool,
55
+ )
56
+ from sglang.srt.utils import broadcast_pyobj, point_to_point_pyobj, require_mlp_sync
53
57
 
54
58
  if TYPE_CHECKING:
55
59
  from torch.distributed import ProcessGroup
@@ -140,6 +144,28 @@ class PrefillBootstrapQueue:
140
144
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
141
145
  kv_args.gpu_id = self.scheduler.gpu_id
142
146
 
147
+ if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
148
+ state_data_ptrs, state_data_lens, state_item_lens = (
149
+ self.token_to_kv_pool.get_state_buf_infos()
150
+ )
151
+ kv_args.state_data_ptrs = state_data_ptrs
152
+ kv_args.state_data_lens = state_data_lens
153
+ kv_args.state_item_lens = state_item_lens
154
+
155
+ if isinstance(self.token_to_kv_pool, SWAKVPool):
156
+ kv_args.state_type = "swa"
157
+ elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
158
+ kv_args.state_type = "mamba"
159
+ elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
160
+ kv_args.state_type = "nsa"
161
+ else:
162
+ kv_args.state_type = "none"
163
+ else:
164
+ kv_args.state_data_ptrs = []
165
+ kv_args.state_data_lens = []
166
+ kv_args.state_item_lens = []
167
+ kv_args.state_type = "none"
168
+
143
169
  kv_manager_class: Type[BaseKVManager] = get_kv_class(
144
170
  self.transfer_backend, KVClassType.MANAGER
145
171
  )
@@ -170,6 +196,7 @@ class PrefillBootstrapQueue:
170
196
  pp_rank=self.pp_rank,
171
197
  )
172
198
  self._process_req(req)
199
+ req.add_latency(RequestStage.PREFILL_PREPARE)
173
200
  self.queue.append(req)
174
201
 
175
202
  def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
@@ -256,8 +283,11 @@ class PrefillBootstrapQueue:
256
283
 
257
284
  num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
258
285
  req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
286
+
259
287
  bootstrapped_reqs.append(req)
260
288
  indices_to_remove.add(i)
289
+ req.time_stats.wait_queue_entry_time = time.perf_counter()
290
+ req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
261
291
 
262
292
  self.queue = [
263
293
  entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
@@ -322,30 +352,21 @@ class SchedulerDisaggregationPrefillMixin:
322
352
  if require_mlp_sync(self.server_args):
323
353
  batch = self.prepare_mlp_sync_batch(batch)
324
354
  self.cur_batch = batch
355
+
356
+ batch_result = None
325
357
  if batch:
326
- result = self.run_batch(batch)
327
- self.result_queue.append((batch.copy(), result))
328
-
329
- if self.last_batch is None:
330
- # Create a dummy first batch to start the pipeline for overlap schedule.
331
- # It is now used for triggering the sampling_info_done event.
332
- tmp_batch = ScheduleBatch(
333
- reqs=None,
334
- forward_mode=ForwardMode.DUMMY_FIRST,
335
- next_batch_sampling_info=self.tp_worker.cur_sampling_info,
336
- )
337
- self.set_next_batch_sampling_info_done(tmp_batch)
358
+ batch_result = self.run_batch(batch)
359
+ self.result_queue.append((batch.copy(), batch_result))
338
360
 
339
361
  if self.last_batch:
340
362
  tmp_batch, tmp_result = self.result_queue.popleft()
341
- tmp_batch.next_batch_sampling_info = (
342
- self.tp_worker.cur_sampling_info if batch else None
343
- )
344
363
  self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
345
364
 
346
365
  if len(self.disagg_prefill_inflight_queue) > 0:
347
366
  self.process_disagg_prefill_inflight_queue()
348
367
 
368
+ self.launch_batch_sample_if_needed(batch_result)
369
+
349
370
  if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
350
371
  self.self_check_during_idle()
351
372
 
@@ -358,7 +379,6 @@ class SchedulerDisaggregationPrefillMixin:
358
379
  self: Scheduler,
359
380
  batch: ScheduleBatch,
360
381
  result: GenerationBatchResult,
361
- launch_done: Optional[threading.Event] = None,
362
382
  ) -> None:
363
383
  """
364
384
  Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
@@ -369,53 +389,47 @@ class SchedulerDisaggregationPrefillMixin:
369
389
  next_token_ids,
370
390
  extend_input_len_per_req,
371
391
  extend_logprob_start_len_per_req,
392
+ copy_done,
372
393
  ) = (
373
394
  result.logits_output,
374
395
  result.next_token_ids,
375
396
  result.extend_input_len_per_req,
376
397
  result.extend_logprob_start_len_per_req,
398
+ result.copy_done,
377
399
  )
378
400
 
401
+ if copy_done is not None:
402
+ copy_done.synchronize()
403
+
379
404
  logprob_pt = 0
380
405
  # Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
381
- if self.enable_overlap:
382
- # wait
383
- logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
384
- launch_done
385
- )
386
- else:
387
- next_token_ids = result.next_token_ids.tolist()
388
- if batch.return_logprob:
389
- if logits_output.next_token_logprobs is not None:
390
- logits_output.next_token_logprobs = (
391
- logits_output.next_token_logprobs.tolist()
392
- )
393
- if logits_output.input_token_logprobs is not None:
394
- logits_output.input_token_logprobs = tuple(
395
- logits_output.input_token_logprobs.tolist()
396
- )
406
+ next_token_ids = result.next_token_ids.tolist()
407
+ if batch.return_logprob:
408
+ if logits_output.next_token_logprobs is not None:
409
+ logits_output.next_token_logprobs = (
410
+ logits_output.next_token_logprobs.tolist()
411
+ )
412
+ if logits_output.input_token_logprobs is not None:
413
+ logits_output.input_token_logprobs = tuple(
414
+ logits_output.input_token_logprobs.tolist()
415
+ )
397
416
 
398
417
  hidden_state_offset = 0
399
418
  for i, (req, next_token_id) in enumerate(
400
419
  zip(batch.reqs, next_token_ids, strict=True)
401
420
  ):
402
- req: Req
403
421
  if req.is_chunked <= 0:
404
422
  # There is no output_ids for prefill
405
423
  req.output_ids.append(next_token_id)
406
424
  self.tree_cache.cache_unfinished_req(req) # update the tree and lock
425
+ req.add_latency(RequestStage.PREFILL_FORWARD)
407
426
  self.disagg_prefill_inflight_queue.append(req)
408
- if (
409
- logits_output is not None
410
- and logits_output.hidden_states is not None
411
- ):
412
- last_hidden_index = (
413
- hidden_state_offset + extend_input_len_per_req[i] - 1
414
- )
427
+ if self.spec_algorithm.is_eagle() and batch.spec_info is not None:
428
+ req.output_topk_p = batch.spec_info.topk_p[i]
429
+ req.output_topk_index = batch.spec_info.topk_index[i]
415
430
  req.hidden_states_tensor = (
416
- logits_output.hidden_states[last_hidden_index].cpu().clone()
431
+ batch.spec_info.hidden_states[i].cpu().clone()
417
432
  )
418
- hidden_state_offset += extend_input_len_per_req[i]
419
433
  else:
420
434
  req.hidden_states_tensor = None
421
435
  if req.return_logprob:
@@ -434,6 +448,7 @@ class SchedulerDisaggregationPrefillMixin:
434
448
  )
435
449
  logprob_pt += num_input_logprobs
436
450
  self.send_kv_chunk(req, last_chunk=True)
451
+ req.time_stats.prefill_transfer_queue_entry_time = time.perf_counter()
437
452
 
438
453
  if req.grammar is not None:
439
454
  # FIXME: this try-except block is for handling unexpected xgrammar issue.
@@ -473,8 +488,6 @@ class SchedulerDisaggregationPrefillMixin:
473
488
  if self.enable_overlap:
474
489
  self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
475
490
 
476
- # We need to remove the sync in the following function for overlap schedule.
477
- self.set_next_batch_sampling_info_done(batch)
478
491
  self.maybe_send_health_check_signal()
479
492
 
480
493
  def process_disagg_prefill_inflight_queue(
@@ -531,6 +544,9 @@ class SchedulerDisaggregationPrefillMixin:
531
544
  else:
532
545
  assert False, f"Unexpected polling state {poll=}"
533
546
 
547
+ for req in done_reqs:
548
+ req.time_stats.completion_time = time.perf_counter()
549
+
534
550
  # Stream requests which have finished transfer
535
551
  self.stream_output(
536
552
  done_reqs,
@@ -539,6 +555,7 @@ class SchedulerDisaggregationPrefillMixin:
539
555
  )
540
556
  for req in done_reqs:
541
557
  req: Req
558
+ req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
542
559
  self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
543
560
  req.metadata_buffer_index = -1
544
561
 
@@ -609,232 +626,58 @@ class SchedulerDisaggregationPrefillMixin:
609
626
  .numpy()
610
627
  )
611
628
  req.start_send_idx = end_idx
629
+ state_indices = None
612
630
  if last_chunk:
613
631
  self.disagg_metadata_buffers.set_buf(req)
632
+
633
+ # Prepare extra pool indices for hybrid models
634
+ if isinstance(
635
+ self.token_to_kv_pool_allocator.get_kvcache(), HybridLinearKVPool
636
+ ):
637
+ # Mamba hybrid model: send single mamba state index
638
+ state_indices = [
639
+ self.req_to_token_pool.req_index_to_mamba_index_mapping[
640
+ req.req_pool_idx
641
+ ]
642
+ .cpu()
643
+ .numpy()
644
+ ]
645
+ elif isinstance(self.token_to_kv_pool_allocator.get_kvcache(), SWAKVPool):
646
+ # SWA hybrid model: send last window KV indices
647
+ seq_len = len(req.fill_ids)
648
+ window_size = self.sliding_window_size
649
+ window_start = max(0, seq_len - window_size)
650
+ window_start = (window_start // page_size) * page_size
651
+
652
+ window_kv_indices_full = self.req_to_token_pool.req_to_token[
653
+ req.req_pool_idx, window_start:seq_len
654
+ ]
655
+
656
+ # Translate to SWA pool indices
657
+ window_kv_indices_swa = (
658
+ self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
659
+ window_kv_indices_full
660
+ )
661
+ )
662
+ state_indices = window_kv_indices_swa.cpu().numpy()
663
+ state_indices = kv_to_page_indices(state_indices, page_size)
664
+ elif isinstance(
665
+ self.token_to_kv_pool_allocator.get_kvcache(), NSATokenToKVPool
666
+ ):
667
+ seq_len = len(req.fill_ids)
668
+ kv_indices_full = self.req_to_token_pool.req_to_token[
669
+ req.req_pool_idx, :seq_len
670
+ ]
671
+ state_indices = kv_indices_full.cpu().numpy()
672
+ state_indices = kv_to_page_indices(state_indices, page_size)
673
+
614
674
  page_indices = kv_to_page_indices(kv_indices, page_size)
615
675
  if len(page_indices) == 0:
616
676
  logger.info(
617
677
  f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
618
678
  )
619
679
  return
620
- req.disagg_kv_sender.send(page_indices)
621
-
622
- # PP
623
- @DynamicGradMode()
624
- def event_loop_pp_disagg_prefill(self: Scheduler):
625
- """
626
- An event loop for the prefill server in pipeline parallelism.
627
-
628
- Rules:
629
- 1. Each stage runs in the same order and is notified by the previous stage.
630
- 2. Each send/recv operation is blocking and matched by the neighboring stage.
631
-
632
- Regular Schedule:
633
- ====================================================================
634
- Stage i | Stage i+1
635
- send ith req | recv ith req
636
- send ith proxy | recv ith proxy
637
- send prev (i+1)th carry | recv prev (i+1)th carry
638
- ====================================================================
639
-
640
- Prefill Server Schedule:
641
- ====================================================================
642
- Stage i | Stage i+1
643
- send ith req | recv ith req
644
- send ith bootstrap req | recv ith bootstrap req
645
- send ith transferred req | recv ith transferred req
646
- send ith proxy | recv ith proxy
647
- send prev (i+1)th carry | recv prev (i+1)th carry
648
- send prev (i+1)th release req | recv prev (i+1)th release req
649
- ====================================================================
650
-
651
- There are two additional elements compared to the regular schedule:
652
-
653
- 1. Bootstrap Requests:
654
- a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization.
655
- b. The first stage polls the status and propagates the bootstrapped requests down to all other stages.
656
- c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together.
657
-
658
- 2. Transferred Requests + Release Requests:
659
- a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage.
660
- b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory.
661
- c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage.
662
- """
663
- from sglang.srt.managers.scheduler import GenerationBatchResult
664
-
665
- mbs = [None] * self.pp_size
666
- last_mbs = [None] * self.pp_size
667
- self.running_mbs = [
668
- ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
669
- ]
670
- bids = [None] * self.pp_size
671
- pp_outputs: Optional[PPProxyTensors] = None
672
-
673
- # Either success or failed
674
- bootstrapped_rids: List[str] = []
675
- transferred_rids: List[str] = []
676
- release_rids: Optional[List[str]] = None
677
-
678
- # transferred microbatch
679
- tmbs = [None] * self.pp_size
680
-
681
- ENABLE_RELEASE = True # For debug
682
-
683
- while True:
684
- server_is_idle = True
685
-
686
- for mb_id in range(self.pp_size):
687
- self.running_batch = self.running_mbs[mb_id]
688
- self.last_batch = last_mbs[mb_id]
689
-
690
- recv_reqs = self.recv_requests()
691
-
692
- self.process_input_requests(recv_reqs)
693
-
694
- if self.pp_group.is_first_rank:
695
- # First rank, pop the bootstrap reqs from the bootstrap queue
696
- bootstrapped_reqs, failed_reqs = (
697
- self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
698
- return_failed_reqs=True
699
- )
700
- )
701
- bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [
702
- req.rid for req in failed_reqs
703
- ]
704
- self.waiting_queue.extend(bootstrapped_reqs)
705
- else:
706
- # Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus
707
- bootstrapped_rids = self.recv_pyobj_from_prev_stage()
708
- bootstrapped_reqs = (
709
- self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
710
- rids_to_check=bootstrapped_rids
711
- )
712
- )
713
- self.waiting_queue.extend(bootstrapped_reqs)
714
-
715
- if self.pp_group.is_first_rank:
716
- transferred_rids = self.get_transferred_rids()
717
- # if other ranks,
718
- else:
719
- # 1. recv previous stage's transferred reqs info
720
- prev_transferred_rids = self.recv_pyobj_from_prev_stage()
721
- # 2. get the current stage's transferred reqs info
722
- curr_transferred_rids = self.get_transferred_rids()
723
- # 3. new consensus rids = intersection(previous consensus rids, transfer finished rids)
724
- transferred_rids = list(
725
- set(prev_transferred_rids) & set(curr_transferred_rids)
726
- )
727
-
728
- tmbs[mb_id] = transferred_rids
729
-
730
- self.process_prefill_chunk()
731
- mbs[mb_id] = self.get_new_batch_prefill()
732
- self.running_mbs[mb_id] = self.running_batch
733
-
734
- self.cur_batch = mbs[mb_id]
735
- if self.cur_batch:
736
- server_is_idle = False
737
- result = self.run_batch(self.cur_batch)
738
-
739
- # send the outputs to the next step
740
- if self.pp_group.is_last_rank:
741
- if self.cur_batch:
742
- next_token_ids, bids[mb_id] = (
743
- result.next_token_ids,
744
- result.bid,
745
- )
746
- pp_outputs = PPProxyTensors(
747
- {
748
- "next_token_ids": next_token_ids,
749
- }
750
- )
751
- # send the output from the last round to let the next stage worker run post processing
752
- self.pp_group.send_tensor_dict(
753
- pp_outputs.tensors,
754
- all_gather_group=self.attn_tp_group,
755
- )
756
-
757
- if ENABLE_RELEASE:
758
- if self.pp_group.is_last_rank:
759
- # At the last stage, all stages has reached the consensus to release memory for transferred_rids
760
- release_rids = transferred_rids
761
- # send to the first rank
762
- self.send_pyobj_to_next_stage(release_rids)
763
-
764
- # receive outputs and post-process (filter finished reqs) the coming microbatch
765
- next_mb_id = (mb_id + 1) % self.pp_size
766
- next_pp_outputs = None
767
- next_release_rids = None
768
-
769
- if mbs[next_mb_id] is not None:
770
- next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
771
- self.pp_group.recv_tensor_dict(
772
- all_gather_group=self.attn_tp_group
773
- )
774
- )
775
- mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
776
- output_result = GenerationBatchResult(
777
- logits_output=None,
778
- pp_hidden_states_proxy_tensors=None,
779
- next_token_ids=next_pp_outputs["next_token_ids"],
780
- extend_input_len_per_req=None,
781
- extend_logprob_start_len_per_req=None,
782
- bid=bids[next_mb_id],
783
- can_run_cuda_graph=result.can_run_cuda_graph,
784
- )
785
- self.process_batch_result_disagg_prefill(
786
- mbs[next_mb_id], output_result
787
- )
788
-
789
- last_mbs[next_mb_id] = mbs[next_mb_id]
790
-
791
- if ENABLE_RELEASE:
792
- if tmbs[next_mb_id] is not None:
793
- # recv consensus rids from the previous rank
794
- next_release_rids = self.recv_pyobj_from_prev_stage()
795
- self.process_disagg_prefill_inflight_queue(next_release_rids)
796
-
797
- # carry the outputs to the next stage
798
- if not self.pp_group.is_last_rank:
799
- if self.cur_batch:
800
- bids[mb_id] = result.bid
801
- if pp_outputs:
802
- # send the outputs from the last round to let the next stage worker run post processing
803
- self.pp_group.send_tensor_dict(
804
- pp_outputs.tensors,
805
- all_gather_group=self.attn_tp_group,
806
- )
807
- if ENABLE_RELEASE:
808
- if release_rids is not None:
809
- self.send_pyobj_to_next_stage(release_rids)
810
-
811
- if not self.pp_group.is_last_rank:
812
- # send out reqs to the next stage
813
- self.send_pyobj_to_next_stage(recv_reqs)
814
- self.send_pyobj_to_next_stage(bootstrapped_rids)
815
- self.send_pyobj_to_next_stage(transferred_rids)
816
-
817
- # send out proxy tensors to the next stage
818
- if self.cur_batch:
819
- self.pp_group.send_tensor_dict(
820
- result.pp_hidden_states_proxy_tensors,
821
- all_gather_group=self.attn_tp_group,
822
- )
823
-
824
- pp_outputs = next_pp_outputs
825
- release_rids = next_release_rids
826
-
827
- self.running_batch.batch_is_full = False
828
-
829
- if not ENABLE_RELEASE:
830
- if len(self.disagg_prefill_inflight_queue) > 0:
831
- self.process_disagg_prefill_inflight_queue()
832
-
833
- # When the server is idle, self-check and re-init some states
834
- if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0:
835
- self.check_memory()
836
- self.check_tree_cache()
837
- self.new_token_ratio = self.init_new_token_ratio
680
+ req.disagg_kv_sender.send(page_indices, state_indices)
838
681
 
839
682
  def send_pyobj_to_next_stage(self, data):
840
683
  if self.attn_tp_rank == 0:
@@ -5,7 +5,7 @@ import random
5
5
  from collections import deque
6
6
  from contextlib import nullcontext
7
7
  from enum import Enum
8
- from typing import TYPE_CHECKING, List, Optional, Type, Union
8
+ from typing import TYPE_CHECKING, Optional, Type
9
9
 
10
10
  import numpy as np
11
11
  import torch
@@ -85,7 +85,7 @@ class MetadataBuffers:
85
85
  self,
86
86
  size: int,
87
87
  hidden_size: int,
88
- dtype: torch.dtype,
88
+ hidden_states_dtype: torch.dtype,
89
89
  max_top_logprobs_num: int = 128,
90
90
  custom_mem_pool: torch.cuda.MemPool = None,
91
91
  ):
@@ -107,7 +107,9 @@ class MetadataBuffers:
107
107
  # We transfer the metadata of first output token to decode
108
108
  # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
109
109
  self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
110
-
110
+ self.cached_tokens = torch.zeros(
111
+ (size, 16), dtype=torch.int32, device=device
112
+ )
111
113
  self.output_token_logprobs_val = torch.zeros(
112
114
  (size, 16), dtype=torch.float32, device=device
113
115
  )
@@ -120,33 +122,49 @@ class MetadataBuffers:
120
122
  self.output_top_logprobs_idx = torch.zeros(
121
123
  (size, max_top_logprobs_num), dtype=torch.int32, device=device
122
124
  )
125
+ # For PD + spec decode
126
+ self.output_topk_p = torch.zeros(
127
+ (size, 16), dtype=torch.float32, device=device
128
+ )
129
+ self.output_topk_index = torch.zeros(
130
+ (size, 16), dtype=torch.int64, device=device
131
+ )
123
132
  self.output_hidden_states = torch.zeros(
124
- (size, hidden_size), dtype=dtype, device=device
133
+ (size, hidden_size), dtype=hidden_states_dtype, device=device
125
134
  )
126
135
 
127
136
  def get_buf_infos(self):
128
137
  ptrs = [
129
138
  self.output_ids.data_ptr(),
139
+ self.cached_tokens.data_ptr(),
130
140
  self.output_token_logprobs_val.data_ptr(),
131
141
  self.output_token_logprobs_idx.data_ptr(),
132
142
  self.output_top_logprobs_val.data_ptr(),
133
143
  self.output_top_logprobs_idx.data_ptr(),
144
+ self.output_topk_p.data_ptr(),
145
+ self.output_topk_index.data_ptr(),
134
146
  self.output_hidden_states.data_ptr(),
135
147
  ]
136
148
  data_lens = [
137
149
  self.output_ids.nbytes,
150
+ self.cached_tokens.nbytes,
138
151
  self.output_token_logprobs_val.nbytes,
139
152
  self.output_token_logprobs_idx.nbytes,
140
153
  self.output_top_logprobs_val.nbytes,
141
154
  self.output_top_logprobs_idx.nbytes,
155
+ self.output_topk_p.nbytes,
156
+ self.output_topk_index.nbytes,
142
157
  self.output_hidden_states.nbytes,
143
158
  ]
144
159
  item_lens = [
145
160
  self.output_ids[0].nbytes,
161
+ self.cached_tokens[0].nbytes,
146
162
  self.output_token_logprobs_val[0].nbytes,
147
163
  self.output_token_logprobs_idx[0].nbytes,
148
164
  self.output_top_logprobs_val[0].nbytes,
149
165
  self.output_top_logprobs_idx[0].nbytes,
166
+ self.output_topk_p[0].nbytes,
167
+ self.output_topk_index[0].nbytes,
150
168
  self.output_hidden_states[0].nbytes,
151
169
  ]
152
170
  return ptrs, data_lens, item_lens
@@ -154,16 +172,20 @@ class MetadataBuffers:
154
172
  def get_buf(self, idx: int):
155
173
  return (
156
174
  self.output_ids[idx],
175
+ self.cached_tokens[idx],
157
176
  self.output_token_logprobs_val[idx],
158
177
  self.output_token_logprobs_idx[idx],
159
178
  self.output_top_logprobs_val[idx],
160
179
  self.output_top_logprobs_idx[idx],
180
+ self.output_topk_p[idx],
181
+ self.output_topk_index[idx],
161
182
  self.output_hidden_states[idx],
162
183
  )
163
184
 
164
185
  def set_buf(self, req: Req):
165
186
 
166
187
  self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
188
+ self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens
167
189
  if req.return_logprob:
168
190
  if req.output_token_logprobs_val: # not none or empty list
169
191
  self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
@@ -186,8 +208,17 @@ class MetadataBuffers:
186
208
  ] = torch.tensor(
187
209
  req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
188
210
  )
189
- # for PD + spec decode
211
+ # For PD + spec decode
190
212
  if req.hidden_states_tensor is not None:
213
+ # speculative_eagle_topk should not be greater than 16 currently
214
+ topk = req.output_topk_p.size(0)
215
+
216
+ self.output_topk_p[req.metadata_buffer_index, :topk].copy_(
217
+ req.output_topk_p
218
+ )
219
+ self.output_topk_index[req.metadata_buffer_index, :topk].copy_(
220
+ req.output_topk_index
221
+ )
191
222
  self.output_hidden_states[req.metadata_buffer_index].copy_(
192
223
  req.hidden_states_tensor
193
224
  )
@@ -0,0 +1,16 @@
1
+ MiB = 1024 * 1024
2
+
3
+ SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
4
+ 9: {
5
+ 2: 64 * MiB, # 64 MB
6
+ 4: 64 * MiB, # 64 MB
7
+ 6: 128 * MiB, # 128 MB
8
+ 8: 128 * MiB, # 128 MB
9
+ },
10
+ 10: {
11
+ 2: 64 * MiB, # 64 MB
12
+ 4: 64 * MiB, # 64 MB
13
+ 6: 128 * MiB, # 128 MB
14
+ 8: 128 * MiB, # 128 MB
15
+ },
16
+ }