sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from sglang.srt.distributed import (
17
17
  get_tp_group,
18
18
  tensor_model_parallel_all_reduce,
19
19
  )
20
+ from sglang.srt.utils import get_bool_env_var, is_hip
20
21
 
21
22
  if TYPE_CHECKING:
22
23
  from sglang.srt.configs.model_config import ModelConfig
@@ -36,6 +37,9 @@ _LOCAL_ATTN_DP_SIZE: Optional[int] = None
36
37
  _LOCAL_ATTN_DP_RANK: Optional[int] = None
37
38
  _ENABLE_DP_ATTENTION_FLAG: bool = False
38
39
 
40
+ _is_hip = is_hip()
41
+ _USE_ROCM700A_WA = _is_hip and get_bool_env_var("SGLANG_USE_ROCM700A")
42
+
39
43
 
40
44
  class DpPaddingMode(IntEnum):
41
45
 
@@ -51,7 +55,12 @@ class DpPaddingMode(IntEnum):
51
55
  return self == DpPaddingMode.SUM_LEN
52
56
 
53
57
  @classmethod
54
- def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode:
58
+ def get_dp_padding_mode(
59
+ cls, is_extend_in_batch, global_num_tokens: List[int]
60
+ ) -> DpPaddingMode:
61
+ if is_extend_in_batch:
62
+ return DpPaddingMode.SUM_LEN
63
+
55
64
  # we choose the mode that minimizes the communication cost
56
65
  max_len = max(global_num_tokens)
57
66
  sum_len = sum(global_num_tokens)
@@ -62,7 +71,12 @@ class DpPaddingMode(IntEnum):
62
71
 
63
72
  @classmethod
64
73
  def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
65
- return cls.MAX_LEN
74
+ # TODO(kkhuang-amd): noqa, temporary work-around for rocm 7.0.0 alpha
75
+ # it can be safely removed later, once RCCL fixed
76
+ if _USE_ROCM700A_WA:
77
+ return cls.SUM_LEN
78
+ else:
79
+ return cls.MAX_LEN
66
80
 
67
81
 
68
82
  class _DpGatheredBufferWrapper:
@@ -119,6 +133,18 @@ class _DpGatheredBufferWrapper:
119
133
  def get_dp_global_num_tokens(cls) -> List[int]:
120
134
  return cls._global_num_tokens
121
135
 
136
+ @classmethod
137
+ def get_dp_hidden_size(cls) -> int:
138
+ return cls._hidden_size
139
+
140
+ @classmethod
141
+ def get_dp_dtype(cls) -> torch.dtype:
142
+ return cls._dtype
143
+
144
+ @classmethod
145
+ def get_dp_device(cls) -> torch.device:
146
+ return cls._device
147
+
122
148
 
