sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -19,25 +19,36 @@ import inspect
19
19
  import json
20
20
  import logging
21
21
  import os
22
+ import socket
23
+ import threading
22
24
  import time
25
+ from collections import defaultdict
23
26
  from dataclasses import dataclass
24
27
  from typing import List, Optional, Tuple, Union
25
28
 
26
29
  import torch
27
30
  import torch.distributed as dist
28
31
 
32
+ from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
29
33
  from sglang.srt.configs.device_config import DeviceConfig
30
- from sglang.srt.configs.load_config import LoadConfig
31
- from sglang.srt.configs.model_config import AttentionArch, ModelConfig
34
+ from sglang.srt.configs.load_config import LoadConfig, LoadFormat
35
+ from sglang.srt.configs.model_config import (
36
+ AttentionArch,
37
+ ModelConfig,
38
+ get_nsa_index_head_dim,
39
+ is_deepseek_nsa,
40
+ )
32
41
  from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
33
42
  from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
34
43
  from sglang.srt.distributed import (
44
+ get_pp_group,
35
45
  get_tp_group,
36
46
  get_world_group,
37
47
  init_distributed_environment,
38
48
  initialize_model_parallel,
39
49
  set_custom_all_reduce,
40
50
  set_mscclpp_all_reduce,
51
+ set_symm_mem_all_reduce,
41
52
  )
42
53
  from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
43
54
  from sglang.srt.eplb.eplb_manager import EPLBManager
@@ -53,6 +64,10 @@ from sglang.srt.eplb.expert_location import (
53
64
  set_global_expert_location_metadata,
54
65
  )
55
66
  from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
67
+ from sglang.srt.layers.attention.attention_registry import (
68
+ ATTENTION_BACKENDS,
69
+ attn_backend_wrapper,
70
+ )
56
71
  from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
57
72
  from sglang.srt.layers.dp_attention import (
58
73
  get_attention_tp_group,
@@ -83,16 +98,23 @@ from sglang.srt.mem_cache.memory_pool import (
83
98
  AscendMLAPagedTokenToKVPool,
84
99
  AscendTokenToKVPool,
85
100
  DoubleSparseTokenToKVPool,
101
+ HybridLinearKVPool,
102
+ HybridReqToTokenPool,
86
103
  MHATokenToKVPool,
87
104
  MLATokenToKVPool,
105
+ NSATokenToKVPool,
88
106
  ReqToTokenPool,
89
107
  SWAKVPool,
90
108
  )
109
+ from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
91
110
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
92
111
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
93
112
  from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
94
113
  from sglang.srt.model_loader import get_model
95
114
  from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
115
+ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
116
+ trigger_init_weights_send_group_for_remote_instance_request,
117
+ )
96
118
  from sglang.srt.model_loader.utils import set_default_torch_dtype
97
119
  from sglang.srt.model_loader.weight_utils import default_weight_loader
