sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/lang/interpreter.py +1 -1
  7. sglang/launch_server.py +14 -0
  8. sglang/profiler.py +2 -2
  9. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  10. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  11. sglang/srt/configs/__init__.py +8 -0
  12. sglang/srt/configs/device_config.py +3 -1
  13. sglang/srt/configs/dots_ocr.py +64 -0
  14. sglang/srt/configs/dots_vlm.py +139 -0
  15. sglang/srt/configs/falcon_h1.py +360 -0
  16. sglang/srt/configs/internvl.py +6 -0
  17. sglang/srt/configs/load_config.py +9 -0
  18. sglang/srt/configs/model_config.py +181 -82
  19. sglang/srt/configs/qwen3_next.py +326 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +71 -19
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +326 -53
  44. sglang/srt/disaggregation/prefill.py +36 -17
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +192 -113
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  52. sglang/srt/entrypoints/grpc_server.py +810 -0
  53. sglang/srt/entrypoints/http_server.py +132 -57
  54. sglang/srt/entrypoints/openai/protocol.py +115 -7
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +207 -58
  57. sglang/srt/entrypoints/openai/serving_completions.py +17 -4
  58. sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +49 -4
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/environ.py +285 -0
  63. sglang/srt/eplb/eplb_manager.py +2 -2
  64. sglang/srt/eplb/expert_distribution.py +26 -13
  65. sglang/srt/eplb/expert_location.py +38 -8
  66. sglang/srt/eplb/expert_location_updater.py +1 -1
  67. sglang/srt/function_call/base_format_detector.py +3 -6
  68. sglang/srt/function_call/ebnf_composer.py +11 -9
  69. sglang/srt/function_call/function_call_parser.py +9 -2
  70. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  71. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  72. sglang/srt/function_call/json_array_parser.py +63 -0
  73. sglang/srt/function_call/kimik2_detector.py +17 -4
  74. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  75. sglang/srt/function_call/utils.py +96 -5
  76. sglang/srt/grpc/__init__.py +1 -0
  77. sglang/srt/grpc/compile_proto.py +245 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  81. sglang/srt/layers/activation.py +143 -9
  82. sglang/srt/layers/attention/aiter_backend.py +106 -82
  83. sglang/srt/layers/attention/ascend_backend.py +115 -9
  84. sglang/srt/layers/attention/attention_registry.py +206 -0
  85. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  86. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  87. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  88. sglang/srt/layers/attention/fla/chunk.py +242 -0
  89. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  90. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  91. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  92. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  93. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  94. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  95. sglang/srt/layers/attention/fla/index.py +37 -0
  96. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  97. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  98. sglang/srt/layers/attention/fla/op.py +66 -0
  99. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  100. sglang/srt/layers/attention/fla/utils.py +331 -0
  101. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  102. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  103. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  104. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  105. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  106. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  107. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  108. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  109. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  111. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  112. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  113. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  114. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  115. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  121. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  122. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  123. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  124. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  125. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  126. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  127. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  128. sglang/srt/layers/attention/nsa/utils.py +24 -0
  129. sglang/srt/layers/attention/nsa_backend.py +887 -0
  130. sglang/srt/layers/attention/tbo_backend.py +6 -6
  131. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  132. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  133. sglang/srt/layers/attention/triton_backend.py +57 -7
  134. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  135. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  136. sglang/srt/layers/attention/vision.py +58 -0
  137. sglang/srt/layers/attention/wave_backend.py +4 -4
  138. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  139. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  140. sglang/srt/layers/communicator.py +53 -7
  141. sglang/srt/layers/dp_attention.py +41 -2
  142. sglang/srt/layers/elementwise.py +3 -1
  143. sglang/srt/layers/layernorm.py +34 -15
  144. sglang/srt/layers/linear.py +55 -7
  145. sglang/srt/layers/logits_processor.py +44 -12
  146. sglang/srt/layers/moe/__init__.py +2 -1
  147. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  148. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  149. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  150. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  151. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  167. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  169. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  170. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  171. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  172. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  173. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  174. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  175. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  176. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  177. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  178. sglang/srt/layers/moe/topk.py +30 -9
  179. sglang/srt/layers/moe/utils.py +22 -7
  180. sglang/srt/layers/parameter.py +23 -6
  181. sglang/srt/layers/quantization/awq.py +19 -7
  182. sglang/srt/layers/quantization/base_config.py +11 -6
  183. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  184. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  185. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  186. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  187. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  188. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  189. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  190. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  191. sglang/srt/layers/quantization/fp8.py +78 -49
  192. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  193. sglang/srt/layers/quantization/gptq.py +25 -17
  194. sglang/srt/layers/quantization/modelopt_quant.py +225 -57
  195. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  196. sglang/srt/layers/quantization/mxfp4.py +77 -42
  197. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  199. sglang/srt/layers/quantization/quark/utils.py +97 -0
  200. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  201. sglang/srt/layers/quantization/unquant.py +135 -47
  202. sglang/srt/layers/quantization/w4afp8.py +26 -17
  203. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  204. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  205. sglang/srt/layers/rocm_linear_utils.py +44 -0
  206. sglang/srt/layers/rotary_embedding.py +78 -49
  207. sglang/srt/layers/sampler.py +213 -21
  208. sglang/srt/layers/utils.py +23 -0
  209. sglang/srt/lora/backend/base_backend.py +50 -8
  210. sglang/srt/lora/backend/chunked_backend.py +348 -0
  211. sglang/srt/lora/backend/triton_backend.py +99 -5
  212. sglang/srt/lora/layers.py +32 -0
  213. sglang/srt/lora/lora.py +8 -3
  214. sglang/srt/lora/lora_manager.py +52 -118
  215. sglang/srt/lora/mem_pool.py +25 -11
  216. sglang/srt/lora/triton_ops/__init__.py +4 -0
  217. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  219. sglang/srt/lora/utils.py +22 -11
  220. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  221. sglang/srt/managers/cache_controller.py +215 -314
  222. sglang/srt/managers/data_parallel_controller.py +115 -80
  223. sglang/srt/managers/detokenizer_manager.py +19 -15
  224. sglang/srt/managers/disagg_service.py +46 -0
  225. sglang/srt/managers/io_struct.py +340 -109
  226. sglang/srt/managers/mm_utils.py +44 -6
  227. sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
  228. sglang/srt/managers/multimodal_processor.py +1 -2
  229. sglang/srt/managers/overlap_utils.py +53 -0
  230. sglang/srt/managers/schedule_batch.py +240 -138
  231. sglang/srt/managers/schedule_policy.py +147 -19
  232. sglang/srt/managers/scheduler.py +501 -304
  233. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  234. sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
  235. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  236. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  237. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  238. sglang/srt/managers/template_manager.py +3 -3
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +321 -632
  241. sglang/srt/managers/tp_worker.py +81 -22
  242. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  243. sglang/srt/managers/utils.py +1 -45
  244. sglang/srt/mem_cache/allocator.py +15 -21
  245. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  246. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  247. sglang/srt/mem_cache/chunk_cache.py +8 -1
  248. sglang/srt/mem_cache/evict_policy.py +23 -0
  249. sglang/srt/mem_cache/hicache_storage.py +58 -34
  250. sglang/srt/mem_cache/hiradix_cache.py +227 -80
  251. sglang/srt/mem_cache/memory_pool.py +535 -58
  252. sglang/srt/mem_cache/memory_pool_host.py +239 -223
  253. sglang/srt/mem_cache/radix_cache.py +222 -73
  254. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  255. sglang/srt/mem_cache/storage/__init__.py +10 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  257. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  258. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  259. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  260. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  261. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  262. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  263. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
  264. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  265. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  266. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
  267. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  268. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  269. sglang/srt/metrics/collector.py +519 -132
  270. sglang/srt/metrics/func_timer.py +2 -7
  271. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  272. sglang/srt/metrics/utils.py +55 -0
  273. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  274. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  275. sglang/srt/model_executor/forward_batch_info.py +98 -57
  276. sglang/srt/model_executor/model_runner.py +433 -158
  277. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  278. sglang/srt/model_loader/__init__.py +9 -3
  279. sglang/srt/model_loader/loader.py +133 -5
  280. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  281. sglang/srt/model_loader/weight_utils.py +158 -3
  282. sglang/srt/models/apertus.py +686 -0
  283. sglang/srt/models/bailing_moe.py +820 -217
  284. sglang/srt/models/bailing_moe_nextn.py +168 -0
  285. sglang/srt/models/deepseek_nextn.py +6 -1
  286. sglang/srt/models/deepseek_v2.py +833 -152
  287. sglang/srt/models/dots_ocr.py +173 -0
  288. sglang/srt/models/dots_vlm.py +174 -0
  289. sglang/srt/models/dots_vlm_vit.py +337 -0
  290. sglang/srt/models/ernie4.py +1 -1
  291. sglang/srt/models/falcon_h1.py +576 -0
  292. sglang/srt/models/gemma3_causal.py +0 -2
  293. sglang/srt/models/gemma3_mm.py +1 -1
  294. sglang/srt/models/gemma3n_mm.py +2 -2
  295. sglang/srt/models/glm4_moe.py +14 -5
  296. sglang/srt/models/glm4_moe_nextn.py +2 -2
  297. sglang/srt/models/glm4v.py +5 -3
  298. sglang/srt/models/glm4v_moe.py +4 -1
  299. sglang/srt/models/gpt_oss.py +8 -31
  300. sglang/srt/models/internvl.py +28 -0
  301. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  302. sglang/srt/models/llama.py +4 -0
  303. sglang/srt/models/llama4.py +9 -0
  304. sglang/srt/models/llama_eagle3.py +13 -0
  305. sglang/srt/models/longcat_flash.py +3 -3
  306. sglang/srt/models/longcat_flash_nextn.py +1 -1
  307. sglang/srt/models/minicpmv.py +165 -3
  308. sglang/srt/models/mllama4.py +40 -4
  309. sglang/srt/models/opt.py +637 -0
  310. sglang/srt/models/qwen2_5_vl.py +29 -5
  311. sglang/srt/models/qwen2_audio.py +1 -1
  312. sglang/srt/models/qwen2_moe.py +124 -14
  313. sglang/srt/models/qwen2_vl.py +1 -1
  314. sglang/srt/models/qwen3.py +26 -5
  315. sglang/srt/models/qwen3_moe.py +71 -12
  316. sglang/srt/models/qwen3_next.py +1069 -0
  317. sglang/srt/models/qwen3_next_mtp.py +112 -0
  318. sglang/srt/models/qwen3_vl.py +787 -0
  319. sglang/srt/models/qwen3_vl_moe.py +471 -0
  320. sglang/srt/models/registry.py +15 -3
  321. sglang/srt/models/sarashina2_vision.py +269 -0
  322. sglang/srt/models/solar.py +505 -0
  323. sglang/srt/models/starcoder2.py +357 -0
  324. sglang/srt/models/step3_vl.py +1 -1
  325. sglang/srt/models/torch_native_llama.py +10 -3
  326. sglang/srt/models/utils.py +51 -0
  327. sglang/srt/multimodal/processors/base_processor.py +15 -7
  328. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  329. sglang/srt/multimodal/processors/glm4v.py +9 -9
  330. sglang/srt/multimodal/processors/internvl.py +153 -129
  331. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  332. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  333. sglang/srt/offloader.py +27 -3
  334. sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
  335. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  336. sglang/srt/sampling/sampling_batch_info.py +38 -17
  337. sglang/srt/sampling/sampling_params.py +7 -0
  338. sglang/srt/server_args.py +1030 -254
  339. sglang/srt/server_args_config_parser.py +146 -0
  340. sglang/srt/single_batch_overlap.py +151 -0
  341. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  342. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  343. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  344. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  345. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  346. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  347. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  348. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  349. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  350. sglang/srt/speculative/eagle_worker.py +253 -136
  351. sglang/srt/speculative/ngram_utils.py +428 -0
  352. sglang/srt/speculative/ngram_worker.py +245 -0
  353. sglang/srt/speculative/spec_info.py +52 -0
  354. sglang/srt/speculative/spec_utils.py +606 -0
  355. sglang/srt/speculative/standalone_worker.py +109 -0
  356. sglang/srt/torch_memory_saver_adapter.py +5 -7
  357. sglang/srt/tracing/trace.py +578 -0
  358. sglang/srt/two_batch_overlap.py +8 -5
  359. sglang/srt/utils/__init__.py +2 -0
  360. sglang/srt/{utils.py → utils/common.py} +445 -77
  361. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  362. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  363. sglang/srt/utils/rpd_utils.py +452 -0
  364. sglang/srt/utils/slow_rank_detector.py +71 -0
  365. sglang/srt/warmup.py +8 -4
  366. sglang/srt/weight_sync/utils.py +2 -2
  367. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  368. sglang/test/few_shot_gsm8k.py +1 -0
  369. sglang/test/get_logits_ut.py +57 -0
  370. sglang/test/run_eval.py +79 -11
  371. sglang/test/runners.py +5 -1
  372. sglang/test/simple_eval_common.py +5 -2
  373. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  374. sglang/test/test_block_fp8.py +2 -2
  375. sglang/test/test_cutlass_moe.py +24 -6
  376. sglang/test/test_deterministic.py +297 -0
  377. sglang/test/test_disaggregation_utils.py +77 -0
  378. sglang/test/test_fp4_moe.py +370 -1
  379. sglang/test/test_programs.py +1 -1
  380. sglang/test/test_utils.py +383 -5
  381. sglang/utils.py +22 -1
  382. sglang/version.py +1 -1
  383. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  384. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
  385. sglang/srt/disaggregation/launch_lb.py +0 -118
  386. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  387. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  388. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  389. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  390. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  391. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  392. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  393. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  394. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  395. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,10 @@ import inspect