123
149
  def set_dp_buffer_len(
124
150
  global_dp_buffer_len: int,
@@ -150,6 +176,18 @@ def get_dp_global_num_tokens() -> List[int]:
150
176
  return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
151
177
 
152
178
 
179
+ def get_dp_hidden_size() -> int:
180
+ return _DpGatheredBufferWrapper.get_dp_hidden_size()
181
+
182
+
183
+ def get_dp_dtype() -> torch.dtype:
184
+ return _DpGatheredBufferWrapper.get_dp_dtype()
185
+
186
+
187
+ def get_dp_device() -> torch.device:
188
+ return _DpGatheredBufferWrapper.get_dp_device()
189
+
190
+
153
191
  def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
154
192
  if not enable_dp_attention:
155
193
  return tp_rank, tp_size, 0
@@ -225,6 +263,7 @@ def initialize_dp_attention(
225
263
  use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
226
264
  use_pymscclpp=False,
227
265
  use_custom_allreduce=False,
266
+ use_torch_symm_mem=False,
228
267
  use_hpu_communicator=False,
229
268
  use_xpu_communicator=False,
230
269
  use_npu_communicator=False,
@@ -187,7 +187,9 @@ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
187
187
 
188
188
  def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
189
189
  assert len(x.shape) == 2
190
- assert x.shape == residual.shape and x.dtype == residual.dtype
190
+ assert (
191
+ x.shape == residual.shape and x.dtype == residual.dtype
192
+ ), f"{x.shape=} {residual.shape=} {x.dtype=} {residual.dtype=}"
191
193
  output, mid = torch.empty_like(x), torch.empty_like(x)
192
194
  bs, hidden_dim = x.shape
193
195
  if autotune:
@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union
18
18
 
19
19
  import torch
20
20
  import torch.nn as nn
21
+ from packaging.version import Version
21
22
 
22
23
  from sglang.srt.custom_op import CustomOp
23
24
  from sglang.srt.utils import (
@@ -25,32 +26,38 @@ from sglang.srt.utils import (
25
26
  get_bool_env_var,
26
27
  is_cpu,
27
28
  is_cuda,
29
+ is_flashinfer_available,
28
30
  is_hip,
29
31
  is_npu,
32
+ is_xpu,
30
33
  supports_custom_op,
31
34
  )
32
35
 
33
36
  _is_cuda = is_cuda()
37
+ _is_flashinfer_available = is_flashinfer_available()
34
38
  _is_hip = is_hip()
35
39
  _is_npu = is_npu()
36
40
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
37
41
  _is_cpu_amx_available = cpu_has_amx_support()
38
42
  _is_cpu = is_cpu()
43
+ _is_xpu = is_xpu()
39
44
 
40
45
  if _is_cuda:
41
- from sgl_kernel import (
42
- fused_add_rmsnorm,
43
- gemma_fused_add_rmsnorm,
44
- gemma_rmsnorm,
45
- rmsnorm,
46
- )
46
+ if _is_flashinfer_available:
47
+ from flashinfer.norm import fused_add_rmsnorm
48
+ else:
49
+ from sgl_kernel import fused_add_rmsnorm
50
+ from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
47
51
 
48
52
  if _use_aiter:
49
53
  from aiter import rmsnorm2d_fwd as rms_norm
50
54
  from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
51
55
  elif _is_hip:
56
+ import vllm
52
57
  from vllm._custom_ops import fused_add_rms_norm, rms_norm
53
58
 
59
+ _vllm_version = Version(vllm.__version__)
60
+
54
61
  logger = logging.getLogger(__name__)
55
62
 
56
63
  if _is_npu:
@@ -73,6 +80,8 @@ class RMSNorm(CustomOp):
73
80
  )
74
81
  if _use_aiter:
75
82
  self._forward_method = self.forward_aiter
83
+ if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"):
84
+ self._forward_method = self.forward_native
76
85
 
77
86
  def forward_cuda(
78
87
  self,
@@ -127,8 +136,21 @@ class RMSNorm(CustomOp):
127
136
  # NOTE: Remove this if aiter kernel supports discontinuous input
128
137
  x = x.contiguous()
129
138
  if residual is not None:
130
- fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
131
- return x, residual
139
+ if _vllm_version < Version("0.9"):
140
+ fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
141
+ return x, residual
142
+ else:
143
+ residual_out = torch.empty_like(x)
144
+ output = torch.empty_like(x)
145
+ fused_add_rms_norm(
146
+ output,
147
+ x,
148
+ residual_out,
149
+ residual,
150
+ self.weight.data,
151
+ self.variance_epsilon,
152
+ )
153
+ return output, residual_out
132
154
  out = torch.empty_like(x)
133
155
  rms_norm(out, x, self.weight.data, self.variance_epsilon)
134
156
  return out
@@ -271,16 +293,11 @@ class GemmaRMSNorm(CustomOp):
271
293
  x: torch.Tensor,
272
294
  residual: Optional[torch.Tensor] = None,
273
295
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
274
- orig_dtype = x.dtype
275
296
  if residual is not None:
276
297
  x = x + residual
277
298
  residual = x
278
299
 
279
- x = x.float()
280
- variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True)
281
- x = x * torch_npu.rsqrt(variance + self.variance_epsilon)
282
- x = x * (1.0 + self.weight.float())
283
- x = x.to(orig_dtype)
300
+ x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
284
301
  return x if residual is None else (x, residual)
285
302
 
286
303
 
@@ -312,7 +329,9 @@ class Gemma3RMSNorm(CustomOp):
312
329
  return f"{tuple(self.weight.shape)}, eps={self.eps}"
313
330
 
314
331
 
315
- if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)):
332
+ if not (
333
+ _is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_xpu
334
+ ):
316
335
  logger.info(
317
336
  "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
318
337
  )