98
120
  from sglang.srt.offloader import (
@@ -100,7 +122,6 @@ from sglang.srt.offloader import (
100
122
  get_offloader,
101
123
  set_offloader,
102
124
  )
103
- from sglang.srt.patch_torch import monkey_patch_torch_reductions
104
125
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
105
126
  from sglang.srt.server_args import ServerArgs
106
127
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -121,15 +142,38 @@ from sglang.srt.utils import (
121
142
  is_no_spec_infer_or_topk_one,
122
143
  is_npu,
123
144
  is_sm100_supported,
145
+ log_info_on_rank0,
124
146
  monkey_patch_p2p_access_check,
125
147
  monkey_patch_vllm_gguf_config,
126
148
  set_cuda_arch,
149
+ slow_rank_detector,
127
150
  )
151
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
128
152
  from sglang.srt.weight_sync.tensor_bucket import (
129
153
  FlattenedTensorBucket,
130
154
  FlattenedTensorMetadata,
131
155
  )
132
156
 
157
+ MLA_ATTENTION_BACKENDS = [
158
+ "aiter",
159
+ "flashinfer",
160
+ "fa3",
161
+ "fa4",
162
+ "triton",
163
+ "flashmla",
164
+ "cutlass_mla",
165
+ "trtllm_mla",
166
+ "ascend",
167
+ "nsa",
168
+ ]
169
+
170
+
171
+ def add_mla_attention_backend(backend_name):
172
+ if backend_name not in MLA_ATTENTION_BACKENDS:
173
+ MLA_ATTENTION_BACKENDS.append(backend_name)
174
+ logger.info(f"Added {backend_name} to MLA_ATTENTION_BACKENDS.")
175
+
176
+
133
177
  _is_hip = is_hip()
134
178
  _is_npu = is_npu()
135
179
  _is_cpu_amx_available = cpu_has_amx_support()
@@ -143,6 +187,13 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
143
187
  logger = logging.getLogger(__name__)
144
188
 
145
189
 
190
+ if _is_npu:
191
+ import torch_npu
192
+
193
+ torch.npu.config.allow_internal_format = True
194
+ torch_npu.npu.set_compile_mode(jit_compile=False)
195
+
196
+
146
197
  class RankZeroFilter(logging.Filter):
147
198
  """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
148
199
 
@@ -237,6 +288,9 @@ class ModelRunner:
237
288
  # CPU offload
238
289
  set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
239
290
 
291
+ if get_bool_env_var("SGLANG_DETECT_SLOW_RANK"):
292
+ slow_rank_detector.execute()
293
+
240
294
  # Update deep gemm configure
241
295
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
242
296
  deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
@@ -251,6 +305,7 @@ class ModelRunner:
251
305
 
252
306
  # For weight updates
253
307
  self._model_update_group = {}
308
+ self._weights_send_group = {}
254
309
 
255
310
  def initialize(self, min_per_gpu_memory: float):
256
311
  server_args = self.server_args
@@ -300,6 +355,27 @@ class ModelRunner:
300
355
  if architectures and not any("Llama4" in arch for arch in architectures):
301
356
  self.is_hybrid = self.model_config.is_hybrid = True
302
357
 
358
+ if config := self.mambaish_config:
359
+ class_name = config.__class__.__name__
360
+ logger.warning(f"{class_name} model detected, disable radix cache")
361
+ self.server_args.disable_radix_cache = True
362
+ if self.server_args.max_mamba_cache_size is None:
363
+ if self.server_args.max_running_requests is not None:
364
+ self.server_args.max_mamba_cache_size = (
365
+ self.server_args.max_running_requests
366
+ )
367
+ else:
368
+ self.server_args.max_mamba_cache_size = 512
369
+ if self.hybrid_gdn_config is not None:
370
+ self.server_args.max_mamba_cache_size = (
371
+ self.server_args.max_mamba_cache_size
372
+ // (
373
+ self.server_args.dp_size
374
+ if self.server_args.enable_dp_attention
375
+ else 1
376
+ )
377
+ )
378
+
303
379
  # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
304
380
  # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
305
381
  # determine the number of layers.
@@ -341,6 +417,20 @@ class ModelRunner:
341
417
  if server_args.enable_lora:
342
418
  self.init_lora_manager()
343
419
 
420
+ # Init Double Sparsity
421
+ if server_args.enable_double_sparsity:
422
+ if server_args.ds_heavy_channel_type is None:
423
+ raise ValueError(
424
+ "Please specify the heavy channel type for double sparsity optimization."
425
+ )
426
+ self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
427
+
428
+ # Enable batch invariant mode
429
+ if server_args.enable_deterministic_inference:
430
+ from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode
431
+
432
+ enable_batch_invariant_mode()
433
+
344
434
  # Init memory pool and attention backends
345
435
  self.init_memory_pool(
346
436
  min_per_gpu_memory,
@@ -351,12 +441,12 @@ class ModelRunner:
351
441
  self.init_cublas()
352
442
  self.init_attention_backend()
353
443
  self.init_device_graphs()
354
- elif self.device == "npu":
444
+ elif self.device in ["npu", "cpu"]:
355
445
  self.init_attention_backend()
356
446
  self.init_device_graphs()
357
447
  else:
358
448
  self.graph_runner = None
359
- self.cuda_graph_mem_usage = 0
449
+ self.graph_mem_usage = 0
360
450
  self.init_attention_backend()
361
451
 
362
452
  # auxiliary hidden capture mode. TODO: expose this to server args?
@@ -452,9 +542,7 @@ class ModelRunner:
452
542
  elif _is_hip:
453
543
  head_num = self.model_config.get_num_kv_heads(self.tp_size)
454
544
  # TODO current aiter only support head number 16 or 128 head number
455
- if (
456
- head_num == 128 or head_num == 16
457
- ) and self.spec_algorithm.is_none():
545
+ if head_num == 128 or head_num == 16:
458
546
  server_args.attention_backend = "aiter"
459
547
  else:
460
548
  server_args.attention_backend = "triton"
@@ -467,16 +555,7 @@ class ModelRunner:
467
555
  )
468
556
  elif self.use_mla_backend:
469
557
  if server_args.device != "cpu":
470
- if server_args.attention_backend in [
471
- "aiter",
472
- "flashinfer",
473
- "fa3",
474
- "triton",
475
- "flashmla",
476
- "cutlass_mla",
477
- "trtllm_mla",
478
- "ascend",
479
- ]:
558
+ if server_args.attention_backend in MLA_ATTENTION_BACKENDS:
480
559
  logger.info(
481
560
  f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
482
561
  )
@@ -506,11 +585,6 @@ class ModelRunner:
506
585
  )
507
586
  server_args.attention_backend = "triton"
508
587
  server_args.disable_cuda_graph = True
509
- if server_args.ds_heavy_channel_type is None:
510
- raise ValueError(
511
- "Please specify the heavy channel type for double sparsity optimization."
512
- )
513
- self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
514
588
 
515
589
  if self.is_multimodal:
516
590
  if not self.is_multimodal_chunked_prefill_supported:
@@ -548,7 +622,7 @@ class ModelRunner:
548
622
  server_args.hicache_io_backend = "direct"
549
623
  logger.warning(
550
624
  "FlashAttention3 decode backend is not compatible with hierarchical cache. "
551
- f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
625
+ "Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
552
626
  )
553
627
 
554
628
  def init_torch_distributed(self):
@@ -583,6 +657,7 @@ class ModelRunner:
583
657
  dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
584
658
  set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
585
659
  set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
660
+ set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
586
661
 
587
662
  if not self.is_draft_worker:
588
663
  if self.device == "cpu":
@@ -593,6 +668,11 @@ class ModelRunner:
593
668
  # Set local size to hint SGLang to use shared memory based AllReduce
594
669
  os.environ["LOCAL_SIZE"] = str(self.tp_size)
595
670
  torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
671
+
672
+ @torch.library.register_fake("sgl_kernel::shm_allgather")
673
+ def _(data, dim):
674
+ return torch.cat([data] * self.tp_size, dim=dim)
675
+
596
676
  else:
597
677
  logger.warning(
598
678
  "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
@@ -625,6 +705,7 @@ class ModelRunner:
625
705
  cpu_group=get_world_group().cpu_group,
626
706
  )
627
707
  self.tp_group = get_tp_group()
708
+ self.pp_group = get_pp_group()
628
709
  self.attention_tp_group = get_attention_tp_group()
629
710
 
630
711
  # Check memory for tensor parallelism
@@ -673,6 +754,10 @@ class ModelRunner:
673
754
  load_format=self.server_args.load_format,
674
755
  download_dir=self.server_args.download_dir,
675
756
  model_loader_extra_config=self.server_args.model_loader_extra_config,
757
+ tp_rank=self.tp_rank,
758
+ remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
759
+ remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
760
+ remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
676
761
  )
677
762
  if self.device == "cpu":
678
763
  self.model_config = adjust_config_with_unaligned_cpu_tp(
@@ -681,16 +766,33 @@ class ModelRunner:
681
766
  if self.server_args.load_format == "gguf":
682
767
  monkey_patch_vllm_gguf_config()
683
768
 
769
+ if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
770
+ if self.tp_rank == 0:
771
+ instance_ip = socket.gethostbyname(socket.gethostname())
772
+ t = threading.Thread(
773
+ target=trigger_init_weights_send_group_for_remote_instance_request,
774
+ args=(
775
+ self.server_args.remote_instance_weight_loader_seed_instance_ip,
776
+ self.server_args.remote_instance_weight_loader_seed_instance_service_port,
777
+ self.server_args.remote_instance_weight_loader_send_weights_group_ports,
778
+ instance_ip,
779
+ ),
780
+ )
781
+ t.start()
782
+
684
783
  # Load the model
685
784
  # Remove monkey_patch when linear.py quant remove dependencies with vllm
686
785
  monkey_patch_vllm_parallel_state()
687
786
  monkey_patch_isinstance_for_vllm_base_layer()
688
787
 
689
- with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
788
+ with self.memory_saver_adapter.region(
789
+ GPU_MEMORY_TYPE_WEIGHTS,
790
+ enable_cpu_backup=self.server_args.enable_weights_cpu_backup,
791
+ ):
690
792
  self.model = get_model(
691
793
  model_config=self.model_config,
692
794
  load_config=self.load_config,
693
- device_config=DeviceConfig(self.device),
795
+ device_config=DeviceConfig(self.device, self.gpu_id),
694
796
  )
695
797
  monkey_patch_vllm_parallel_state(reverse=True)
696
798
  monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
@@ -781,7 +883,7 @@ class ModelRunner:
781
883
  load_config = LoadConfig(load_format=load_format)
782
884
 
783
885
  # Only support DefaultModelLoader for now
784
- loader = get_model_loader(load_config)
886
+ loader = get_model_loader(load_config, self.model_config)
785
887
  if not isinstance(loader, DefaultModelLoader):
786
888
  message = f"Failed to get model loader: {loader}."
787
889
  return False, message
@@ -822,6 +924,103 @@ class ModelRunner:
822
924
  logger.info("Update weights end.")
823
925
  return True, "Succeeded to update model weights."
824
926
 
927
+ def init_weights_send_group_for_remote_instance(
928
+ self,
929
+ master_address,
930
+ ports,
931
+ group_rank,
932
+ world_size,
933
+ group_name,
934
+ backend="nccl",
935
+ ):
936
+ assert (
937
+ torch.distributed.is_initialized()
938
+ ), "Default torch process group must be initialized"
939
+ assert group_name != "", "Group name cannot be empty"
940
+
941
+ ports_list = ports.split(",")
942
+ assert (
943
+ len(ports_list) == self.tp_size
944
+ ), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
945
+ group_port = ports_list[self.tp_rank]
946
+ group_name = f"{group_name}_{group_port}_{self.tp_rank}"
947
+
948
+ logger.info(
949
+ f"init custom process group: tp_rank={self.tp_rank}, gpu_id={self.gpu_id}, master_address={master_address}, master_port={group_port}, "
950
+ f"group_rank={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
951
+ )
952
+
953
+ torch.cuda.empty_cache()
954
+ success = False
955
+ message = ""
956
+ try:
957
+ self._weights_send_group[group_name] = init_custom_process_group(
958
+ backend=backend,
959
+ init_method=f"tcp://{master_address}:{group_port}",
960
+ world_size=world_size,
961
+ rank=group_rank,
962
+ group_name=group_name,
963
+ device_id=torch.device("cuda", self.gpu_id),
964
+ )
965
+ dist.barrier(group=self._weights_send_group[group_name])
966
+ success = True
967
+ message = (
968
+ f"Succeeded to init group through {master_address}:{group_port} group."
969
+ )
970
+ except Exception as e:
971
+ message = f"Failed to init group: {e}."
972
+ logger.error(message)
973
+
974
+ torch.cuda.empty_cache()
975
+ return success, message
976
+
977
+ def send_weights_to_remote_instance(
978
+ self,
979
+ master_address,
980
+ ports,
981
+ group_name,
982
+ ):
983
+ assert (
984
+ torch.distributed.is_initialized()
985
+ ), "Default torch process group must be initialized"
986
+ assert group_name != "", "Group name cannot be empty"
987
+
988
+ ports_list = ports.split(",")
989
+ assert (
990
+ len(ports_list) == self.tp_size
991
+ ), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
992
+ group_port = ports_list[self.tp_rank]
993
+ group_name = f"{group_name}_{group_port}_{self.tp_rank}"
994
+
995
+ if self._weights_send_group[group_name] is not None:
996
+ send_group = self._weights_send_group[group_name]
997
+ else:
998
+ message = f"Group {group_name} not in _weights_send_group list. Please call `init_weights_send_group_for_remote_instance` first."
999
+ logger.error(message)
1000
+ return False, message
1001
+
1002
+ torch.cuda.empty_cache()
1003
+ success = False
1004
+ message = ""
1005
+ try:
1006
+ for _, weights in self.model.named_parameters():
1007
+ torch.distributed.broadcast(
1008
+ weights,
1009
+ src=0,
1010
+ group=send_group,
1011
+ )
1012
+ success = True
1013
+ message = f"Succeeded to send weights through {master_address}:{group_port} {group_name}."
1014
+ except Exception as e:
1015
+ message = f"Failed to send weights: {e}."
1016
+ logger.error(message)
1017
+
1018
+ # destroy the process group after sending weights
1019
+ del self._weights_send_group[group_name]
1020
+ torch.distributed.distributed_c10d.destroy_process_group(send_group)
1021
+ torch.cuda.empty_cache()
1022
+ return success, message
1023
+
825
1024
  def init_weights_update_group(
826
1025
  self,
827
1026
  master_address,
@@ -867,6 +1066,19 @@ class ModelRunner:
867
1066
  logger.error(message)
868
1067
  return False, message
869
1068
 
1069
+ def destroy_weights_update_group(self, group_name):
1070
+ try:
1071
+ if group_name in self._model_update_group:
1072
+ pg = self._model_update_group.pop(group_name)
1073
+ torch.distributed.destroy_process_group(pg)
1074
+ return True, "Succeeded to destroy custom process group."
1075
+ else:
1076
+ return False, "The group to be destroyed does not exist."
1077
+ except Exception as e:
1078
+ message = f"Failed to destroy custom process group: {e}."
1079
+ logger.error(message)
1080
+ return False, message
1081
+
870
1082
  def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
871
1083
  """
872
1084
  Update specific parameter in the model weights online
@@ -904,7 +1116,7 @@ class ModelRunner:
904
1116
  handle.wait()
905
1117
 
906
1118
  self.model.load_weights(weights)
907
- return True, f"Succeeded to update parameter online."
1119
+ return True, "Succeeded to update parameter online."
908
1120
 
909
1121
  except Exception as e:
910
1122
  error_msg = (
@@ -1008,6 +1220,7 @@ class ModelRunner:
1008
1220
  max_lora_rank=self.server_args.max_lora_rank,
1009
1221
  target_modules=self.server_args.lora_target_modules,
1010
1222
  lora_paths=self.server_args.lora_paths,
1223
+ server_args=self.server_args,
1011
1224
  )
1012
1225
 
1013
1226
  def load_lora_adapter(self, lora_ref: LoRARef):
@@ -1057,6 +1270,8 @@ class ModelRunner:
1057
1270
  "num_nextn_predict_layers",
1058
1271
  self.num_effective_layers,
1059
1272
  )