19
19
  import json
20
20
  import logging
21
21
  import os
22
+ import socket
23
+ import threading
22
24
  import time
25
+ from collections import defaultdict
23
26
  from dataclasses import dataclass
24
27
  from typing import List, Optional, Tuple, Union
25
28
 
@@ -27,17 +30,24 @@ import torch
27
30
  import torch.distributed as dist
28
31
 
29
32
  from sglang.srt.configs.device_config import DeviceConfig
30
- from sglang.srt.configs.load_config import LoadConfig
31
- from sglang.srt.configs.model_config import AttentionArch, ModelConfig
33
+ from sglang.srt.configs.load_config import LoadConfig, LoadFormat
34
+ from sglang.srt.configs.model_config import (
35
+ AttentionArch,
36
+ ModelConfig,
37
+ get_nsa_index_head_dim,
38
+ is_deepseek_nsa,
39
+ )
32
40
  from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
33
41
  from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
34
42
  from sglang.srt.distributed import (
43
+ get_pp_group,
35
44
  get_tp_group,
36
45
  get_world_group,
37
46
  init_distributed_environment,
38
47
  initialize_model_parallel,
39
48
  set_custom_all_reduce,
40
49
  set_mscclpp_all_reduce,
50
+ set_symm_mem_all_reduce,
41
51
  )
