sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -40,8 +40,9 @@ from sglang.srt.layers.moe import (
40
40
  get_moe_a2a_backend,
41
41
  should_use_flashinfer_cutlass_moe_fp4_allgather,
42
42
  )
43
- from sglang.srt.managers.schedule_batch import global_server_args_dict
44
43
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
+ from sglang.srt.server_args import get_global_server_args
45
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
45
46
  from sglang.srt.utils import (
46
47
  get_bool_env_var,
47
48
  is_cuda,
@@ -168,7 +169,7 @@ class LayerScatterModes:
168
169
 
169
170
 
170
171
  def enable_moe_dense_fully_dp():
171
- return global_server_args_dict["moe_dense_tp_size"] == 1
172
+ return get_global_server_args().moe_dense_tp_size == 1
172
173
 
173
174
 
174
175
  class LayerCommunicator:
@@ -211,6 +212,10 @@ class LayerCommunicator:
211
212
  )
212
213
  )
213
214
 
215
+ self._speculative_algo = SpeculativeAlgorithm.from_string(
216
+ get_global_server_args().speculative_algorithm
217
+ )
218
+
214
219
  def prepare_attn(
215
220
  self,
216
221
  hidden_states: torch.Tensor,
@@ -314,11 +319,10 @@ class LayerCommunicator:
314
319
  def should_fuse_mlp_allreduce_with_next_layer(
315
320
  self, forward_batch: ForwardBatch
316
321
  ) -> bool:
317
- speculative_algo = global_server_args_dict.get("speculative_algorithm", None)
318
322
  if (
319
323
  is_dp_attention_enabled()
320
- and speculative_algo is not None
321
- and speculative_algo.is_eagle()
324
+ and self._speculative_algo is not None
325
+ and self._speculative_algo.is_eagle()
322
326
  ):
323
327
  return False
324
328
 
@@ -333,7 +337,8 @@ class LayerCommunicator:
333
337
  static_conditions_met = (
334
338
  (not self.is_last_layer)
335
339
  and (self._context.tp_size > 1)
336
- and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
340
+ and not is_dp_attention_enabled()
341
+ and get_global_server_args().enable_flashinfer_allreduce_fusion
337
342
  and _is_flashinfer_available
338
343
  )
339
344
 
@@ -531,7 +536,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
531
536
  (_is_sm100_supported or _is_sm90_supported)
532
537
  and _is_flashinfer_available
533
538
  and hasattr(layernorm, "forward_with_allreduce_fusion")
534
- and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
539
+ and get_global_server_args().enable_flashinfer_allreduce_fusion
535
540
  and hidden_states.shape[0] <= 4096
536
541
  ):
537
542
  hidden_states, residual = layernorm.forward_with_allreduce_fusion(
@@ -7,11 +7,10 @@ from typing import Dict, List, Tuple
7
7
  import torch
8
8
  from tqdm import tqdm
9
9
 
10
- from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
11
- ENABLE_JIT_DEEPGEMM,
12
- )
10
+ from sglang.srt.environ import envs
11
+ from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM
13
12
  from sglang.srt.server_args import ServerArgs
14
- from sglang.srt.utils import ceil_div, get_bool_env_var, get_int_env_var
13
+ from sglang.srt.utils import ceil_div, get_bool_env_var
15
14
 
16
15
  logger = logging.getLogger(__name__)
17
16
 
@@ -20,17 +19,14 @@ if ENABLE_JIT_DEEPGEMM:
20
19
 
21
20
 
22
21
  _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
23
- _ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
24
- "SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
25
- )
22
+ _ENABLE_JIT_DEEPGEMM_PRECOMPILE = envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.get()
26
23
  _DO_COMPILE_ALL = True
27
24
  _IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
28
- _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
29
25
  _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
30
26
 
31
27
  # Force redirect deep_gemm cache_dir
32
28
  os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
33
- "SGL_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
29
+ "SGLANG_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
34
30
  )
35
31
 
36
32
  # Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
@@ -1,6 +1,7 @@
1
1
  import logging
2
2
 
3
- from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
3
+ from sglang.srt.environ import envs
4
+ from sglang.srt.utils import get_device_sm, is_blackwell
4
5
 