1273
+ elif config := self.mambaish_config:
1274
+ num_layers = len(config.full_attention_layer_ids)
1060
1275
  else:
1061
1276
  num_layers = self.num_effective_layers
1062
1277
  if self.use_mla_backend:
@@ -1065,6 +1280,17 @@ class ModelRunner:
1065
1280
  * num_layers
1066
1281
  * torch._utils._element_size(self.kv_cache_dtype)
1067
1282
  )
1283
+ # Add indexer KV cache overhead for NSA models (DeepSeek V3.2)
1284
+ if is_deepseek_nsa(self.model_config.hf_config):
1285
+ index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config)
1286
+ indexer_size_per_token = (
1287
+ index_head_dim
1288
+ + index_head_dim // NSATokenToKVPool.quant_block_size * 4
1289
+ )
1290
+ element_size = torch._utils._element_size(
1291
+ NSATokenToKVPool.index_k_with_scale_buffer_dtype
1292
+ )
1293
+ cell_size += indexer_size_per_token * num_layers * element_size
1068
1294
  else:
1069
1295
  cell_size = (
1070
1296
  self.model_config.get_num_kv_heads(get_attention_tp_size())
@@ -1076,9 +1302,33 @@ class ModelRunner:
1076
1302
  rest_memory = available_gpu_memory - total_gpu_memory * (
1077
1303
  1 - self.mem_fraction_static
1078
1304
  )
1305
+ if config := self.mambaish_config:
1306
+ rest_memory -= (
1307
+ self.server_args.max_mamba_cache_size
1308
+ * config.mamba2_cache_params.mamba_cache_per_req
1309
+ / (1 << 30)
1310
+ )
1079
1311
  max_num_token = int(rest_memory * (1 << 30) // cell_size)
1080
1312
  return max_num_token
1081
1313
 
1314
+ @property
1315
+ def hybrid_gdn_config(self):
1316
+ config = self.model_config.hf_config
1317
+ if isinstance(config, Qwen3NextConfig):
1318
+ return config
1319
+ return None
1320
+
1321
+ @property
1322
+ def mamba2_config(self):
1323
+ config = self.model_config.hf_config
1324
+ if isinstance(config, FalconH1Config | NemotronHConfig):
1325
+ return config
1326
+ return None
1327
+
1328
+ @property
1329
+ def mambaish_config(self):
1330
+ return self.mamba2_config or self.hybrid_gdn_config
1331
+
1082
1332
  def set_num_token_hybrid(self):
1083
1333
  if (
1084
1334
  "Llama4ForConditionalGeneration"
@@ -1169,7 +1419,18 @@ class ModelRunner:
1169
1419
  ):
1170
1420
  # Determine the kv cache dtype
1171
1421
  if self.server_args.kv_cache_dtype == "auto":
1172
- self.kv_cache_dtype = self.dtype
1422
+ quant_config = getattr(self.model, "quant_config", None)
1423
+ kv_cache_quant_algo = getattr(quant_config, "kv_cache_quant_algo", None)
1424
+ if (
1425
+ isinstance(kv_cache_quant_algo, str)
1426
+ and kv_cache_quant_algo.upper() == "FP8"
1427
+ ):
1428
+ if _is_hip:
1429
+ self.kv_cache_dtype = torch.float8_e4m3fnuz
1430
+ else:
1431
+ self.kv_cache_dtype = torch.float8_e4m3fn
1432
+ else:
1433
+ self.kv_cache_dtype = self.dtype
1173
1434
  elif self.server_args.kv_cache_dtype == "fp8_e5m2":
1174
1435
  if _is_hip: # Using natively supported format
1175
1436
  self.kv_cache_dtype = torch.float8_e5m2fnuz
@@ -1185,6 +1446,8 @@ class ModelRunner:
1185
1446
  f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
1186
1447
  )
1187
1448
 
1449
+ log_info_on_rank0(logger, f"Using KV cache dtype: {self.kv_cache_dtype}")
1450
+
1188
1451
  self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
1189
1452
  if SGLANG_CI_SMALL_KV_SIZE:
1190
1453
  self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
@@ -1199,8 +1462,10 @@ class ModelRunner:
1199
1462
  ),