42
52
  from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
43
53
  from sglang.srt.eplb.eplb_manager import EPLBManager
@@ -53,6 +63,10 @@ from sglang.srt.eplb.expert_location import (
53
63
  set_global_expert_location_metadata,
54
64
  )
55
65
  from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
66
+ from sglang.srt.layers.attention.attention_registry import (
67
+ ATTENTION_BACKENDS,
68
+ attn_backend_wrapper,
69
+ )
56
70
  from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
57
71
  from sglang.srt.layers.dp_attention import (
58
72
  get_attention_tp_group,
@@ -83,16 +97,23 @@ from sglang.srt.mem_cache.memory_pool import (
83
97
  AscendMLAPagedTokenToKVPool,
84
98
  AscendTokenToKVPool,
85
99
  DoubleSparseTokenToKVPool,
100
+ HybridLinearKVPool,
101
+ HybridReqToTokenPool,
86
102
  MHATokenToKVPool,
87
103
  MLATokenToKVPool,
104
+ NSATokenToKVPool,
88
105
  ReqToTokenPool,
89
106
  SWAKVPool,
90
107
  )
108
+ from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
91
109
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
92
110
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
93
111
  from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
94
112
  from sglang.srt.model_loader import get_model
95
113
  from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
114
+ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
115
+ trigger_init_weights_send_group_for_remote_instance_request,
116
+ )
96
117
  from sglang.srt.model_loader.utils import set_default_torch_dtype
97
118
  from sglang.srt.model_loader.weight_utils import default_weight_loader