5
6
  logger = logging.getLogger(__name__)
6
7
 
@@ -11,11 +12,11 @@ def _compute_enable_deep_gemm():
11
12
  return False
12
13
 
13
14
  try:
14
- import deep_gemm
15
+ import deep_gemm # noqa: F401
15
16
  except ImportError:
16
17
  return False
17
18
 
18
- return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
19
+ return envs.SGLANG_ENABLE_JIT_DEEPGEMM.get()
19
20
 
20
21
 
21
22
  ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
@@ -4,8 +4,8 @@ from typing import Tuple
4
4
 
5
5
  import torch
6
6
 
7
- from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils
8
- from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
7
+ from sglang.srt.layers.deep_gemm_wrapper import compile_utils
8
+ from sglang.srt.layers.deep_gemm_wrapper.configurer import ( # noqa: F401
9
9
  DEEPGEMM_BLACKWELL,
10
10
  DEEPGEMM_SCALE_UE8M0,
11
11
  ENABLE_JIT_DEEPGEMM,
@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
17
17
 
18
18
  if ENABLE_JIT_DEEPGEMM:
19
19
  import deep_gemm
20
- from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
20
+ from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor # noqa: F401
21
21
 
22
22
  _SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
23
23
 
@@ -87,6 +87,7 @@ class _DpGatheredBufferWrapper:
87
87
  _global_dp_buffer_len: int
88
88
  _local_dp_buffer_len: int
89
89
  _global_num_tokens: Optional[List[int]]
90
+ _is_extend_in_batch: bool
90
91
 
91
92
  @classmethod
92
93
  def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
@@ -145,6 +146,14 @@ class _DpGatheredBufferWrapper:
145
146
  def get_dp_device(cls) -> torch.device:
146
147
  return cls._device
147
148
 
149
+ @classmethod
150
+ def set_is_extend_in_batch(cls, is_extend_in_batch: bool):
151
+ cls._is_extend_in_batch = is_extend_in_batch
152
+
153
+ @classmethod
154
+ def get_is_extend_in_batch(cls) -> bool:
155
+ return cls._is_extend_in_batch
156
+
148
157
 
149
158
  def set_dp_buffer_len(
150
159
  global_dp_buffer_len: int,
@@ -188,6 +197,14 @@ def get_dp_device() -> torch.device:
188
197
  return _DpGatheredBufferWrapper.get_dp_device()
189
198
 
190
199
 
200
+ def set_is_extend_in_batch(is_extend_in_batch: bool):
201
+ _DpGatheredBufferWrapper.set_is_extend_in_batch(is_extend_in_batch)
202
+
203
+
204
+ def get_is_extend_in_batch() -> bool:
205
+ return _DpGatheredBufferWrapper.get_is_extend_in_batch()
206
+
207
+
191
208
  def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
192
209
  if not enable_dp_attention:
193
210
  return tp_rank, tp_size, 0
@@ -42,13 +42,16 @@ _is_cpu_amx_available = cpu_has_amx_support()
42
42
  _is_cpu = is_cpu()
43
43
  _is_xpu = is_xpu()
44
44
 
45
- if _is_cuda:
46
- if _is_flashinfer_available:
47
- from flashinfer.norm import fused_add_rmsnorm
48
- else:
49
- from sgl_kernel import fused_add_rmsnorm
50
- from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
51
-
45
+ if _is_cuda or _is_xpu:
46
+ # if _is_flashinfer_available:
47
+ # from flashinfer.norm import fused_add_rmsnorm
48
+ # else:
49
+ from sgl_kernel import (
50
+ fused_add_rmsnorm,
51
+ gemma_fused_add_rmsnorm,
52
+ gemma_rmsnorm,
53
+ rmsnorm,
54
+ )
52
55
  if _use_aiter:
53
56
  from aiter import rmsnorm2d_fwd as rms_norm
54
57
  from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
@@ -70,9 +73,16 @@ class RMSNorm(CustomOp):
70
73
  hidden_size: int,
71
74
  eps: float = 1e-6,
72
75
  var_hidden_size: Optional[int] = None,
76
+ cast_x_before_out_mul: bool = False,
77
+ fp32_residual: bool = False,
78
+ weight_dtype: Optional = None,
79
+ override_orig_dtype: Optional = None,
73
80
  ) -> None:
74
81
  super().__init__()
75
- self.weight = nn.Parameter(torch.ones(hidden_size))
82
+ self.cast_x_before_out_mul = cast_x_before_out_mul
83
+ self.fp32_residual = fp32_residual
84
+ self.override_orig_dtype = override_orig_dtype
85
+ self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype))
76
86
  self.variance_epsilon = eps