1200
1463
  4096,
1201
1464
  )
1465
+ if self.mambaish_config is not None:
1466
+ max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
1202
1467
 
1203
- if not self.spec_algorithm.is_none():
1468
+ if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
1204
1469
  if self.is_draft_worker:
1205
1470
  self.max_total_num_tokens = self.server_args.draft_runner_cache_size
1206
1471
  max_num_reqs = self.server_args.max_num_reqs
@@ -1237,13 +1502,24 @@ class ModelRunner:
1237
1502
  // self.server_args.page_size
1238
1503
  * self.server_args.page_size
1239
1504
  )
1505
+ # different pp rank may have different num of layers, so we need to reduce the max_total_num_tokens
1506
+ if self.pp_size > 1:
1507
+ tensor = torch.tensor(self.max_total_num_tokens, dtype=torch.int64)
1508
+ torch.distributed.all_reduce(
1509
+ tensor,
1510
+ op=torch.distributed.ReduceOp.MIN,
1511
+ group=get_world_group().cpu_group,
1512
+ )
1513
+ self.max_total_num_tokens = tensor.item()
1514
+
1240
1515
  # create token size for hybrid cache
1241
1516
  if self.is_hybrid:
1242
1517
  self.set_num_token_hybrid()
1243
1518
 
1244
1519
  if self.max_total_num_tokens <= 0:
1245
1520
  raise RuntimeError(
1246
- "Not enough memory. Please try to increase --mem-fraction-static."
1521
+ f"Not enough memory. Please try to increase --mem-fraction-static. "
1522
+ f"Current value: {self.server_args.mem_fraction_static=}"
1247
1523
  )
1248
1524
 
1249
1525
  # Initialize req_to_token_pool
@@ -1267,6 +1543,16 @@ class ModelRunner:
1267
1543
  enable_memory_saver=self.server_args.enable_memory_saver,
1268
1544
  pre_alloc_size=pre_alloc_size,
1269
1545
  )
1546
+ elif config := self.mambaish_config:
1547
+ self.req_to_token_pool = HybridReqToTokenPool(
1548
+ size=max_num_reqs,
1549
+ max_context_len=self.model_config.context_len
1550
+ + extra_max_context_len,
1551
+ device=self.device,
1552
+ enable_memory_saver=self.server_args.enable_memory_saver,
1553
+ cache_params=config.mamba2_cache_params,
1554
+ speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1555
+ )
1270
1556
  else:
1271
1557
  self.req_to_token_pool = ReqToTokenPool(
1272
1558
  size=max_num_reqs,
@@ -1280,6 +1566,7 @@ class ModelRunner:
1280
1566
  assert self.is_draft_worker
1281
1567
 
1282
1568
  # Initialize token_to_kv_pool
1569
+ is_nsa_model = is_deepseek_nsa(self.model_config.hf_config)
1283
1570
  if self.server_args.attention_backend == "ascend":
1284
1571
  if self.use_mla_backend:
1285
1572
  self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
@@ -1288,6 +1575,7 @@ class ModelRunner:
1288
1575
  dtype=self.kv_cache_dtype,
1289
1576
  kv_lora_rank=self.model_config.kv_lora_rank,
1290
1577
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1578
+ index_head_dim=self.model_config.index_head_dim,
1291
1579
  layer_num=self.num_effective_layers,
1292
1580
  device=self.device,
1293
1581
  enable_memory_saver=self.server_args.enable_memory_saver,
@@ -1307,7 +1595,22 @@ class ModelRunner:
1307
1595
  device=self.device,
1308
1596
  enable_memory_saver=self.server_args.enable_memory_saver,
1309
1597
  )
1598
+ elif self.use_mla_backend and is_nsa_model:
1599
+ self.token_to_kv_pool = NSATokenToKVPool(
1600
+ self.max_total_num_tokens,
1601
+ page_size=self.page_size,
1602
+ dtype=self.kv_cache_dtype,
1603
+ kv_lora_rank=self.model_config.kv_lora_rank,
1604
+ qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1605
+ layer_num=self.num_effective_layers,
1606
+ device=self.device,
1607
+ enable_memory_saver=self.server_args.enable_memory_saver,
1608
+ start_layer=self.start_layer,
1609
+ end_layer=self.end_layer,
1610
+ index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
1611
+ )
1310
1612
  elif self.use_mla_backend:
1613
+ assert not is_nsa_model
1311
1614
  self.token_to_kv_pool = MLATokenToKVPool(
1312
1615
  self.max_total_num_tokens,
1313
1616
  page_size=self.page_size,
@@ -1349,6 +1652,22 @@ class ModelRunner:
1349
1652
  enable_kvcache_transpose=False,
1350
1653
  device=self.device,
1351
1654
  )
1655
+ elif config := self.mambaish_config:
1656
+ self.token_to_kv_pool = HybridLinearKVPool(
1657
+ page_size=self.page_size,
1658
+ size=self.max_total_num_tokens,
1659
+ dtype=self.kv_cache_dtype,
1660
+ head_num=self.model_config.get_num_kv_heads(
1661
+ get_attention_tp_size()
1662
+ ),
1663
+ head_dim=self.model_config.head_dim,
1664
+ # if draft worker, we only need 1 attention layer's kv pool
1665
+ full_attention_layer_ids=(
1666
+ [0] if self.is_draft_worker else config.full_attention_layer_ids
1667
+ ),
1668
+ enable_kvcache_transpose=False,
1669
+ device=self.device,
1670
+ )
1352
1671
  else:
1353
1672
  self.token_to_kv_pool = MHATokenToKVPool(
1354
1673
  self.max_total_num_tokens,
@@ -1363,12 +1682,18 @@ class ModelRunner:
1363
1682
  enable_memory_saver=self.server_args.enable_memory_saver,
1364
1683
  start_layer=self.start_layer,
1365
1684
  end_layer=self.end_layer,
1685
+ enable_kv_cache_copy=(
1686
+ self.server_args.speculative_algorithm is not None
1687
+ ),
1366
1688
  )
1367
1689
 
1368
1690
  # Initialize token_to_kv_pool_allocator
1369
1691
  need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
1370
1692
  if self.token_to_kv_pool_allocator is None:
1371
- if self.server_args.attention_backend == "ascend":
1693
+ if _is_npu and (
1694
+ self.server_args.attention_backend == "ascend"
1695
+ or self.hybrid_gdn_config is not None
1696
+ ):
1372
1697
  self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1373
1698
  self.max_total_num_tokens,
1374
1699
  page_size=self.page_size,
@@ -1432,16 +1757,10 @@ class ModelRunner:
1432
1757
 
1433
1758
  def _get_attention_backend(self):
1434
1759
  """Init attention kernel backend."""
1435
- self.decode_attention_backend_str = (
1436
- self.server_args.decode_attention_backend
1437
- if self.server_args.decode_attention_backend
1438
- else self.server_args.attention_backend
1439
- )
1440
- self.prefill_attention_backend_str = (
1441
- self.server_args.prefill_attention_backend
1442
- if self.server_args.prefill_attention_backend
1443
- else self.server_args.attention_backend
1760
+ self.prefill_attention_backend_str, self.decode_attention_backend_str = (
1761
+ self.server_args.get_attention_backends()
1444
1762
  )
1763
+
1445
1764
  if self.decode_attention_backend_str != self.prefill_attention_backend_str:
1446
1765
  from sglang.srt.layers.attention.hybrid_attn_backend import (
1447
1766
  HybridAttnBackend,
@@ -1462,8 +1781,8 @@ class ModelRunner:
1462
1781
  f"prefill_backend={self.prefill_attention_backend_str}."
1463
1782
  )
1464
1783
  logger.warning(
1465
- f"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
1466
- f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
1784
+ "Warning: Attention backend specified by --attention-backend or default backend might be overridden."
1785
+ "The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
1467
1786
  )
1468
1787
  else:
1469
1788
  attn_backend = self._get_attention_backend_from_str(
@@ -1479,111 +1798,10 @@ class ModelRunner:
1479
1798
  return attn_backend
1480
1799
 
1481
1800
  def _get_attention_backend_from_str(self, backend_str: str):
1482
- if backend_str == "flashinfer":
1483
- if not self.use_mla_backend:
1484
- from sglang.srt.layers.attention.flashinfer_backend import (
1485
- FlashInferAttnBackend,
1486
- )
1487
-
1488
- # Init streams
1489
- if self.server_args.speculative_algorithm == "EAGLE":
1490
- if (
1491
- not hasattr(self, "plan_stream_for_flashinfer")
1492
- or not self.plan_stream_for_flashinfer
1493
- ):
1494
- self.plan_stream_for_flashinfer = torch.cuda.Stream()
1495
- return FlashInferAttnBackend(self)
1496
- else:
1497
- from sglang.srt.layers.attention.flashinfer_mla_backend import (
1498
- FlashInferMLAAttnBackend,
1499
- )
1500
-
1501
- return FlashInferMLAAttnBackend(self)
1502
- elif backend_str == "aiter":
1503
- from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
1504
-
1505
- return AiterAttnBackend(self)
1506
- elif self.server_args.attention_backend == "wave":
1507
- from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
1508
-
1509
- return WaveAttnBackend(self)
1510
- elif backend_str == "ascend":
1511
- from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
1512
-
1513
- return AscendAttnBackend(self)
1514
- elif backend_str == "triton":
1515
- assert not self.model_config.is_encoder_decoder, (
1516
- "Cross attention is not supported in the triton attention backend. "
1517
- "Please use `--attention-backend flashinfer`."
1518
- )
1519
- if self.server_args.enable_double_sparsity:
1520
- from sglang.srt.layers.attention.double_sparsity_backend import (
1521
- DoubleSparseAttnBackend,
1522
- )
1523
-
1524
- return DoubleSparseAttnBackend(self)
1525
- else:
1526
- from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
1527
-
1528
- return TritonAttnBackend(self)
1529
- elif backend_str == "torch_native":
1530
- from sglang.srt.layers.attention.torch_native_backend import (
1531
- TorchNativeAttnBackend,
1532
- )
1533
-
1534
- return TorchNativeAttnBackend(self)
1535
- elif backend_str == "flashmla":
1536
- from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
1537
-
1538
- return FlashMLABackend(self)
1539
- elif backend_str == "fa3":
1540
- assert (
1541
- torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
1542
- ) or torch.cuda.get_device_capability()[0] == 9, (
1543
- "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
1544
- "Please use `--attention-backend flashinfer`."
1545
- )
1546
- from sglang.srt.layers.attention.flashattention_backend import (
1547
- FlashAttentionBackend,
1548
- )
1549
-
1550
- return FlashAttentionBackend(self)
1551
- elif backend_str == "cutlass_mla":
1552
- from sglang.srt.layers.attention.cutlass_mla_backend import (
1553
- CutlassMLABackend,
1554
- )
1555
-
1556
- return CutlassMLABackend(self)
1557
- elif backend_str == "trtllm_mla":
1558
- if not self.use_mla_backend:
1559
- raise ValueError("trtllm_mla backend can only be used with MLA models.")
1560
- from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
1561
-
1562
- return TRTLLMMLABackend(self)
1563
- elif backend_str == "trtllm_mha":
1564
- if self.use_mla_backend:
1565
- raise ValueError(
1566
- "trtllm_mha backend can only be used with non-MLA models."
1567
- )
1568
- from sglang.srt.layers.attention.trtllm_mha_backend import (
1569
- TRTLLMHAAttnBackend,
1570
- )
1571
-
1572
- return TRTLLMHAAttnBackend(self)
1573
- elif backend_str == "intel_amx":
1574
- from sglang.srt.layers.attention.intel_amx_backend import (
1575
- IntelAMXAttnBackend,
1576
- )
1577
-
1578
- return IntelAMXAttnBackend(self)
1579
- elif backend_str == "dual_chunk_flash_attn":
1580
- from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
1581
- DualChunkFlashAttentionBackend,
1582
- )
1583
-
1584
- return DualChunkFlashAttentionBackend(self)
1585
- else:
1801
+ if backend_str not in ATTENTION_BACKENDS:
1586
1802
  raise ValueError(f"Invalid attention backend: {backend_str}")
1803
+ full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
1804
+ return attn_backend_wrapper(self, full_attention_backend)
1587
1805
 
1588
1806
  def init_double_sparsity_channel_config(self, selected_channel):
1589
1807
  selected_channel = "." + selected_channel + "_proj"
@@ -1603,38 +1821,46 @@ class ModelRunner:
1603
1821
  )