98
119
  from sglang.srt.offloader import (
@@ -100,7 +121,6 @@ from sglang.srt.offloader import (
100
121
  get_offloader,
101
122
  set_offloader,
102
123
  )
103
- from sglang.srt.patch_torch import monkey_patch_torch_reductions
104
124
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
105
125
  from sglang.srt.server_args import ServerArgs
106
126
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -121,15 +141,38 @@ from sglang.srt.utils import (
121
141
  is_no_spec_infer_or_topk_one,
122
142
  is_npu,
123
143
  is_sm100_supported,
144
+ log_info_on_rank0,
124
145
  monkey_patch_p2p_access_check,
125
146
  monkey_patch_vllm_gguf_config,
126
147
  set_cuda_arch,
148
+ slow_rank_detector,
127
149
  )
150
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
128
151
  from sglang.srt.weight_sync.tensor_bucket import (
129
152
  FlattenedTensorBucket,
130
153
  FlattenedTensorMetadata,
131
154
  )
132
155
 
156
+ MLA_ATTENTION_BACKENDS = [
157
+ "aiter",
158
+ "flashinfer",
159
+ "fa3",
160
+ "fa4",
161
+ "triton",
162
+ "flashmla",
163
+ "cutlass_mla",
164
+ "trtllm_mla",
165
+ "ascend",
166
+ "nsa",
167
+ ]
168
+
169
+
170
+ def add_mla_attention_backend(backend_name):
171
+ if backend_name not in MLA_ATTENTION_BACKENDS:
172
+ MLA_ATTENTION_BACKENDS.append(backend_name)
173
+ logger.info(f"Added {backend_name} to MLA_ATTENTION_BACKENDS.")
174
+
175
+
133
176
  _is_hip = is_hip()
134
177
  _is_npu = is_npu()
135
178
  _is_cpu_amx_available = cpu_has_amx_support()
@@ -143,6 +186,13 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
143
186
  logger = logging.getLogger(__name__)
144
187
 
145
188
 
189
+ if _is_npu:
190
+ import torch_npu
191
+
192
+ torch.npu.config.allow_internal_format = True
193
+ torch_npu.npu.set_compile_mode(jit_compile=False)
194
+
195
+
146
196
  class RankZeroFilter(logging.Filter):
147
197
  """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
148
198
 
@@ -237,6 +287,9 @@ class ModelRunner:
237
287
  # CPU offload
238
288
  set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
239
289
 
290
+ if get_bool_env_var("SGLANG_DETECT_SLOW_RANK"):
291
+ slow_rank_detector.execute()
292
+
240
293
  # Update deep gemm configure
241
294
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
242
295
  deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
@@ -251,6 +304,7 @@ class ModelRunner:
251
304
 
252
305
  # For weight updates
253
306
  self._model_update_group = {}
307
+ self._weights_send_group = {}
254
308
 
255
309
  def initialize(self, min_per_gpu_memory: float):
256
310
  server_args = self.server_args
@@ -300,6 +354,25 @@ class ModelRunner:
300
354
  if architectures and not any("Llama4" in arch for arch in architectures):
301
355
  self.is_hybrid = self.model_config.is_hybrid = True
302
356
 
357
+ if self.is_hybrid_gdn:
358
+ logger.warning("Hybrid GDN model detected, disable radix cache")
359
+ self.server_args.disable_radix_cache = True
360
+ if self.server_args.max_mamba_cache_size is None:
361
+ if self.server_args.max_running_requests is not None:
362
+ self.server_args.max_mamba_cache_size = (
363
+ self.server_args.max_running_requests
364
+ )
365
+ else:
366
+ self.server_args.max_mamba_cache_size = 512
367
+ self.server_args.max_mamba_cache_size = (
368
+ self.server_args.max_mamba_cache_size
369
+ // (
370
+ self.server_args.dp_size
371
+ if self.server_args.enable_dp_attention
372
+ else 1
373
+ )
374
+ )
375
+
303
376
  # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
304
377
  # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
305
378
  # determine the number of layers.
@@ -341,6 +414,20 @@ class ModelRunner:
341
414
  if server_args.enable_lora:
342
415
  self.init_lora_manager()
343
416
 
417
+ # Init Double Sparsity
418
+ if server_args.enable_double_sparsity:
419
+ if server_args.ds_heavy_channel_type is None:
420
+ raise ValueError(
421
+ "Please specify the heavy channel type for double sparsity optimization."
422
+ )
423
+ self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
424
+
425
+ # Enable batch invariant mode
426
+ if server_args.enable_deterministic_inference:
427
+ from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode
428
+
429
+ enable_batch_invariant_mode()
430
+
344
431
  # Init memory pool and attention backends
345
432
  self.init_memory_pool(
346
433
  min_per_gpu_memory,
@@ -351,12 +438,12 @@ class ModelRunner:
351
438
  self.init_cublas()
352
439
  self.init_attention_backend()
353
440
  self.init_device_graphs()
354
- elif self.device == "npu":
441
+ elif self.device in ["npu", "cpu"]:
355
442
  self.init_attention_backend()
356
443
  self.init_device_graphs()
357
444
  else:
358
445
  self.graph_runner = None
359
- self.cuda_graph_mem_usage = 0
446
+ self.graph_mem_usage = 0
360
447
  self.init_attention_backend()
361
448
 
362
449
  # auxiliary hidden capture mode. TODO: expose this to server args?
@@ -452,9 +539,7 @@ class ModelRunner:
452
539
  elif _is_hip:
453
540
  head_num = self.model_config.get_num_kv_heads(self.tp_size)
454
541
  # TODO current aiter only support head number 16 or 128 head number
455
- if (
456
- head_num == 128 or head_num == 16
457
- ) and self.spec_algorithm.is_none():
542
+ if head_num == 128 or head_num == 16:
458
543
  server_args.attention_backend = "aiter"
459
544
  else:
460
545
  server_args.attention_backend = "triton"
@@ -467,16 +552,7 @@ class ModelRunner:
467
552
  )
468
553
  elif self.use_mla_backend:
469
554
  if server_args.device != "cpu":
470
- if server_args.attention_backend in [
471
- "aiter",
472
- "flashinfer",
473
- "fa3",
474
- "triton",
475
- "flashmla",
476
- "cutlass_mla",
477
- "trtllm_mla",
478
- "ascend",
479
- ]:
555
+ if server_args.attention_backend in MLA_ATTENTION_BACKENDS:
480
556
  logger.info(
481
557
  f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
482
558
  )
@@ -506,11 +582,6 @@ class ModelRunner:
506
582
  )
507
583
  server_args.attention_backend = "triton"
508
584
  server_args.disable_cuda_graph = True
509
- if server_args.ds_heavy_channel_type is None:
510
- raise ValueError(
511
- "Please specify the heavy channel type for double sparsity optimization."
512
- )
513
- self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
514
585
 
515
586
  if self.is_multimodal:
516
587
  if not self.is_multimodal_chunked_prefill_supported:
@@ -548,7 +619,7 @@ class ModelRunner:
548
619
  server_args.hicache_io_backend = "direct"
549
620
  logger.warning(
550
621
  "FlashAttention3 decode backend is not compatible with hierarchical cache. "
551
- f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
622
+ "Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
552
623
  )
553
624
 
554
625
  def init_torch_distributed(self):
@@ -583,6 +654,7 @@ class ModelRunner:
583
654
  dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
584
655
  set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
585
656
  set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
657
+ set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
586
658
 
587
659
  if not self.is_draft_worker:
588
660
  if self.device == "cpu":
@@ -593,6 +665,11 @@ class ModelRunner:
593
665
  # Set local size to hint SGLang to use shared memory based AllReduce
594
666
  os.environ["LOCAL_SIZE"] = str(self.tp_size)
595
667
  torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
668
+
669
+ @torch.library.register_fake("sgl_kernel::shm_allgather")
670
+ def _(data, dim):
671
+ return torch.cat([data] * self.tp_size, dim=dim)
672
+
596
673
  else:
597
674
  logger.warning(
598
675
  "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
@@ -625,6 +702,7 @@ class ModelRunner:
625
702
  cpu_group=get_world_group().cpu_group,
626
703
  )
627
704
  self.tp_group = get_tp_group()
705
+ self.pp_group = get_pp_group()
628
706
  self.attention_tp_group = get_attention_tp_group()
629
707
 
630
708
  # Check memory for tensor parallelism
@@ -673,6 +751,10 @@ class ModelRunner:
673
751
  load_format=self.server_args.load_format,
674
752
  download_dir=self.server_args.download_dir,
675
753
  model_loader_extra_config=self.server_args.model_loader_extra_config,
754
+ tp_rank=self.tp_rank,
755
+ remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
756
+ remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
757
+ remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
676
758
  )
677
759
  if self.device == "cpu":
678
760
  self.model_config = adjust_config_with_unaligned_cpu_tp(
@@ -681,16 +763,33 @@ class ModelRunner:
681
763
  if self.server_args.load_format == "gguf":
682
764
  monkey_patch_vllm_gguf_config()
683
765
 
766
+ if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
767
+ if self.tp_rank == 0:
768
+ instance_ip = socket.gethostbyname(socket.gethostname())
769
+ t = threading.Thread(
770
+ target=trigger_init_weights_send_group_for_remote_instance_request,
771
+ args=(
772
+ self.server_args.remote_instance_weight_loader_seed_instance_ip,
773
+ self.server_args.remote_instance_weight_loader_seed_instance_service_port,
774
+ self.server_args.remote_instance_weight_loader_send_weights_group_ports,
775
+ instance_ip,
776
+ ),
777
+ )
778
+ t.start()
779
+
684
780
  # Load the model
685
781
  # Remove monkey_patch when linear.py quant remove dependencies with vllm
686
782
  monkey_patch_vllm_parallel_state()
687
783
  monkey_patch_isinstance_for_vllm_base_layer()
688
784
 
689
- with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
785
+ with self.memory_saver_adapter.region(
786
+ GPU_MEMORY_TYPE_WEIGHTS,
787
+ enable_cpu_backup=self.server_args.enable_weights_cpu_backup,
788
+ ):
690
789
  self.model = get_model(
691
790
  model_config=self.model_config,
692
791
  load_config=self.load_config,
693
- device_config=DeviceConfig(self.device),
792
+ device_config=DeviceConfig(self.device, self.gpu_id),
694
793
  )
695
794
  monkey_patch_vllm_parallel_state(reverse=True)
696
795
  monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
@@ -822,6 +921,103 @@ class ModelRunner:
822
921
  logger.info("Update weights end.")
823
922
  return True, "Succeeded to update model weights."
824
923
 
924
+ def init_weights_send_group_for_remote_instance(
925
+ self,
926
+ master_address,
927
+ ports,
928
+ group_rank,
929
+ world_size,
930
+ group_name,
931
+ backend="nccl",
932
+ ):
933
+ assert (
934
+ torch.distributed.is_initialized()
935
+ ), "Default torch process group must be initialized"
936
+ assert group_name != "", "Group name cannot be empty"
937
+
938
+ ports_list = ports.split(",")
939
+ assert (
940
+ len(ports_list) == self.tp_size
941
+ ), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
942
+ group_port = ports_list[self.tp_rank]
943
+ group_name = f"{group_name}_{group_port}_{self.tp_rank}"
944
+
945
+ logger.info(
946
+ f"init custom process group: tp_rank={self.tp_rank}, gpu_id={self.gpu_id}, master_address={master_address}, master_port={group_port}, "
947
+ f"group_rank={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
948
+ )
949
+
950
+ torch.cuda.empty_cache()
951
+ success = False
952
+ message = ""
953
+ try:
954
+ self._weights_send_group[group_name] = init_custom_process_group(
955
+ backend=backend,
956
+ init_method=f"tcp://{master_address}:{group_port}",
957
+ world_size=world_size,
958
+ rank=group_rank,
959
+ group_name=group_name,
960
+ device_id=torch.device("cuda", self.gpu_id),
961
+ )
962
+ dist.barrier(group=self._weights_send_group[group_name])
963
+ success = True
964
+ message = (
965
+ f"Succeeded to init group through {master_address}:{group_port} group."
966
+ )
967
+ except Exception as e:
968
+ message = f"Failed to init group: {e}."
969
+ logger.error(message)
970
+
971
+ torch.cuda.empty_cache()
972
+ return success, message
973
+
974
+ def send_weights_to_remote_instance(
975
+ self,
976
+ master_address,
977
+ ports,
978
+ group_name,
979
+ ):
980
+ assert (
981
+ torch.distributed.is_initialized()
982
+ ), "Default torch process group must be initialized"
983
+ assert group_name != "", "Group name cannot be empty"
984
+
985
+ ports_list = ports.split(",")
986
+ assert (
987
+ len(ports_list) == self.tp_size
988
+ ), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
989
+ group_port = ports_list[self.tp_rank]
990
+ group_name = f"{group_name}_{group_port}_{self.tp_rank}"
991
+
992
+ if self._weights_send_group[group_name] is not None:
993
+ send_group = self._weights_send_group[group_name]
994
+ else:
995
+ message = f"Group {group_name} not in _weights_send_group list. Please call `init_weights_send_group_for_remote_instance` first."
996
+ logger.error(message)
997
+ return False, message
998
+
999
+ torch.cuda.empty_cache()
1000
+ success = False
1001
+ message = ""
1002
+ try:
1003
+ for _, weights in self.model.named_parameters():
1004
+ torch.distributed.broadcast(
1005
+ weights,
1006
+ src=0,
1007
+ group=send_group,
1008
+ )
1009
+ success = True
1010
+ message = f"Succeeded to send weights through {master_address}:{group_port} {group_name}."
1011
+ except Exception as e:
1012
+ message = f"Failed to send weights: {e}."
1013
+ logger.error(message)
1014
+
1015
+ # destroy the process group after sending weights
1016
+ del self._weights_send_group[group_name]
1017
+ torch.distributed.distributed_c10d.destroy_process_group(send_group)
1018
+ torch.cuda.empty_cache()
1019
+ return success, message
1020
+
825
1021
  def init_weights_update_group(
826
1022
  self,
827
1023
  master_address,
@@ -867,6 +1063,19 @@ class ModelRunner:
867
1063
  logger.error(message)
868
1064
  return False, message
869
1065
 
1066
+ def destroy_weights_update_group(self, group_name):
1067
+ try:
1068
+ if group_name in self._model_update_group:
1069
+ pg = self._model_update_group.pop(group_name)
1070
+ torch.distributed.destroy_process_group(pg)
1071
+ return True, "Succeeded to destroy custom process group."
1072
+ else:
1073
+ return False, "The group to be destroyed does not exist."
1074
+ except Exception as e:
1075
+ message = f"Failed to destroy custom process group: {e}."
1076
+ logger.error(message)
1077
+ return False, message
1078
+
870
1079
  def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
871
1080
  """
872
1081
  Update specific parameter in the model weights online
@@ -904,7 +1113,7 @@ class ModelRunner:
904
1113
  handle.wait()
905
1114
 
906
1115
  self.model.load_weights(weights)
907
- return True, f"Succeeded to update parameter online."
1116
+ return True, "Succeeded to update parameter online."
908
1117
 
909
1118
  except Exception as e:
910
1119
  error_msg = (
@@ -1008,6 +1217,7 @@ class ModelRunner:
1008
1217
  max_lora_rank=self.server_args.max_lora_rank,
1009
1218
  target_modules=self.server_args.lora_target_modules,
1010
1219
  lora_paths=self.server_args.lora_paths,
1220
+ server_args=self.server_args,
1011
1221
  )
1012
1222
 
1013
1223
  def load_lora_adapter(self, lora_ref: LoRARef):
@@ -1057,6 +1267,8 @@ class ModelRunner:
1057
1267
  "num_nextn_predict_layers",
1058
1268
  self.num_effective_layers,
1059
1269
  )