77
87
  self.hidden_size = hidden_size
78
88
  self.variance_size_override = (
@@ -162,11 +172,14 @@ class RMSNorm(CustomOp):
162
172
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
163
173
  if not x.is_contiguous():
164
174
  x = x.contiguous()
165
- orig_dtype = x.dtype
175
+ orig_dtype = self.override_orig_dtype or x.dtype
166
176
  x = x.to(torch.float32)
167
177
  if residual is not None:
168
178
  x = x + residual.to(torch.float32)
169
- residual = x.to(orig_dtype)
179
+ if self.fp32_residual:
180
+ residual = x.clone()
181
+ else:
182
+ residual = x.to(orig_dtype)
170
183
 
171
184
  hidden_size = x.shape[-1]
172
185
  if hidden_size != self.hidden_size:
@@ -188,7 +201,12 @@ class RMSNorm(CustomOp):
188
201
 
189
202
  variance = x_var.pow(2).mean(dim=-1, keepdim=True)
190
203
  x = x * torch.rsqrt(variance + self.variance_epsilon)
191
- x = (x * self.weight).to(orig_dtype)
204
+
205
+ if self.cast_x_before_out_mul:
206
+ x = self.weight * x.to(orig_dtype)
207
+ else:
208
+ x = (x * self.weight).to(orig_dtype)
209
+
192
210
  if residual is None:
193
211
  return x
194
212
  else:
@@ -211,6 +229,19 @@ class RMSNorm(CustomOp):
211
229
  else:
212
230
  return self.forward_native(x, residual)
213
231
 
232
+ def forward_xpu(
233
+ self,
234
+ x: torch.Tensor,
235
+ residual: Optional[torch.Tensor] = None,
236
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
237
+ if self.variance_size_override is not None:
238
+ return self.forward_native(x, residual)
239
+ if residual is not None:
240
+ fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
241
+ return x, residual
242
+ out = rmsnorm(x, self.weight.data, self.variance_epsilon)
243
+ return out
244
+
214
245
  def forward_with_allreduce_fusion(
215
246
  self,
216
247
  x: torch.Tensor,
@@ -258,6 +289,19 @@ class GemmaRMSNorm(CustomOp):
258
289
  if _is_hip:
259
290
  self._forward_method = self.forward_native
260
291
 
292
+ def _forward_impl(
293
+ self,
294
+ x: torch.Tensor,
295
+ residual: Optional[torch.Tensor] = None,
296
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
297
+ if residual is not None:
298
+ gemma_fused_add_rmsnorm(
299
+ x, residual, self.weight.data, self.variance_epsilon
300
+ )
301
+ return x, residual
302
+ out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
303
+ return out
304
+
261
305
  def forward_native(
262
306
  self,
263
307
  x: torch.Tensor,
@@ -280,13 +324,7 @@ class GemmaRMSNorm(CustomOp):
280
324
  x: torch.Tensor,
281
325
  residual: Optional[torch.Tensor] = None,
282
326
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
283
- if residual is not None:
284
- gemma_fused_add_rmsnorm(
285
- x, residual, self.weight.data, self.variance_epsilon
286
- )
287
- return x, residual
288
- out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
289
- return out
327
+ return self._forward_impl(x, residual)
290
328
 
291
329
  def forward_npu(
292
330
  self,
@@ -300,6 +338,13 @@ class GemmaRMSNorm(CustomOp):
300
338
  x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
301
339
  return x if residual is None else (x, residual)
302
340
 
341
+ def forward_xpu(
342
+ self,
343
+ x: torch.Tensor,
344
+ residual: Optional[torch.Tensor] = None,
345
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
346
+ return self._forward_impl(x, residual)
347
+
303
348
 
304
349
  class Gemma3RMSNorm(CustomOp):
305
350
  def __init__(self, dim: int, eps: float = 1e-6):
@@ -335,4 +380,4 @@ if not (
335
380
  logger.info(
336
381
  "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
337
382
  )
338
- from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
383
+ from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm # noqa: F401
@@ -32,7 +32,7 @@ from sglang.srt.layers.parameter import (
32
32
  )
33
33
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
34
34
  from sglang.srt.layers.utils import pad_or_narrow_weight
35
- from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
35
+ from sglang.srt.utils import get_bool_env_var, is_cpu, is_hip, is_npu, set_weight_attrs
36
36
 
37
37
  if TYPE_CHECKING:
38
38
  from sglang.srt.layers.quantization.base_config import (
@@ -40,12 +40,18 @@ if TYPE_CHECKING:
40
40
  QuantizeMethodBase,
41
41
  )
42
42
 
43
+ _is_hip = is_hip()
44
+ _disable_hip_linear_quant = _is_hip and get_bool_env_var(
45
+ "SGLANG_ROCM_DISABLE_LINEARQUANT"
46
+ )
47
+
43
48
  logger = logging.getLogger(__name__)
44
49
 
45
50
  WEIGHT_LOADER_V2_SUPPORTED = [
46
51
  "CompressedTensorsLinearMethod",
47
52
  "AWQMarlinLinearMethod",
48
53
  "AWQLinearMethod",
54
+ "AWQLinearAscendMethod",
49
55
  "GPTQMarlinLinearMethod",
50
56
  "Fp8LinearMethod",
51
57
  "BlockInt8LinearMethod",
@@ -824,6 +830,7 @@ class QKVParallelLinear(ColumnParallelLinear):
824
830
  self.num_kv_heads * self.head_size * tp_size, # v_proj
825
831
  ]
826
832
  self.use_presharded_weights = load_presharded_attn
833
+ quant_config = None if _disable_hip_linear_quant else quant_config
827
834
 
828
835
  super().__init__(
829
836
  input_size=input_size,
@@ -1225,6 +1232,7 @@ class RowParallelLinear(LinearBase):
1225
1232
  tp_size: Optional[int] = None,
1226
1233
  use_presharded_weights: bool = False,
1227
1234
  ):
1235
+ quant_config = None if _disable_hip_linear_quant else quant_config
1228
1236
  super().__init__(
1229
1237
  input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
1230
1238
  )
@@ -38,17 +38,15 @@ from sglang.srt.layers.dp_attention import (
38
38
  get_dp_device,
39
39
  get_dp_dtype,
40
40
  get_dp_hidden_size,
41
- get_global_dp_buffer,
42
41
  get_local_attention_dp_size,
43
- set_dp_buffer_len,
44
42
  )
45
43
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
46
- from sglang.srt.managers.schedule_batch import global_server_args_dict
47
44
  from sglang.srt.model_executor.forward_batch_info import (
48
45
  CaptureHiddenMode,
49
46
  ForwardBatch,
50
47
  ForwardMode,
51
48
  )
49
+ from sglang.srt.server_args import get_global_server_args
52
50
  from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
53
51
 
54
52
  logger = logging.getLogger(__name__)
@@ -60,13 +58,14 @@ _is_npu = is_npu()
60
58
  class LogitsProcessorOutput:
61
59
  ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
62
60
  # The logits of the next tokens. shape: [#seq, vocab_size]
63
- next_token_logits: torch.Tensor
61
+ # Can be None for certain prefill-only requests (e.g., multi-item scoring) that don't need next token generation
62
+ next_token_logits: Optional[torch.Tensor]
64
63
  # Used by speculative decoding (EAGLE)
65
64
  # The last hidden layers
66
65
  hidden_states: Optional[torch.Tensor] = None
67
66
 
68
67
  ## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
69
- # he log probs of output tokens, if RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
68
+ # he log probs of output tokens, if SGLANG_RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
70
69
  next_token_logprobs: Optional[torch.Tensor] = None
71
70
  # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
72
71
  next_token_top_logprobs_val: Optional[List] = None
@@ -85,7 +84,10 @@ class LogitsProcessorOutput:
85
84
  input_top_logprobs_val: List = None
86
85
  input_top_logprobs_idx: List = None
87
86
  # The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
88
- input_token_ids_logprobs_val: Optional[List] = None
87
+ # Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization)
88
+ input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = (
89
+ None
90
+ )
89
91
  input_token_ids_logprobs_idx: Optional[List] = None
90
92
 
91
93
 
@@ -127,10 +129,16 @@ class LogitsMetadata:
127
129
  # for padding
128
130
  padded_static_len: int = -1
129
131
 
132
+ # Whether this batch is prefill-only (no token generation needed)
133
+ is_prefill_only: bool = False
134
+
130
135
  @classmethod
131
136
  def from_forward_batch(cls, forward_batch: ForwardBatch):
132
137
  if (
133
- forward_batch.forward_mode.is_extend()
138
+ (
139
+ forward_batch.forward_mode.is_extend()
140
+ or forward_batch.forward_mode.is_split_prefill()
141
+ )
134
142
  and forward_batch.return_logprob
135
143
  and not forward_batch.forward_mode.is_target_verify()
136
144
  ):
@@ -169,6 +177,7 @@ class LogitsMetadata:
169
177
  token_ids_logprobs=forward_batch.token_ids_logprobs,
170
178
  extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
171
179
  padded_static_len=forward_batch.padded_static_len,
180
+ is_prefill_only=forward_batch.is_prefill_only,
172
181
  global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
173
182
  dp_local_start_pos=forward_batch.dp_local_start_pos,
174
183
  dp_local_num_tokens=forward_batch.dp_local_num_tokens,
@@ -219,8 +228,8 @@ class LogitsProcessor(nn.Module):
219
228
  super().__init__()
220
229
  self.config = config
221
230
  self.logit_scale = logit_scale
222
- self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
223
- self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
231
+ self.use_attn_tp_group = get_global_server_args().enable_dp_lm_head
232
+ self.use_fp32_lm_head = get_global_server_args().enable_fp32_lm_head
224
233
  if self.use_attn_tp_group:
225
234
  self.attn_tp_size = get_attention_tp_size()
226
235
  self.do_tensor_parallel_all_gather = (
@@ -243,8 +252,110 @@ class LogitsProcessor(nn.Module):
243
252
  ):
244
253
  self.final_logit_softcapping = None
245
254
 
246
- self.debug_tensor_dump_output_folder = global_server_args_dict.get(
247
- "debug_tensor_dump_output_folder", None
255
+ self.debug_tensor_dump_output_folder = (
256
+ get_global_server_args().debug_tensor_dump_output_folder
257
+ )
258
+
259
+ def compute_logprobs_for_multi_item_scoring(
260
+ self,
261
+ input_ids,
262
+ hidden_states,
263
+ lm_head: VocabParallelEmbedding,
264
+ logits_metadata: Union[LogitsMetadata, ForwardBatch],
265
+ delimiter_token: int,
266
+ ):
267
+ """
268
+ Compute logprobs for multi-item scoring using delimiter-based token extraction.
269
+
270
+ This method is designed for scenarios where you want to score multiple items/candidates
271
+ against a single query by combining them into one sequence separated by delimiters.
272
+
273
+ Sequence format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
274
+ Scoring positions: Extracts logprobs at positions before each <delimiter>
275
+
276
+ Args:
277
+ input_ids (torch.Tensor): Input token IDs containing query and items separated by delimiters.
278
+ Shape: [total_sequence_length] for single request or [batch_total_length] for batch.
279
+ hidden_states (torch.Tensor): Hidden states from the model.
280
+ Shape: [sequence_length, hidden_dim].
281
+ lm_head (VocabParallelEmbedding): Language model head for computing logits.
282
+ logits_metadata (Union[LogitsMetadata, ForwardBatch]): Metadata containing batch info
283
+ and token ID specifications for logprob extraction.
284
+ delimiter_token (int): Token ID used as delimiter between query and items.
285
+
286
+ Returns:
287
+ LogitsProcessorOutput: Contains:
288
+ - next_token_logits: None (not needed for scoring-only requests)
289
+ - input_token_logprobs: Logprobs of delimiter tokens at scoring positions
290
+ - input_top_logprobs_val: Top-k logprobs at delimiter positions (if requested)
291
+ - input_top_logprobs_idx: Top-k token indices at delimiter positions (if requested)
292
+ - input_token_ids_logprobs_val: Logprobs for user-requested token IDs (if any)
293
+ - input_token_ids_logprobs_idx: Indices for user-requested token IDs (if any)
294
+ """
295
+ multi_item_indices = (input_ids == delimiter_token).nonzero(as_tuple=True)[
296
+ 0
297
+ ] - 1
298
+ # Extract hidden states at delimiter positions for multi-item scoring
299
+ sliced_hidden = hidden_states[multi_item_indices]
300
+
301
+ sliced_logits = self._get_logits(sliced_hidden, lm_head, logits_metadata)
302
+ sliced_logprobs = torch.nn.functional.log_softmax(sliced_logits, dim=-1)
303
+
304
+ # Initialize return values
305
+ input_token_ids_logprobs_val = []
306
+ input_token_ids_logprobs_idx = []
307
+ input_top_logprobs_val = None
308
+ input_top_logprobs_idx = None
309
+
310
+ # Recalculate extend_logprob_pruned_lens_cpu to match delimiter counts per request
311
+ # Original contains sequence lengths, but we need delimiter counts for sliced_logprobs
312
+ if (
313
+ logits_metadata.token_ids_logprobs
314
+ or logits_metadata.extend_return_top_logprob
315
+ ):
316
+ logits_metadata.extend_logprob_pruned_lens_cpu = []
317
+
318
+ if logits_metadata.extend_seq_lens_cpu is not None:
319
+ # Multi-request batch: count delimiters per request
320
+ input_pt = 0
321
+ for req_seq_len in logits_metadata.extend_seq_lens_cpu:
322
+ req_input_ids = input_ids[input_pt : input_pt + req_seq_len]
323
+ delimiter_count = (req_input_ids == delimiter_token).sum().item()
324
+ logits_metadata.extend_logprob_pruned_lens_cpu.append(
325
+ delimiter_count
326
+ )
327
+ input_pt += req_seq_len
328
+ else:
329
+ # Single request case: one request gets all delimiters
330
+ total_delimiters = (input_ids == delimiter_token).sum().item()
331
+ logits_metadata.extend_logprob_pruned_lens_cpu = [total_delimiters]
332
+
333
+ # Get the logprobs of specified token ids
334
+ if logits_metadata.extend_token_ids_logprob:
335
+ (
336
+ input_token_ids_logprobs_val,
337
+ input_token_ids_logprobs_idx,
338
+ ) = self.get_token_ids_logprobs(
339
+ sliced_logprobs, logits_metadata, delay_cpu_copy=True
340
+ )
341
+
342
+ # Get the logprob of top-k tokens
343
+ if logits_metadata.extend_return_top_logprob:
344
+ (
345
+ input_top_logprobs_val,
346
+ input_top_logprobs_idx,
347
+ ) = self.get_top_logprobs(sliced_logprobs, logits_metadata)
348
+
349
+ # For input_token_logprobs, use delimiter token logprobs
350
+ input_token_logprobs = sliced_logprobs[:, delimiter_token]
351
+
352
+ return LogitsProcessorOutput(
353
+ next_token_logits=None, # Multi-item scoring doesn't need next token logits
354
+ input_token_logprobs=input_token_logprobs,
355
+ input_top_logprobs_val=input_top_logprobs_val,
356
+ input_top_logprobs_idx=input_top_logprobs_idx,
357
+ input_token_ids_logprobs_val=input_token_ids_logprobs_val,
358
+ input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
248
359
  )
249
360
 
250
361
  def forward(
@@ -257,10 +368,19 @@ class LogitsProcessor(nn.Module):
257
368
  ) -> LogitsProcessorOutput:
258
369
  if isinstance(logits_metadata, ForwardBatch):
259
370
  logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
371
+
372
+ # Check if multi-item scoring is enabled via server args (only for prefill-only requests)
373
+ multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter
374
+ if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
375
+ return self.compute_logprobs_for_multi_item_scoring(
376
+ input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
377
+ )
378
+
260
379
  # Get the last hidden states and last logits for the next token prediction
261
380
  if (
262
381
  logits_metadata.forward_mode.is_decode_or_idle()
263
382
  or logits_metadata.forward_mode.is_target_verify()
383
+ or logits_metadata.forward_mode.is_draft_extend_v2()
264
384
  ):
265
385
  pruned_states = hidden_states
266
386
  if aux_hidden_states is not None:
@@ -269,8 +389,8 @@ class LogitsProcessor(nn.Module):
269
389
  input_logprob_indices = None
270
390
  elif (
271
391
  logits_metadata.forward_mode.is_extend()
272
- and not logits_metadata.extend_return_logprob
273
- ):
392
+ or logits_metadata.forward_mode.is_split_prefill()
393
+ ) and not logits_metadata.extend_return_logprob:
274
394
  # Prefill without input logprobs.
275
395
  if logits_metadata.padded_static_len < 0:
276
396
  last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
@@ -473,6 +593,11 @@ class LogitsProcessor(nn.Module):
473
593
  None, # bias
474
594
  True, # is_vnni
475
595
  )
596
+ elif get_global_server_args().rl_on_policy_target == "fsdp":
597
+ # Due to tie-weight, we may not be able to change lm_head's weight dtype
598
+ logits = torch.matmul(
599
+ hidden_states.bfloat16(), lm_head.weight.T.bfloat16()
600
+ )
476
601
  else:
477
602
  logits = torch.matmul(
478
603
  hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
@@ -584,7 +709,9 @@ class LogitsProcessor(nn.Module):
584
709
 
585
710
  @staticmethod
586
711
  def get_token_ids_logprobs(
587
- all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
712
+ all_logprobs: torch.Tensor,
713
+ logits_metadata: LogitsMetadata,
714
+ delay_cpu_copy: bool = False,
588
715
  ):
589
716
  input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
590
717
  pt = 0
@@ -597,9 +724,17 @@ class LogitsProcessor(nn.Module):
597
724
  input_token_ids_logprobs_idx.append([])
598
725
  continue
599
726
 
600
- input_token_ids_logprobs_val.append(
601
- [all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
602
- )
727
+ position_logprobs = all_logprobs[
728
+ pt : pt + pruned_len, token_ids
729
+ ] # Shape: [pruned_len, num_tokens]
730
+
731
+ if delay_cpu_copy:
732
+ # Keep as tensor to delay GPU-to-CPU transfer
733
+ input_token_ids_logprobs_val.append(position_logprobs)
734
+ else:
735
+ # Convert to list immediately (default behavior)
736
+ input_token_ids_logprobs_val.append(position_logprobs.tolist())
737
+
603
738
  input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
604
739
  pt += pruned_len
605
740
 
@@ -0,0 +1,11 @@
1
+ """
2
+ ModelOpt related constants
3
+ """
4
+
5
+ QUANT_CFG_CHOICES = {
6
+ "fp8": "FP8_DEFAULT_CFG",
7
+ "int4_awq": "INT4_AWQ_CFG", # TODO: add support for int4_awq
8
+ "w4a8_awq": "W4A8_AWQ_BETA_CFG", # TODO: add support for w4a8_awq
9
+ "nvfp4": "NVFP4_DEFAULT_CFG",
10
+ "nvfp4_awq": "NVFP4_AWQ_LITE_CFG", # TODO: add support for nvfp4_awq
11
+ }
@@ -116,8 +116,6 @@ def cutlass_fused_experts_fp8(
116
116
 
117
117
  if is_cuda:
118
118
  from sglang.srt.layers.quantization.fp8_kernel import (
119
- per_group_transpose,
120
- per_token_group_quant_fp8_hopper_moe_mn_major,
121
119
  sglang_per_token_group_quant_fp8,
122
120
  )
123
121