1604
1822
 
1605
1823
  def init_device_graphs(self):
1606
- """Capture cuda graphs."""
1824
+ """Capture device graphs."""
1607
1825
  self.graph_runner = None
1608
- self.cuda_graph_mem_usage = 0
1826
+ self.graph_mem_usage = 0
1609
1827
 
1610
1828
  if not self.is_generation:
1611
1829
  # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
1612
1830
  return
1613
1831
 
1614
- if self.server_args.disable_cuda_graph:
1832
+ if self.device != "cpu" and self.server_args.disable_cuda_graph:
1833
+ return
1834
+
1835
+ if self.device == "cpu" and not self.server_args.enable_torch_compile:
1615
1836
  return
1616
1837
 
1617
1838
  tic = time.perf_counter()
1618
1839
  before_mem = get_available_gpu_memory(self.device, self.gpu_id)
1619
1840
  logger.info(
1620
- f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1841
+ f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1621
1842
  )
1622
- self.graph_runner = (
1623
- CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self)
1843
+ graph_runners = defaultdict(
1844
+ lambda: CudaGraphRunner,
1845
+ {
1846
+ "cpu": CPUGraphRunner,
1847
+ "npu": NPUGraphRunner,
1848
+ },
1624
1849
  )
1850
+ self.graph_runner = graph_runners[self.device](self)
1851
+
1625
1852
  after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1626
- self.cuda_graph_mem_usage = before_mem - after_mem
1853
+ self.graph_mem_usage = before_mem - after_mem
1627
1854
  logger.info(
1628
- f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
1629
- f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
1855
+ f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
1856
+ f"mem usage={self.graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
1630
1857
  )
1631
1858
 
1632
1859
  def init_threads_binding(self):
1633
1860
  omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
1861
+ cpu_ids_by_node = get_cpu_ids_by_node()
1862
+ n_numa_node = len(cpu_ids_by_node)
1634
1863
  if omp_cpuids == "all":
1635
- cpu_ids_by_node = get_cpu_ids_by_node()
1636
- n_numa_node = len(cpu_ids_by_node)
1637
-
1638
1864
  assert self.tp_size <= n_numa_node, (
1639
1865
  f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
1640
1866
  f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
@@ -1651,7 +1877,18 @@ class ModelRunner:
1651
1877
  )
1652
1878
  self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
1653
1879
  else:
1654
- self.local_omp_cpuid = omp_cpuids.split("|")[self.tp_rank]
1880
+ threads_bind_list = omp_cpuids.split("|")
1881
+ assert self.tp_size == len(threads_bind_list), (
1882
+ f"SGLANG_CPU_OMP_THREADS_BIND setting must be aligned with TP size parameter ({self.tp_size}). "
1883
+ f"Please double check your settings."
1884
+ )
1885
+ self.local_omp_cpuid = threads_bind_list[self.tp_rank]
1886
+ if self.tp_size > n_numa_node:
1887
+ logger.warning(
1888
+ f"TP size ({self.tp_size})is larger than numa node number ({n_numa_node}), "
1889
+ f"in this case the available memory amount of each rank cannot be determined in prior. "
1890
+ f"Please set proper `--max-total-tokens` to avoid the out-of-memory error."
1891
+ )
1655
1892
 