1270
+ elif self.is_hybrid_gdn:
1271
+ num_layers = len(self.model_config.hf_config.full_attention_layer_ids)
1060
1272
  else:
1061
1273
  num_layers = self.num_effective_layers
1062
1274
  if self.use_mla_backend:
@@ -1076,9 +1288,23 @@ class ModelRunner:
1076
1288
  rest_memory = available_gpu_memory - total_gpu_memory * (
1077
1289
  1 - self.mem_fraction_static
1078
1290
  )
1291
+ if self.is_hybrid_gdn:
1292
+ rest_memory -= (
1293
+ self.server_args.max_mamba_cache_size
1294
+ * self.model_config.hf_config.mamba_cache_per_req
1295
+ / (1 << 30)
1296
+ )
1079
1297
  max_num_token = int(rest_memory * (1 << 30) // cell_size)
1080
1298
  return max_num_token
1081
1299
 
1300
+ @property
1301
+ def is_hybrid_gdn(self):
1302
+ return self.model_config.hf_config.architectures[0] in [
1303
+ "Qwen3NextForCausalLM",
1304
+ "Qwen3NextForCausalLMMTP",
1305
+ "FalconH1ForCausalLM",
1306
+ ]
1307
+
1082
1308
  def set_num_token_hybrid(self):
1083
1309
  if (
1084
1310
  "Llama4ForConditionalGeneration"
@@ -1169,7 +1395,18 @@ class ModelRunner:
1169
1395
  ):
1170
1396
  # Determine the kv cache dtype
1171
1397
  if self.server_args.kv_cache_dtype == "auto":
1172
- self.kv_cache_dtype = self.dtype
1398
+ quant_config = getattr(self.model, "quant_config", None)
1399
+ kv_cache_quant_algo = getattr(quant_config, "kv_cache_quant_algo", None)
1400
+ if (
1401
+ isinstance(kv_cache_quant_algo, str)
1402
+ and kv_cache_quant_algo.upper() == "FP8"
1403
+ ):
1404
+ if _is_hip:
1405
+ self.kv_cache_dtype = torch.float8_e4m3fnuz
1406
+ else:
1407
+ self.kv_cache_dtype = torch.float8_e4m3fn
1408
+ else:
1409
+ self.kv_cache_dtype = self.dtype
1173
1410
  elif self.server_args.kv_cache_dtype == "fp8_e5m2":
1174
1411
  if _is_hip: # Using natively supported format
1175
1412
  self.kv_cache_dtype = torch.float8_e5m2fnuz
@@ -1185,6 +1422,8 @@ class ModelRunner:
1185
1422
  f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
1186
1423
  )