@@ -31,6 +31,7 @@ from sglang.srt.layers.parameter import (
31
31
  _ColumnvLLMParameter,
32
32
  )
33
33
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
34
+ from sglang.srt.layers.utils import pad_or_narrow_weight
34
35
  from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
35
36
 
36
37
  if TYPE_CHECKING:
@@ -235,9 +236,8 @@ class ReplicatedLinear(LinearBase):
235
236
  loaded_weight = loaded_weight[:1]
236
237
  else:
237
238
  raise ValueError(f"{loaded_weight} are not all equal")
238
- assert (
239
- param.size() == loaded_weight.size()
240
- ), f"Loading weight error: param: {param.size()}, loaded_weight: {loaded_weight.size()}"
239
+
240
+ assert param.size() == loaded_weight.size()
241
241
  param.data.copy_(loaded_weight)
242
242
 
243
243
  def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@@ -626,9 +626,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
626
626
  # bitsandbytes loads the weights of the specific portion
627
627
  # no need to narrow here
628
628
  if not use_bitsandbytes_4bit and not self.use_presharded_weights:
629
- loaded_weight = loaded_weight.narrow(
630
- output_dim, start_idx, shard_size
631
- )
629
+ # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
630
+ end_idx = start_idx + shard_size
631
+ if end_idx > loaded_weight.shape[output_dim]:
632
+ loaded_weight = pad_or_narrow_weight(
633
+ loaded_weight, output_dim, start_idx, shard_size
634
+ )
635
+ else:
636
+ loaded_weight = loaded_weight.narrow(
637
+ output_dim, start_idx, shard_size
638
+ )
632
639
 
633
640
  # Special case for AQLM codebooks.
634
641
  elif is_metadata:
@@ -894,6 +901,35 @@ class QKVParallelLinear(ColumnParallelLinear):
894
901
  )
895
902
  self.weight_loader_v2(param, loaded_weight_shard, shard_id)
896
903
 
904
+ def _load_qkv_block_scale(
905
+ self, param: BasevLLMParameter, loaded_weight: torch.Tensor
906
+ ):
907
+ block_n, _ = self.quant_method.quant_config.weight_block_size
908
+ q_size = self.total_num_heads * self.head_size // block_n
909
+ k_size = self.total_num_kv_heads * self.head_size // block_n
910
+ v_size = self.total_num_kv_heads * self.head_size // block_n
911
+ shard_offsets = [
912
+ # (shard_id, shard_offset, shard_size)
913
+ ("q", 0, q_size),
914
+ ("k", q_size, k_size),
915
+ ("v", q_size + k_size, v_size),
916
+ ]
917
+ for shard_id, shard_offset, shard_size in shard_offsets:
918
+ loaded_weight_shard = loaded_weight.narrow(
919
+ param.output_dim, shard_offset, shard_size
920
+ )
921
+ rank_shard_offset = self._get_shard_offset_mapping(shard_id) // block_n
922
+ rank_shard_size = self._get_shard_size_mapping(shard_id) // block_n
923
+ param.load_qkv_weight(
924
+ loaded_weight=loaded_weight_shard,
925
+ num_heads=self.num_kv_head_replicas,
926
+ shard_id=shard_id,
927
+ shard_offset=rank_shard_offset,
928
+ shard_size=rank_shard_size,
929
+ tp_rank=self.tp_rank,
930
+ use_presharded_weights=self.use_presharded_weights,
931
+ )
932
+
897
933
  def weight_loader_v2(
898
934
  self,
899
935
  param: BasevLLMParameter,
@@ -907,6 +943,9 @@ class QKVParallelLinear(ColumnParallelLinear):
907
943
  elif type(param) in (RowvLLMParameter, BasevLLMParameter):
908
944
  param.load_qkv_weight(loaded_weight=loaded_weight)
909
945
  return
946
+ elif isinstance(param, BlockQuantScaleParameter):
947
+ self._load_qkv_block_scale(param, loaded_weight)
948
+ return
910
949
  # TODO: @dsikka - move to parameter.py
911
950
  self._load_fused_module_from_checkpoint(param, loaded_weight)
912
951
  return
@@ -1271,7 +1310,16 @@ class RowParallelLinear(LinearBase):
1271
1310
  shard_size,
1272
1311
  )