1656
1893
  def apply_torch_tp(self):
1657
1894
  logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
@@ -1771,18 +2008,24 @@ class ModelRunner:
1771
2008
  reinit_attn_backend: bool = False,
1772
2009
  split_forward_count: int = 1,
1773
2010
  ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1774
- can_run_cuda_graph = bool(
1775
- forward_batch.forward_mode.is_cuda_graph()
2011
+ mode_check = (
2012
+ forward_batch.forward_mode.is_cpu_graph
2013
+ if self.device == "cpu"
2014
+ else forward_batch.forward_mode.is_cuda_graph
2015
+ )
2016
+ can_run_graph = bool(
2017
+ mode_check()
1776
2018
  and self.graph_runner
1777
2019
  and self.graph_runner.can_run(forward_batch)
1778
2020
  )
1779
- if can_run_cuda_graph:
2021
+
2022
+ if can_run_graph:
1780
2023
  ret = self.graph_runner.replay(
1781
2024
  forward_batch,
1782
2025
  skip_attn_backend_init=skip_attn_backend_init,
1783
2026
  pp_proxy_tensors=pp_proxy_tensors,
1784
2027
  )
1785
- return ret, can_run_cuda_graph
2028
+ return ret, can_run_graph
1786
2029
 
1787
2030
  # For MLP sync
1788
2031
  if forward_batch.global_num_tokens_cpu is not None:
@@ -1811,23 +2054,22 @@ class ModelRunner:
1811
2054
  else:
1812
2055
  raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1813
2056
 
1814
- if forward_batch.global_num_tokens_cpu is not None:
2057
+ if (
2058
+ forward_batch.global_num_tokens_cpu is not None
2059
+ and self.pp_group.is_last_rank
2060
+ ):
1815
2061
  forward_batch.post_forward_mlp_sync_batch(ret)
1816
2062
 
1817
- return ret, can_run_cuda_graph
2063
+ return ret, can_run_graph
1818
2064
 
1819
2065
  def _preprocess_logits(
1820
2066
  self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
1821
2067
  ):
1822
- # Apply logit bias
1823
- if sampling_info.sampling_info_done:
1824
- # Overlap mode: the function update_regex_vocab_mask was executed
1825
- # in process_batch_result of the last batch.
1826
- if sampling_info.grammars:
1827
- sampling_info.sampling_info_done.wait()
1828
- else:
1829
- # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
1830
- sampling_info.update_regex_vocab_mask()
2068
+ # NOTE: In overlap mode, the function update_regex_vocab_mask (in sample)
2069
+ # was executed after we processed last batch's results.
2070
+
2071
+ # Calculate logits bias and apply it to next_token_logits.
2072
+ sampling_info.update_regex_vocab_mask()
1831
2073
  sampling_info.apply_logits_bias(logits_output.next_token_logits)
1832
2074
 
1833
2075
  def sample(
@@ -1852,7 +2094,6 @@ class ModelRunner:
1852
2094
  )
1853
2095
 
1854
2096
  self._preprocess_logits(logits_output, forward_batch.sampling_info)
1855
-
1856
2097
  # Sample the next tokens
1857
2098
  next_token_ids = self.sampler(
1858
2099
  logits_output,
@@ -1860,9 +2101,47 @@ class ModelRunner:
1860
2101
  forward_batch.return_logprob,
1861
2102
  forward_batch.top_logprobs_nums,
1862
2103
  forward_batch.token_ids_logprobs,
2104
+ # For prefill, we only use the position of the last token.
2105
+ (
2106
+ forward_batch.positions
2107
+ if forward_batch.forward_mode.is_decode()
2108
+ else forward_batch.seq_lens - 1
2109
+ ),
1863
2110
  )
1864
2111
  return next_token_ids
1865
2112
 
2113
+ def compute_logprobs_only(
2114
+ self,
2115
+ logits_output: LogitsProcessorOutput,
2116
+ forward_batch: ForwardBatch,
2117
+ ) -> None:
2118
+ """
2119
+ Compute token_ids_logprobs without performing sampling.
2120
+
2121
+ Optimized path for prefill-only requests that need token_ids_logprobs but don't
2122
+ require next token generation. Skips expensive sampling operations
2123
+ while still providing requested probability information.
2124
+
2125
+ Args:
2126
+ logits_output: The logits output from the model forward
2127
+ forward_batch: The forward batch that generates logits_output
2128
+ """
2129
+ if not forward_batch.token_ids_logprobs:
2130
+ return
2131
+
2132
+ # Preprocess logits (same as in sample method)
2133
+ self._preprocess_logits(logits_output, forward_batch.sampling_info)
2134
+
2135
+ # Delegate to sampler for logprob-only computation
2136
+ # This populates logits_output with requested token probabilities
2137
+ self.sampler.compute_logprobs_only(
2138
+ logits_output,
2139
+ forward_batch.sampling_info,
2140
+ forward_batch.return_logprob,
2141
+ forward_batch.top_logprobs_nums,
2142
+ forward_batch.token_ids_logprobs,
2143
+ )
2144
+
1866
2145
  @property
1867
2146
  def model_is_mrope(self) -> bool:
1868
2147
  """Detect if the model has "mrope" rope_scaling type.