1187
1424
 
1425
+ log_info_on_rank0(logger, f"Using KV cache dtype: {self.kv_cache_dtype}")
1426
+
1188
1427
  self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
1189
1428
  if SGLANG_CI_SMALL_KV_SIZE:
1190
1429
  self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
@@ -1199,8 +1438,10 @@ class ModelRunner:
1199
1438
  ),
1200
1439
  4096,
1201
1440
  )
1441
+ if self.is_hybrid_gdn:
1442
+ max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
1202
1443
 
1203
- if not self.spec_algorithm.is_none():
1444
+ if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
1204
1445
  if self.is_draft_worker:
1205
1446
  self.max_total_num_tokens = self.server_args.draft_runner_cache_size
1206
1447
  max_num_reqs = self.server_args.max_num_reqs
@@ -1237,13 +1478,24 @@ class ModelRunner:
1237
1478
  // self.server_args.page_size
1238
1479
  * self.server_args.page_size
1239
1480
  )
1481
+ # different pp rank may have different num of layers, so we need to reduce the max_total_num_tokens
1482
+ if self.pp_size > 1:
1483
+ tensor = torch.tensor(self.max_total_num_tokens, dtype=torch.int64)
1484
+ torch.distributed.all_reduce(
1485
+ tensor,
1486
+ op=torch.distributed.ReduceOp.MIN,
1487
+ group=get_world_group().cpu_group,
1488
+ )
1489
+ self.max_total_num_tokens = tensor.item()
1490
+
1240
1491
  # create token size for hybrid cache
1241
1492
  if self.is_hybrid:
1242
1493
  self.set_num_token_hybrid()
1243
1494
 
1244
1495
  if self.max_total_num_tokens <= 0:
1245
1496
  raise RuntimeError(
1246
- "Not enough memory. Please try to increase --mem-fraction-static."
1497
+ f"Not enough memory. Please try to increase --mem-fraction-static. "
1498
+ f"Current value: {self.server_args.mem_fraction_static=}"
1247
1499
  )
1248
1500
 
1249
1501
  # Initialize req_to_token_pool
@@ -1267,6 +1519,28 @@ class ModelRunner:
1267
1519
  enable_memory_saver=self.server_args.enable_memory_saver,
1268
1520
  pre_alloc_size=pre_alloc_size,
1269
1521
  )
1522
+ elif self.is_hybrid_gdn:
1523
+ config = self.model_config.hf_config
1524
+ (
1525
+ conv_state_shape,
1526
+ temporal_state_shape,
1527
+ conv_dtype,
1528
+ ssm_dtype,
1529
+ mamba_layers,
1530
+ ) = config.hybrid_gdn_params
1531
+ self.req_to_token_pool = HybridReqToTokenPool(
1532
+ size=max_num_reqs,
1533
+ max_context_len=self.model_config.context_len
1534
+ + extra_max_context_len,
1535
+ device=self.device,
1536
+ enable_memory_saver=self.server_args.enable_memory_saver,
1537
+ conv_state_shape=conv_state_shape,
1538
+ temporal_state_shape=temporal_state_shape,
1539
+ conv_dtype=conv_dtype,
1540
+ ssm_dtype=ssm_dtype,
1541
+ mamba_layers=mamba_layers,
1542
+ speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1543
+ )
1270
1544
  else:
1271
1545
  self.req_to_token_pool = ReqToTokenPool(
1272
1546
  size=max_num_reqs,
@@ -1280,6 +1554,7 @@ class ModelRunner:
1280
1554
  assert self.is_draft_worker
1281
1555
 
1282
1556
  # Initialize token_to_kv_pool
1557
+ is_nsa_model = is_deepseek_nsa(self.model_config.hf_config)
1283
1558
  if self.server_args.attention_backend == "ascend":
1284
1559
  if self.use_mla_backend:
1285
1560
  self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
@@ -1288,6 +1563,7 @@ class ModelRunner:
1288
1563
  dtype=self.kv_cache_dtype,
1289
1564
  kv_lora_rank=self.model_config.kv_lora_rank,
1290
1565
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1566
+ index_head_dim=self.model_config.index_head_dim,
1291
1567
  layer_num=self.num_effective_layers,
1292
1568
  device=self.device,
1293
1569
  enable_memory_saver=self.server_args.enable_memory_saver,
@@ -1307,7 +1583,22 @@ class ModelRunner:
1307
1583
  device=self.device,
1308
1584
  enable_memory_saver=self.server_args.enable_memory_saver,
1309
1585
  )
1586
+ elif self.use_mla_backend and is_nsa_model:
1587
+ self.token_to_kv_pool = NSATokenToKVPool(
1588
+ self.max_total_num_tokens,
1589
+ page_size=self.page_size,
1590
+ dtype=self.kv_cache_dtype,
1591
+ kv_lora_rank=self.model_config.kv_lora_rank,
1592
+ qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1593
+ layer_num=self.num_effective_layers,
1594
+ device=self.device,
1595
+ enable_memory_saver=self.server_args.enable_memory_saver,
1596
+ start_layer=self.start_layer,
1597
+ end_layer=self.end_layer,
1598
+ index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
1599
+ )
1310
1600
  elif self.use_mla_backend:
1601
+ assert not is_nsa_model
1311
1602
  self.token_to_kv_pool = MLATokenToKVPool(
1312
1603
  self.max_total_num_tokens,
1313
1604
  page_size=self.page_size,
@@ -1349,6 +1640,24 @@ class ModelRunner:
1349
1640
  enable_kvcache_transpose=False,
1350
1641
  device=self.device,
1351
1642
  )
1643
+ elif self.is_hybrid_gdn:
1644
+ self.token_to_kv_pool = HybridLinearKVPool(
1645
+ page_size=self.page_size,
1646
+ size=self.max_total_num_tokens,
1647
+ dtype=self.kv_cache_dtype,
1648
+ head_num=self.model_config.get_num_kv_heads(
1649
+ get_attention_tp_size()
1650
+ ),
1651
+ head_dim=self.model_config.head_dim,
1652
+ # if draft worker, we only need 1 attention layer's kv pool
1653
+ full_attention_layer_ids=(
1654
+ [0]
1655
+ if self.is_draft_worker
1656
+ else self.model_config.hf_config.full_attention_layer_ids
1657
+ ),
1658
+ enable_kvcache_transpose=False,
1659
+ device=self.device,
1660
+ )
1352
1661
  else:
1353
1662
  self.token_to_kv_pool = MHATokenToKVPool(
1354
1663
  self.max_total_num_tokens,
@@ -1368,7 +1677,9 @@ class ModelRunner:
1368
1677
  # Initialize token_to_kv_pool_allocator
1369
1678
  need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
1370
1679
  if self.token_to_kv_pool_allocator is None:
1371
- if self.server_args.attention_backend == "ascend":
1680
+ if _is_npu and (
1681
+ self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn
1682
+ ):
1372
1683
  self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1373
1684
  self.max_total_num_tokens,
1374
1685
  page_size=self.page_size,
@@ -1462,8 +1773,8 @@ class ModelRunner:
1462
1773
  f"prefill_backend={self.prefill_attention_backend_str}."
1463
1774
  )
1464
1775
  logger.warning(
1465
- f"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
1466
- f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
1776
+ "Warning: Attention backend specified by --attention-backend or default backend might be overridden."
1777
+ "The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
1467
1778
  )
1468
1779
  else:
1469
1780
  attn_backend = self._get_attention_backend_from_str(
@@ -1479,111 +1790,10 @@ class ModelRunner:
1479
1790
  return attn_backend
1480
1791
 
1481
1792
  def _get_attention_backend_from_str(self, backend_str: str):
1482
- if backend_str == "flashinfer":
1483
- if not self.use_mla_backend:
1484
- from sglang.srt.layers.attention.flashinfer_backend import (
1485
- FlashInferAttnBackend,
1486
- )
1487
-
1488
- # Init streams
1489
- if self.server_args.speculative_algorithm == "EAGLE":
1490
- if (
1491
- not hasattr(self, "plan_stream_for_flashinfer")
1492
- or not self.plan_stream_for_flashinfer
1493
- ):
1494
- self.plan_stream_for_flashinfer = torch.cuda.Stream()
1495
- return FlashInferAttnBackend(self)
1496
- else:
1497
- from sglang.srt.layers.attention.flashinfer_mla_backend import (
1498
- FlashInferMLAAttnBackend,
1499
- )
1500
-
1501
- return FlashInferMLAAttnBackend(self)
1502
- elif backend_str == "aiter":
1503
- from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
1504
-
1505
- return AiterAttnBackend(self)
1506
- elif self.server_args.attention_backend == "wave":
1507
- from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
1508
-
1509
- return WaveAttnBackend(self)
1510
- elif backend_str == "ascend":
1511
- from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
1512
-
1513
- return AscendAttnBackend(self)
1514
- elif backend_str == "triton":
1515
- assert not self.model_config.is_encoder_decoder, (
1516
- "Cross attention is not supported in the triton attention backend. "
1517
- "Please use `--attention-backend flashinfer`."
1518
- )
1519
- if self.server_args.enable_double_sparsity:
1520
- from sglang.srt.layers.attention.double_sparsity_backend import (
1521
- DoubleSparseAttnBackend,
1522
- )
1523
-
1524
- return DoubleSparseAttnBackend(self)
1525
- else:
1526
- from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
1527
-
1528
- return TritonAttnBackend(self)
1529
- elif backend_str == "torch_native":
1530
- from sglang.srt.layers.attention.torch_native_backend import (
1531
- TorchNativeAttnBackend,
1532
- )
1533
-
1534
- return TorchNativeAttnBackend(self)
1535
- elif backend_str == "flashmla":
1536
- from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
1537
-
1538
- return FlashMLABackend(self)
1539
- elif backend_str == "fa3":
1540
- assert (
1541
- torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
1542
- ) or torch.cuda.get_device_capability()[0] == 9, (
1543
- "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
1544
- "Please use `--attention-backend flashinfer`."
1545
- )
1546
- from sglang.srt.layers.attention.flashattention_backend import (
1547
- FlashAttentionBackend,
1548
- )
1549
-
1550
- return FlashAttentionBackend(self)
1551
- elif backend_str == "cutlass_mla":
1552
- from sglang.srt.layers.attention.cutlass_mla_backend import (
1553
- CutlassMLABackend,
1554
- )
1555
-
1556
- return CutlassMLABackend(self)
1557
- elif backend_str == "trtllm_mla":
1558
- if not self.use_mla_backend:
1559
- raise ValueError("trtllm_mla backend can only be used with MLA models.")
1560
- from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
1561
-
1562
- return TRTLLMMLABackend(self)
1563
- elif backend_str == "trtllm_mha":
1564
- if self.use_mla_backend:
1565
- raise ValueError(
1566
- "trtllm_mha backend can only be used with non-MLA models."
1567
- )
1568
- from sglang.srt.layers.attention.trtllm_mha_backend import (
1569
- TRTLLMHAAttnBackend,
1570
- )
1571
-
1572
- return TRTLLMHAAttnBackend(self)
1573
- elif backend_str == "intel_amx":
1574
- from sglang.srt.layers.attention.intel_amx_backend import (
1575
- IntelAMXAttnBackend,
1576
- )
1577
-
1578
- return IntelAMXAttnBackend(self)
1579
- elif backend_str == "dual_chunk_flash_attn":
1580
- from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
1581
- DualChunkFlashAttentionBackend,
1582
- )
1583
-
1584
- return DualChunkFlashAttentionBackend(self)
1585
- else:
1793
+ if backend_str not in ATTENTION_BACKENDS:
1586
1794
  raise ValueError(f"Invalid attention backend: {backend_str}")
1795
+ full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
1796
+ return attn_backend_wrapper(self, full_attention_backend)
1587
1797
 
1588
1798
  def init_double_sparsity_channel_config(self, selected_channel):
1589
1799
  selected_channel = "." + selected_channel + "_proj"
@@ -1603,38 +1813,46 @@ class ModelRunner:
1603
1813
  )
1604
1814
 
1605
1815
  def init_device_graphs(self):
1606
- """Capture cuda graphs."""
1816
+ """Capture device graphs."""
1607
1817
  self.graph_runner = None
1608
- self.cuda_graph_mem_usage = 0
1818
+ self.graph_mem_usage = 0
1609
1819
 
1610
1820
  if not self.is_generation:
1611
1821
  # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
1612
1822
  return
1613
1823
 
1614
- if self.server_args.disable_cuda_graph:
1824
+ if self.device != "cpu" and self.server_args.disable_cuda_graph:
1825
+ return
1826
+
1827
+ if self.device == "cpu" and not self.server_args.enable_torch_compile:
1615
1828
  return
1616
1829
 
1617
1830
  tic = time.perf_counter()
1618
1831
  before_mem = get_available_gpu_memory(self.device, self.gpu_id)