1273
1312
  else:
1274
- loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1313
+ # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
1314
+ end_idx = start_idx + shard_size
1315
+ if end_idx > loaded_weight.shape[input_dim]:
1316
+ loaded_weight = pad_or_narrow_weight(
1317
+ loaded_weight, input_dim, start_idx, shard_size
1318
+ )
1319
+ else:
1320
+ loaded_weight = loaded_weight.narrow(
1321
+ input_dim, start_idx, shard_size
1322
+ )
1275
1323
 
1276
1324
  # Special case for loading scales off disk, which often do not
1277
1325
  # have a shape (such as in the case of AutoFP8).
@@ -35,6 +35,9 @@ from sglang.srt.layers.dp_attention import (
35
35
  get_attention_dp_rank,
36
36
  get_attention_dp_size,
37
37
  get_attention_tp_size,
38
+ get_dp_device,
39
+ get_dp_dtype,
40
+ get_dp_hidden_size,
38
41
  get_global_dp_buffer,
39
42
  get_local_attention_dp_size,
40
43
  set_dp_buffer_len,
@@ -46,16 +49,19 @@ from sglang.srt.model_executor.forward_batch_info import (
46
49
  ForwardBatch,
47
50
  ForwardMode,
48
51
  )
49
- from sglang.srt.utils import dump_to_file, use_intel_amx_backend
52
+ from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
50
53
 
51
54
  logger = logging.getLogger(__name__)
52
55
 
56
+ _is_npu = is_npu()
57
+
53
58
 
54
59
  @dataclasses.dataclass
55
60
  class LogitsProcessorOutput:
56
61
  ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
57
62
  # The logits of the next tokens. shape: [#seq, vocab_size]
58
- next_token_logits: torch.Tensor
63
+ # Can be None for certain prefill-only requests (e.g., multi-item scoring) that don't need next token generation
64
+ next_token_logits: Optional[torch.Tensor]
59
65
  # Used by speculative decoding (EAGLE)
60
66
  # The last hidden layers
61
67
  hidden_states: Optional[torch.Tensor] = None
@@ -67,7 +73,10 @@ class LogitsProcessorOutput:
67
73
  next_token_top_logprobs_val: Optional[List] = None
68
74
  next_token_top_logprobs_idx: Optional[List] = None
69
75
  # The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
70
- next_token_token_ids_logprobs_val: Optional[List] = None
76
+ # Can contain either lists or GPU tensors (for delayed copy optimization in prefill-only requests)
77
+ next_token_token_ids_logprobs_val: Optional[
78
+ List[Union[List[float], torch.Tensor]]
79
+ ] = None
71
80
  next_token_token_ids_logprobs_idx: Optional[List] = None
72
81
 
73
82
  ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
@@ -77,7 +86,10 @@ class LogitsProcessorOutput:
77
86
  input_top_logprobs_val: List = None
78
87
  input_top_logprobs_idx: List = None
79
88
  # The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
80
- input_token_ids_logprobs_val: Optional[List] = None
89
+ # Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization)
90
+ input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = (
91
+ None
92
+ )
81
93
  input_token_ids_logprobs_idx: Optional[List] = None
82
94
 
83
95
 
@@ -119,6 +131,9 @@ class LogitsMetadata:
119
131
  # for padding
120
132
  padded_static_len: int = -1
121
133
 
134
+ # Whether this batch is prefill-only (no token generation needed)
135
+ is_prefill_only: bool = False
136
+
122
137
  @classmethod
123
138
  def from_forward_batch(cls, forward_batch: ForwardBatch):
124
139
  if (
@@ -161,6 +176,7 @@ class LogitsMetadata:
161
176
  token_ids_logprobs=forward_batch.token_ids_logprobs,
162
177
  extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
163
178
  padded_static_len=forward_batch.padded_static_len,
179
+ is_prefill_only=forward_batch.is_prefill_only,
164
180
  global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
165
181
  dp_local_start_pos=forward_batch.dp_local_start_pos,
166
182
  dp_local_num_tokens=forward_batch.dp_local_num_tokens,
@@ -180,10 +196,13 @@ class LogitsMetadata:
180
196
  )
181
197
  else:
182
198
  dp_local_start_pos = cumtokens[dp_rank - 1]
183
- dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
184
199
 
185
200
  self.dp_local_start_pos = dp_local_start_pos
186
- self.dp_local_num_tokens = dp_local_num_tokens
201
+ self.dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
202
+
203
+ hidden_size = get_dp_hidden_size()
204
+ dtype = get_dp_dtype()
205
+ device = get_dp_device()
187
206
 
188
207
  if self.global_num_tokens_for_logprob_cpu is not None:
189
208
  # create a smaller buffer to reduce peak memory usage
@@ -191,10 +210,13 @@ class LogitsMetadata:
191
210
  else:
192
211
  self.global_dp_buffer_len = self.global_dp_buffer_len
193
212
 
194
- set_dp_buffer_len(
195
- self.global_dp_buffer_len,
196
- self.dp_local_num_tokens,
197
- self.global_num_tokens_for_logprob_cpu,
213
+ self.gathered_buffer = torch.empty(
214
+ (
215
+ self.global_dp_buffer_len,
216
+ hidden_size,
217
+ ),
218
+ dtype=dtype,
219
+ device=device,
198
220
  )
199
221
 
200
222
 
@@ -206,6 +228,7 @@ class LogitsProcessor(nn.Module):
206
228
  self.config = config
207
229
  self.logit_scale = logit_scale
208
230
  self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
231
+ self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
209
232
  if self.use_attn_tp_group:
210
233
  self.attn_tp_size = get_attention_tp_size()
211
234
  self.do_tensor_parallel_all_gather = (
@@ -232,6 +255,108 @@ class LogitsProcessor(nn.Module):
232
255
  "debug_tensor_dump_output_folder", None
233
256
  )
234
257
 
258
+ def compute_logprobs_for_multi_item_scoring(
259
+ self,
260
+ input_ids,
261
+ hidden_states,
262
+ lm_head: VocabParallelEmbedding,
263
+ logits_metadata: Union[LogitsMetadata, ForwardBatch],
264
+ delimiter_token: int,
265
+ ):
266
+ """
267
+ Compute logprobs for multi-item scoring using delimiter-based token extraction.
268
+
269
+ This method is designed for scenarios where you want to score multiple items/candidates
270
+ against a single query by combining them into one sequence separated by delimiters.
271
+
272
+ Sequence format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
273
+ Scoring positions: Extracts logprobs at positions before each <delimiter>
274
+
275
+ Args:
276
+ input_ids (torch.Tensor): Input token IDs containing query and items separated by delimiters.
277
+ Shape: [total_sequence_length] for single request or [batch_total_length] for batch.
278
+ hidden_states (torch.Tensor): Hidden states from the model.
279
+ Shape: [sequence_length, hidden_dim].
280
+ lm_head (VocabParallelEmbedding): Language model head for computing logits.
281
+ logits_metadata (Union[LogitsMetadata, ForwardBatch]): Metadata containing batch info
282
+ and token ID specifications for logprob extraction.
283
+ delimiter_token (int): Token ID used as delimiter between query and items.
284
+
285
+ Returns:
286
+ LogitsProcessorOutput: Contains:
287
+ - next_token_logits: None (not needed for scoring-only requests)
288
+ - input_token_logprobs: Logprobs of delimiter tokens at scoring positions
289
+ - input_top_logprobs_val: Top-k logprobs at delimiter positions (if requested)
290
+ - input_top_logprobs_idx: Top-k token indices at delimiter positions (if requested)
291
+ - input_token_ids_logprobs_val: Logprobs for user-requested token IDs (if any)
292
+ - input_token_ids_logprobs_idx: Indices for user-requested token IDs (if any)
293
+ """
294
+ multi_item_indices = (input_ids == delimiter_token).nonzero(as_tuple=True)[
295
+ 0
296
+ ] - 1
297
+ # Extract hidden states at delimiter positions for multi-item scoring
298
+ sliced_hidden = hidden_states[multi_item_indices]
299
+
300
+ sliced_logits = self._get_logits(sliced_hidden, lm_head, logits_metadata)
301
+ sliced_logprobs = torch.nn.functional.log_softmax(sliced_logits, dim=-1)
302
+
303
+ # Initialize return values
304
+ input_token_ids_logprobs_val = []
305
+ input_token_ids_logprobs_idx = []
306
+ input_top_logprobs_val = None
307
+ input_top_logprobs_idx = None
308
+
309
+ # Recalculate extend_logprob_pruned_lens_cpu to match delimiter counts per request
310
+ # Original contains sequence lengths, but we need delimiter counts for sliced_logprobs
311
+ if (
312
+ logits_metadata.token_ids_logprobs
313
+ or logits_metadata.extend_return_top_logprob
314
+ ):
315
+ logits_metadata.extend_logprob_pruned_lens_cpu = []
316
+
317
+ if logits_metadata.extend_seq_lens_cpu is not None:
318
+ # Multi-request batch: count delimiters per request
319
+ input_pt = 0
320
+ for req_seq_len in logits_metadata.extend_seq_lens_cpu:
321
+ req_input_ids = input_ids[input_pt : input_pt + req_seq_len]
322
+ delimiter_count = (req_input_ids == delimiter_token).sum().item()
323
+ logits_metadata.extend_logprob_pruned_lens_cpu.append(
324
+ delimiter_count
325
+ )
326
+ input_pt += req_seq_len
327
+ else:
328
+ # Single request case: one request gets all delimiters
329
+ total_delimiters = (input_ids == delimiter_token).sum().item()
330
+ logits_metadata.extend_logprob_pruned_lens_cpu = [total_delimiters]
331
+
332
+ # Get the logprobs of specified token ids
333
+ if logits_metadata.extend_token_ids_logprob:
334
+ (
335
+ input_token_ids_logprobs_val,
336
+ input_token_ids_logprobs_idx,
337
+ ) = self.get_token_ids_logprobs(
338
+ sliced_logprobs, logits_metadata, delay_cpu_copy=True
339
+ )
340
+
341
+ # Get the logprob of top-k tokens
342
+ if logits_metadata.extend_return_top_logprob:
343
+ (
344
+ input_top_logprobs_val,
345
+ input_top_logprobs_idx,
346
+ ) = self.get_top_logprobs(sliced_logprobs, logits_metadata)
347
+
348
+ # For input_token_logprobs, use delimiter token logprobs
349
+ input_token_logprobs = sliced_logprobs[:, delimiter_token]
350
+
351
+ return LogitsProcessorOutput(
352
+ next_token_logits=None, # Multi-item scoring doesn't need next token logits
353
+ input_token_logprobs=input_token_logprobs,
354
+ input_top_logprobs_val=input_top_logprobs_val,
355
+ input_top_logprobs_idx=input_top_logprobs_idx,
356
+ input_token_ids_logprobs_val=input_token_ids_logprobs_val,
357
+ input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
358
+ )
359
+
235
360
  def forward(
236
361
  self,
237
362
  input_ids,
@@ -242,6 +367,16 @@ class LogitsProcessor(nn.Module):
242
367
  ) -> LogitsProcessorOutput:
243
368
  if isinstance(logits_metadata, ForwardBatch):
244
369
  logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
370
+
371
+ # Check if multi-item scoring is enabled via server args (only for prefill-only requests)
372
+ multi_item_delimiter = global_server_args_dict.get(
373
+ "multi_item_scoring_delimiter"
374
+ )
375
+ if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
376
+ return self.compute_logprobs_for_multi_item_scoring(
377
+ input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
378
+ )
379
+
245
380
  # Get the last hidden states and last logits for the next token prediction
246
381
  if (
247
382
  logits_metadata.forward_mode.is_decode_or_idle()
@@ -441,13 +576,17 @@ class LogitsProcessor(nn.Module):
441
576
  if self.do_tensor_parallel_all_gather_dp_attn:
442
577
  logits_metadata.compute_dp_attention_metadata()
443
578
  hidden_states, local_hidden_states = (
444
- get_global_dp_buffer(),
579
+ logits_metadata.gathered_buffer,
445
580
  hidden_states,
446
581
  )
447
582
  dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
448
583
 
449
584
  if hasattr(lm_head, "weight"):
450
- if use_intel_amx_backend(lm_head):
585
+ if self.use_fp32_lm_head:
586
+ logits = torch.matmul(
587
+ hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T
588
+ )
589
+ elif use_intel_amx_backend(lm_head):
451
590
  logits = torch.ops.sgl_kernel.weight_packed_linear(
452
591
  hidden_states.to(lm_head.weight.dtype),
453
592
  lm_head.weight,
@@ -461,7 +600,15 @@ class LogitsProcessor(nn.Module):
461
600
  else:
462
601
  # GGUF models
463
602
  # TODO: use weight_packed_linear for GGUF models
464
- logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
603
+ if self.use_fp32_lm_head:
604
+ with torch.cuda.amp.autocast(enabled=False):
605
+ logits = lm_head.quant_method.apply(
606
+ lm_head, hidden_states.to(torch.float32), embedding_bias
607
+ )
608
+ else:
609
+ logits = lm_head.quant_method.apply(
610
+ lm_head, hidden_states, embedding_bias
611
+ )
465
612
 
466
613
  if self.logit_scale is not None:
467
614
  logits.mul_(self.logit_scale)
@@ -517,7 +664,12 @@ class LogitsProcessor(nn.Module):
517
664
  logits = logits[:, : self.config.vocab_size].float()
518
665
 
519
666
  if self.final_logit_softcapping:
520
- fused_softcap(logits, self.final_logit_softcapping)
667
+ if not _is_npu:
668
+ fused_softcap(logits, self.final_logit_softcapping)
669
+ else:
670
+ logits = self.final_logit_softcapping * torch.tanh(
671
+ logits / self.final_logit_softcapping
672
+ )
521
673
 
522
674
  return logits
523
675
 
@@ -552,7 +704,9 @@ class LogitsProcessor(nn.Module):
552
704
 
553
705
  @staticmethod
554
706
  def get_token_ids_logprobs(
555
- all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
707
+ all_logprobs: torch.Tensor,
708
+ logits_metadata: LogitsMetadata,
709
+ delay_cpu_copy: bool = False,
556
710
  ):
557
711
  input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
558
712
  pt = 0
@@ -565,9 +719,17 @@ class LogitsProcessor(nn.Module):
565
719
  input_token_ids_logprobs_idx.append([])
566
720
  continue
567
721
 
568
- input_token_ids_logprobs_val.append(
569
- [all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
570
- )
722
+ position_logprobs = all_logprobs[
723
+ pt : pt + pruned_len, token_ids
724
+ ] # Shape: [pruned_len, num_tokens]
725
+
726
+ if delay_cpu_copy:
727
+ # Keep as tensor to delay GPU-to-CPU transfer
728
+ input_token_ids_logprobs_val.append(position_logprobs)
729
+ else:
730
+ # Convert to list immediately (default behavior)
731
+ input_token_ids_logprobs_val.append(position_logprobs.tolist())
732
+
571
733
  input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
572
734
  pt += pruned_len
573
735
 
@@ -0,0 +1,11 @@
1
+ """
2
+ ModelOpt related constants
3
+ """
4
+
5
+ QUANT_CFG_CHOICES = {
6
+ "fp8": "FP8_DEFAULT_CFG",
7
+ "int4_awq": "INT4_AWQ_CFG", # TODO: add support for int4_awq
8
+ "w4a8_awq": "W4A8_AWQ_BETA_CFG", # TODO: add support for w4a8_awq
9
+ "nvfp4": "NVFP4_DEFAULT_CFG",
10
+ "nvfp4_awq": "NVFP4_AWQ_LITE_CFG", # TODO: add support for nvfp4_awq
11
+ }
@@ -1,4 +1,4 @@
1
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
1
+ from sglang.srt.layers.moe.moe_runner import MoeRunner, MoeRunnerConfig
2
2
  from sglang.srt.layers.moe.utils import (
3
3
  DeepEPMode,
4
4
  MoeA2ABackend,
@@ -17,6 +17,7 @@ from sglang.srt.layers.moe.utils import (
17
17
  __all__ = [
18
18
  "DeepEPMode",
19
19
  "MoeA2ABackend",
20
+ "MoeRunner",
20
21
  "MoeRunnerConfig",
21
22
  "MoeRunnerBackend",
22
23
  "initialize_moe_config",