1619
1832
  logger.info(
1620
- f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1833
+ f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1621
1834
  )
1622
- self.graph_runner = (
1623
- CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self)
1835
+ graph_runners = defaultdict(
1836
+ lambda: CudaGraphRunner,
1837
+ {
1838
+ "cpu": CPUGraphRunner,
1839
+ "npu": NPUGraphRunner,
1840
+ },
1624
1841
  )
1842
+ self.graph_runner = graph_runners[self.device](self)
1843
+
1625
1844
  after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1626
- self.cuda_graph_mem_usage = before_mem - after_mem
1845
+ self.graph_mem_usage = before_mem - after_mem
1627
1846
  logger.info(
1628
- f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
1629
- f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
1847
+ f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
1848
+ f"mem usage={self.graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
1630
1849
  )
1631
1850
 
1632
1851
  def init_threads_binding(self):
1633
1852
  omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
1853
+ cpu_ids_by_node = get_cpu_ids_by_node()
1854
+ n_numa_node = len(cpu_ids_by_node)
1634
1855
  if omp_cpuids == "all":
1635
- cpu_ids_by_node = get_cpu_ids_by_node()
1636
- n_numa_node = len(cpu_ids_by_node)
1637
-
1638
1856
  assert self.tp_size <= n_numa_node, (
1639
1857
  f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
1640
1858
  f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
@@ -1651,11 +1869,22 @@ class ModelRunner:
1651
1869
  )
1652
1870
  self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
1653
1871
  else:
1654
- self.local_omp_cpuid = omp_cpuids.split("|")[self.tp_rank]
1872
+ threads_bind_list = omp_cpuids.split("|")
1873
+ assert self.tp_size == len(threads_bind_list), (
1874
+ f"SGLANG_CPU_OMP_THREADS_BIND setting must be aligned with TP size parameter ({self.tp_size}). "
1875
+ f"Please double check your settings."
1876
+ )
1877
+ self.local_omp_cpuid = threads_bind_list[self.tp_rank]
1878
+ if self.tp_size > n_numa_node:
1879
+ logger.warning(
1880
+ f"TP size ({self.tp_size})is larger than numa node number ({n_numa_node}), "
1881
+ f"in this case the available memory amount of each rank cannot be determined in prior. "
1882
+ f"Please set proper `--max-total-tokens` to avoid the out-of-memory error."
1883
+ )
1655
1884
 
1656
1885
  def apply_torch_tp(self):
1657
1886
  logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1658
- from sglang.srt.model_parallel import tensor_parallel
1887
+ from sglang.srt.layers.model_parallel import tensor_parallel
1659
1888
 
1660
1889
  device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
1661
1890
  tensor_parallel(self.model, device_mesh)
@@ -1771,18 +2000,24 @@ class ModelRunner:
1771
2000
  reinit_attn_backend: bool = False,
1772
2001
  split_forward_count: int = 1,
1773
2002
  ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1774
- can_run_cuda_graph = bool(
1775
- forward_batch.forward_mode.is_cuda_graph()
2003
+ mode_check = (
2004
+ forward_batch.forward_mode.is_cpu_graph
2005
+ if self.device == "cpu"
2006
+ else forward_batch.forward_mode.is_cuda_graph
2007
+ )
2008
+ can_run_graph = bool(
2009
+ mode_check()
1776
2010
  and self.graph_runner
1777
2011
  and self.graph_runner.can_run(forward_batch)
1778
2012
  )
1779
- if can_run_cuda_graph:
2013
+
2014
+ if can_run_graph:
1780
2015
  ret = self.graph_runner.replay(
1781
2016
  forward_batch,
1782
2017
  skip_attn_backend_init=skip_attn_backend_init,
1783
2018
  pp_proxy_tensors=pp_proxy_tensors,
1784
2019
  )
1785
- return ret, can_run_cuda_graph
2020
+ return ret, can_run_graph
1786
2021
 
1787
2022
  # For MLP sync
1788
2023
  if forward_batch.global_num_tokens_cpu is not None:
@@ -1811,10 +2046,13 @@ class ModelRunner:
1811
2046
  else:
1812
2047
  raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1813
2048
 
1814
- if forward_batch.global_num_tokens_cpu is not None:
2049
+ if (
2050
+ forward_batch.global_num_tokens_cpu is not None
2051
+ and self.pp_group.is_last_rank
2052
+ ):
1815
2053
  forward_batch.post_forward_mlp_sync_batch(ret)
1816
2054
 
1817
- return ret, can_run_cuda_graph
2055
+ return ret, can_run_graph
1818
2056
 
1819
2057
  def _preprocess_logits(
1820
2058
  self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
@@ -1852,7 +2090,6 @@ class ModelRunner:
1852
2090
  )
1853
2091
 
1854
2092
  self._preprocess_logits(logits_output, forward_batch.sampling_info)
1855
-
1856
2093
  # Sample the next tokens
1857
2094
  next_token_ids = self.sampler(
1858
2095
  logits_output,
@@ -1860,9 +2097,47 @@ class ModelRunner:
1860
2097
  forward_batch.return_logprob,
1861
2098
  forward_batch.top_logprobs_nums,
1862
2099
  forward_batch.token_ids_logprobs,
2100
+ # For prefill, we only use the position of the last token.
2101
+ (
2102
+ forward_batch.positions
2103
+ if forward_batch.forward_mode.is_decode()
2104
+ else forward_batch.seq_lens - 1
2105
+ ),
1863
2106
  )
1864
2107
  return next_token_ids
1865
2108
 
2109
+ def compute_logprobs_only(
2110
+ self,
2111
+ logits_output: LogitsProcessorOutput,
2112
+ forward_batch: ForwardBatch,
2113
+ ) -> None:
2114
+ """
2115
+ Compute token_ids_logprobs without performing sampling.
2116
+
2117
+ Optimized path for prefill-only requests that need token_ids_logprobs but don't
2118
+ require next token generation. Skips expensive sampling operations
2119
+ while still providing requested probability information.
2120
+
2121
+ Args:
2122
+ logits_output: The logits output from the model forward
2123
+ forward_batch: The forward batch that generates logits_output
2124
+ """
2125
+ if not forward_batch.token_ids_logprobs:
2126
+ return
2127
+
2128
+ # Preprocess logits (same as in sample method)
2129
+ self._preprocess_logits(logits_output, forward_batch.sampling_info)
2130
+
2131
+ # Delegate to sampler for logprob-only computation
2132
+ # This populates logits_output with requested token probabilities
2133
+ self.sampler.compute_logprobs_only(
2134
+ logits_output,
2135
+ forward_batch.sampling_info,
2136
+ forward_batch.return_logprob,
2137
+ forward_batch.top_logprobs_nums,
2138
+ forward_batch.token_ids_logprobs,
2139
+ )
2140
+
1866
2141
  @property
1867
2142
  def model_is_mrope(self) -> bool:
1868
2143
  """Detect if the model has "mrope" rope_scaling type.