sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -40,8 +40,9 @@ from sglang.srt.layers.moe import (
40
40
  get_moe_a2a_backend,
41
41
  should_use_flashinfer_cutlass_moe_fp4_allgather,
42
42
  )
43
- from sglang.srt.managers.schedule_batch import global_server_args_dict
44
43
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
+ from sglang.srt.server_args import get_global_server_args
45
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
45
46
  from sglang.srt.utils import (
46
47
  get_bool_env_var,
47
48
  is_cuda,
@@ -168,7 +169,7 @@ class LayerScatterModes:
168
169
 
169
170
 
170
171
  def enable_moe_dense_fully_dp():
171
- return global_server_args_dict["moe_dense_tp_size"] == 1
172
+ return get_global_server_args().moe_dense_tp_size == 1
172
173
 
173
174
 
174
175
  class LayerCommunicator:
@@ -211,6 +212,10 @@ class LayerCommunicator:
211
212
  )
212
213
  )
213
214
 
215
+ self._speculative_algo = SpeculativeAlgorithm.from_string(
216
+ get_global_server_args().speculative_algorithm
217
+ )
218
+
214
219
  def prepare_attn(
215
220
  self,
216
221
  hidden_states: torch.Tensor,
@@ -314,11 +319,10 @@ class LayerCommunicator:
314
319
  def should_fuse_mlp_allreduce_with_next_layer(
315
320
  self, forward_batch: ForwardBatch
316
321
  ) -> bool:
317
- speculative_algo = global_server_args_dict.get("speculative_algorithm", None)
318
322
  if (
319
323
  is_dp_attention_enabled()
320
- and speculative_algo is not None
321
- and speculative_algo.is_eagle()
324
+ and self._speculative_algo is not None
325
+ and self._speculative_algo.is_eagle()
322
326
  ):
323
327
  return False
324
328
 
@@ -333,7 +337,7 @@ class LayerCommunicator:
333
337
  static_conditions_met = (
334
338
  (not self.is_last_layer)
335
339
  and (self._context.tp_size > 1)
336
- and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
340
+ and get_global_server_args().enable_flashinfer_allreduce_fusion
337
341
  and _is_flashinfer_available
338
342
  )
339
343
 
@@ -531,7 +535,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
531
535
  (_is_sm100_supported or _is_sm90_supported)
532
536
  and _is_flashinfer_available
533
537
  and hasattr(layernorm, "forward_with_allreduce_fusion")
534
- and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
538
+ and get_global_server_args().enable_flashinfer_allreduce_fusion
535
539
  and hidden_states.shape[0] <= 4096
536
540
  ):
537
541
  hidden_states, residual = layernorm.forward_with_allreduce_fusion(
@@ -7,11 +7,10 @@ from typing import Dict, List, Tuple
7
7
  import torch
8
8
  from tqdm import tqdm
9
9
 
10
- from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
11
- ENABLE_JIT_DEEPGEMM,
12
- )
10
+ from sglang.srt.environ import envs
11
+ from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM
13
12
  from sglang.srt.server_args import ServerArgs
14
- from sglang.srt.utils import ceil_div, get_bool_env_var, get_int_env_var
13
+ from sglang.srt.utils import ceil_div, get_bool_env_var
15
14
 
16
15
  logger = logging.getLogger(__name__)
17
16
 
@@ -20,12 +19,9 @@ if ENABLE_JIT_DEEPGEMM:
20
19
 
21
20
 
22
21
  _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
23
- _ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
24
- "SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
25
- )
22
+ _ENABLE_JIT_DEEPGEMM_PRECOMPILE = envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.get()
26
23
  _DO_COMPILE_ALL = True
27
24
  _IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
28
- _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
29
25
  _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
30
26
 
31
27
  # Force redirect deep_gemm cache_dir
@@ -1,6 +1,7 @@
1
1
  import logging
2
2
 
3
- from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
3
+ from sglang.srt.environ import envs
4
+ from sglang.srt.utils import get_device_sm, is_blackwell
4
5
 
5
6
  logger = logging.getLogger(__name__)
6
7
 
@@ -11,11 +12,11 @@ def _compute_enable_deep_gemm():
11
12
  return False
12
13
 
13
14
  try:
14
- import deep_gemm
15
+ import deep_gemm # noqa: F401
15
16
  except ImportError:
16
17
  return False
17
18
 
18
- return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
19
+ return envs.SGLANG_ENABLE_JIT_DEEPGEMM.get()
19
20
 
20
21
 
21
22
  ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
@@ -4,8 +4,8 @@ from typing import Tuple
4
4
 
5
5
  import torch
6
6
 
7
- from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils
8
- from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
7
+ from sglang.srt.layers.deep_gemm_wrapper import compile_utils
8
+ from sglang.srt.layers.deep_gemm_wrapper.configurer import ( # noqa: F401
9
9
  DEEPGEMM_BLACKWELL,
10
10
  DEEPGEMM_SCALE_UE8M0,
11
11
  ENABLE_JIT_DEEPGEMM,
@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
17
17
 
18
18
  if ENABLE_JIT_DEEPGEMM:
19
19
  import deep_gemm
20
- from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
20
+ from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor # noqa: F401
21
21
 
22
22
  _SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
23
23
 
@@ -87,6 +87,7 @@ class _DpGatheredBufferWrapper:
87
87
  _global_dp_buffer_len: int
88
88
  _local_dp_buffer_len: int
89
89
  _global_num_tokens: Optional[List[int]]
90
+ _is_extend_in_batch: bool
90
91
 
91
92
  @classmethod
92
93
  def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
@@ -145,6 +146,14 @@ class _DpGatheredBufferWrapper:
145
146
  def get_dp_device(cls) -> torch.device:
146
147
  return cls._device
147
148
 
149
+ @classmethod
150
+ def set_is_extend_in_batch(cls, is_extend_in_batch: bool):
151
+ cls._is_extend_in_batch = is_extend_in_batch
152
+
153
+ @classmethod
154
+ def get_is_extend_in_batch(cls) -> bool:
155
+ return cls._is_extend_in_batch
156
+
148
157
 
149
158
  def set_dp_buffer_len(
150
159
  global_dp_buffer_len: int,
@@ -188,6 +197,14 @@ def get_dp_device() -> torch.device:
188
197
  return _DpGatheredBufferWrapper.get_dp_device()
189
198
 
190
199
 
200
+ def set_is_extend_in_batch(is_extend_in_batch: bool):
201
+ _DpGatheredBufferWrapper.set_is_extend_in_batch(is_extend_in_batch)
202
+
203
+
204
+ def get_is_extend_in_batch() -> bool:
205
+ return _DpGatheredBufferWrapper.get_is_extend_in_batch()
206
+
207
+
191
208
  def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
192
209
  if not enable_dp_attention:
193
210
  return tp_rank, tp_size, 0
@@ -42,13 +42,16 @@ _is_cpu_amx_available = cpu_has_amx_support()
42
42
  _is_cpu = is_cpu()
43
43
  _is_xpu = is_xpu()
44
44
 
45
- if _is_cuda:
46
- if _is_flashinfer_available:
47
- from flashinfer.norm import fused_add_rmsnorm
48
- else:
49
- from sgl_kernel import fused_add_rmsnorm
50
- from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
51
-
45
+ if _is_cuda or _is_xpu:
46
+ # if _is_flashinfer_available:
47
+ # from flashinfer.norm import fused_add_rmsnorm
48
+ # else:
49
+ from sgl_kernel import (
50
+ fused_add_rmsnorm,
51
+ gemma_fused_add_rmsnorm,
52
+ gemma_rmsnorm,
53
+ rmsnorm,
54
+ )
52
55
  if _use_aiter:
53
56
  from aiter import rmsnorm2d_fwd as rms_norm
54
57
  from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
@@ -211,6 +214,19 @@ class RMSNorm(CustomOp):
211
214
  else:
212
215
  return self.forward_native(x, residual)
213
216
 
217
+ def forward_xpu(
218
+ self,
219
+ x: torch.Tensor,
220
+ residual: Optional[torch.Tensor] = None,
221
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
222
+ if self.variance_size_override is not None:
223
+ return self.forward_native(x, residual)
224
+ if residual is not None:
225
+ fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
226
+ return x, residual
227
+ out = rmsnorm(x, self.weight.data, self.variance_epsilon)
228
+ return out
229
+
214
230
  def forward_with_allreduce_fusion(
215
231
  self,
216
232
  x: torch.Tensor,
@@ -258,6 +274,19 @@ class GemmaRMSNorm(CustomOp):
258
274
  if _is_hip:
259
275
  self._forward_method = self.forward_native
260
276
 
277
+ def _forward_impl(
278
+ self,
279
+ x: torch.Tensor,
280
+ residual: Optional[torch.Tensor] = None,
281
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
282
+ if residual is not None:
283
+ gemma_fused_add_rmsnorm(
284
+ x, residual, self.weight.data, self.variance_epsilon
285
+ )
286
+ return x, residual
287
+ out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
288
+ return out
289
+
261
290
  def forward_native(
262
291
  self,
263
292
  x: torch.Tensor,
@@ -280,13 +309,7 @@ class GemmaRMSNorm(CustomOp):
280
309
  x: torch.Tensor,
281
310
  residual: Optional[torch.Tensor] = None,
282
311
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
283
- if residual is not None:
284
- gemma_fused_add_rmsnorm(
285
- x, residual, self.weight.data, self.variance_epsilon
286
- )
287
- return x, residual
288
- out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
289
- return out
312
+ return self._forward_impl(x, residual)
290
313
 
291
314
  def forward_npu(
292
315
  self,
@@ -300,6 +323,13 @@ class GemmaRMSNorm(CustomOp):
300
323
  x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
301
324
  return x if residual is None else (x, residual)
302
325
 
326
+ def forward_xpu(
327
+ self,
328
+ x: torch.Tensor,
329
+ residual: Optional[torch.Tensor] = None,
330
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
331
+ return self._forward_impl(x, residual)
332
+
303
333
 
304
334
  class Gemma3RMSNorm(CustomOp):
305
335
  def __init__(self, dim: int, eps: float = 1e-6):
@@ -335,4 +365,4 @@ if not (
335
365
  logger.info(
336
366
  "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
337
367
  )
338
- from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
368
+ from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm # noqa: F401
@@ -32,7 +32,7 @@ from sglang.srt.layers.parameter import (
32
32
  )
33
33
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
34
34
  from sglang.srt.layers.utils import pad_or_narrow_weight
35
- from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
35
+ from sglang.srt.utils import get_bool_env_var, is_cpu, is_hip, is_npu, set_weight_attrs
36
36
 
37
37
  if TYPE_CHECKING:
38
38
  from sglang.srt.layers.quantization.base_config import (
@@ -40,12 +40,18 @@ if TYPE_CHECKING:
40
40
  QuantizeMethodBase,
41
41
  )
42
42
 
43
+ _is_hip = is_hip()
44
+ _disable_hip_linear_quant = _is_hip and get_bool_env_var(
45
+ "SGLANG_ROCM_DISABLE_LINEARQUANT"
46
+ )
47
+
43
48
  logger = logging.getLogger(__name__)
44
49
 
45
50
  WEIGHT_LOADER_V2_SUPPORTED = [
46
51
  "CompressedTensorsLinearMethod",
47
52
  "AWQMarlinLinearMethod",
48
53
  "AWQLinearMethod",
54
+ "AWQLinearAscendMethod",
49
55
  "GPTQMarlinLinearMethod",
50
56
  "Fp8LinearMethod",
51
57
  "BlockInt8LinearMethod",
@@ -824,6 +830,7 @@ class QKVParallelLinear(ColumnParallelLinear):
824
830
  self.num_kv_heads * self.head_size * tp_size, # v_proj
825
831
  ]
826
832
  self.use_presharded_weights = load_presharded_attn
833
+ quant_config = None if _disable_hip_linear_quant else quant_config
827
834
 
828
835
  super().__init__(
829
836
  input_size=input_size,
@@ -1225,6 +1232,7 @@ class RowParallelLinear(LinearBase):
1225
1232
  tp_size: Optional[int] = None,
1226
1233
  use_presharded_weights: bool = False,
1227
1234
  ):
1235
+ quant_config = None if _disable_hip_linear_quant else quant_config
1228
1236
  super().__init__(
1229
1237
  input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
1230
1238
  )
@@ -38,17 +38,15 @@ from sglang.srt.layers.dp_attention import (
38
38
  get_dp_device,
39
39
  get_dp_dtype,
40
40
  get_dp_hidden_size,
41
- get_global_dp_buffer,
42
41
  get_local_attention_dp_size,
43
- set_dp_buffer_len,
44
42
  )
45
43
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
46
- from sglang.srt.managers.schedule_batch import global_server_args_dict
47
44
  from sglang.srt.model_executor.forward_batch_info import (
48
45
  CaptureHiddenMode,
49
46
  ForwardBatch,
50
47
  ForwardMode,
51
48
  )
49
+ from sglang.srt.server_args import get_global_server_args
52
50
  from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
53
51
 
54
52
  logger = logging.getLogger(__name__)
@@ -60,13 +58,14 @@ _is_npu = is_npu()
60
58
  class LogitsProcessorOutput:
61
59
  ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
62
60
  # The logits of the next tokens. shape: [#seq, vocab_size]
63
- next_token_logits: torch.Tensor
61
+ # Can be None for certain prefill-only requests (e.g., multi-item scoring) that don't need next token generation
62
+ next_token_logits: Optional[torch.Tensor]
64
63
  # Used by speculative decoding (EAGLE)
65
64
  # The last hidden layers
66
65
  hidden_states: Optional[torch.Tensor] = None
67
66
 
68
67
  ## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
69
- # he log probs of output tokens, if RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
68
+ # he log probs of output tokens, if SGLANG_RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
70
69
  next_token_logprobs: Optional[torch.Tensor] = None
71
70
  # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
72
71
  next_token_top_logprobs_val: Optional[List] = None
@@ -85,7 +84,10 @@ class LogitsProcessorOutput:
85
84
  input_top_logprobs_val: List = None
86
85
  input_top_logprobs_idx: List = None
87
86
  # The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
88
- input_token_ids_logprobs_val: Optional[List] = None
87
+ # Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization)
88
+ input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = (
89
+ None
90
+ )
89
91
  input_token_ids_logprobs_idx: Optional[List] = None
90
92
 
91
93
 
@@ -127,10 +129,16 @@ class LogitsMetadata:
127
129
  # for padding
128
130
  padded_static_len: int = -1
129
131
 
132
+ # Whether this batch is prefill-only (no token generation needed)
133
+ is_prefill_only: bool = False
134
+
130
135
  @classmethod
131
136
  def from_forward_batch(cls, forward_batch: ForwardBatch):
132
137
  if (
133
- forward_batch.forward_mode.is_extend()
138
+ (
139
+ forward_batch.forward_mode.is_extend()
140
+ or forward_batch.forward_mode.is_split_prefill()
141
+ )
134
142
  and forward_batch.return_logprob
135
143
  and not forward_batch.forward_mode.is_target_verify()
136
144
  ):
@@ -169,6 +177,7 @@ class LogitsMetadata:
169
177
  token_ids_logprobs=forward_batch.token_ids_logprobs,
170
178
  extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
171
179
  padded_static_len=forward_batch.padded_static_len,
180
+ is_prefill_only=forward_batch.is_prefill_only,
172
181
  global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
173
182
  dp_local_start_pos=forward_batch.dp_local_start_pos,
174
183
  dp_local_num_tokens=forward_batch.dp_local_num_tokens,
@@ -219,8 +228,8 @@ class LogitsProcessor(nn.Module):
219
228
  super().__init__()
220
229
  self.config = config
221
230
  self.logit_scale = logit_scale
222
- self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
223
- self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
231
+ self.use_attn_tp_group = get_global_server_args().enable_dp_lm_head
232
+ self.use_fp32_lm_head = get_global_server_args().enable_fp32_lm_head
224
233
  if self.use_attn_tp_group:
225
234
  self.attn_tp_size = get_attention_tp_size()
226
235
  self.do_tensor_parallel_all_gather = (
@@ -243,8 +252,110 @@ class LogitsProcessor(nn.Module):
243
252
  ):
244
253
  self.final_logit_softcapping = None
245
254
 
246
- self.debug_tensor_dump_output_folder = global_server_args_dict.get(
247
- "debug_tensor_dump_output_folder", None
255
+ self.debug_tensor_dump_output_folder = (
256
+ get_global_server_args().debug_tensor_dump_output_folder
257
+ )
258
+
259
+ def compute_logprobs_for_multi_item_scoring(
260
+ self,
261
+ input_ids,
262
+ hidden_states,
263
+ lm_head: VocabParallelEmbedding,
264
+ logits_metadata: Union[LogitsMetadata, ForwardBatch],
265
+ delimiter_token: int,
266
+ ):
267
+ """
268
+ Compute logprobs for multi-item scoring using delimiter-based token extraction.
269
+
270
+ This method is designed for scenarios where you want to score multiple items/candidates
271
+ against a single query by combining them into one sequence separated by delimiters.
272
+
273
+ Sequence format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
274
+ Scoring positions: Extracts logprobs at positions before each <delimiter>
275
+
276
+ Args:
277
+ input_ids (torch.Tensor): Input token IDs containing query and items separated by delimiters.
278
+ Shape: [total_sequence_length] for single request or [batch_total_length] for batch.
279
+ hidden_states (torch.Tensor): Hidden states from the model.
280
+ Shape: [sequence_length, hidden_dim].
281
+ lm_head (VocabParallelEmbedding): Language model head for computing logits.
282
+ logits_metadata (Union[LogitsMetadata, ForwardBatch]): Metadata containing batch info
283
+ and token ID specifications for logprob extraction.
284
+ delimiter_token (int): Token ID used as delimiter between query and items.
285
+
286
+ Returns:
287
+ LogitsProcessorOutput: Contains:
288
+ - next_token_logits: None (not needed for scoring-only requests)
289
+ - input_token_logprobs: Logprobs of delimiter tokens at scoring positions
290
+ - input_top_logprobs_val: Top-k logprobs at delimiter positions (if requested)
291
+ - input_top_logprobs_idx: Top-k token indices at delimiter positions (if requested)
292
+ - input_token_ids_logprobs_val: Logprobs for user-requested token IDs (if any)
293
+ - input_token_ids_logprobs_idx: Indices for user-requested token IDs (if any)
294
+ """
295
+ multi_item_indices = (input_ids == delimiter_token).nonzero(as_tuple=True)[
296
+ 0
297
+ ] - 1
298
+ # Extract hidden states at delimiter positions for multi-item scoring
299
+ sliced_hidden = hidden_states[multi_item_indices]
300
+
301
+ sliced_logits = self._get_logits(sliced_hidden, lm_head, logits_metadata)
302
+ sliced_logprobs = torch.nn.functional.log_softmax(sliced_logits, dim=-1)
303
+
304
+ # Initialize return values
305
+ input_token_ids_logprobs_val = []
306
+ input_token_ids_logprobs_idx = []
307
+ input_top_logprobs_val = None
308
+ input_top_logprobs_idx = None
309
+
310
+ # Recalculate extend_logprob_pruned_lens_cpu to match delimiter counts per request
311
+ # Original contains sequence lengths, but we need delimiter counts for sliced_logprobs
312
+ if (
313
+ logits_metadata.token_ids_logprobs
314
+ or logits_metadata.extend_return_top_logprob
315
+ ):
316
+ logits_metadata.extend_logprob_pruned_lens_cpu = []
317
+
318
+ if logits_metadata.extend_seq_lens_cpu is not None:
319
+ # Multi-request batch: count delimiters per request
320
+ input_pt = 0
321
+ for req_seq_len in logits_metadata.extend_seq_lens_cpu:
322
+ req_input_ids = input_ids[input_pt : input_pt + req_seq_len]
323
+ delimiter_count = (req_input_ids == delimiter_token).sum().item()
324
+ logits_metadata.extend_logprob_pruned_lens_cpu.append(
325
+ delimiter_count
326
+ )
327
+ input_pt += req_seq_len
328
+ else:
329
+ # Single request case: one request gets all delimiters
330
+ total_delimiters = (input_ids == delimiter_token).sum().item()
331
+ logits_metadata.extend_logprob_pruned_lens_cpu = [total_delimiters]
332
+
333
+ # Get the logprobs of specified token ids
334
+ if logits_metadata.extend_token_ids_logprob:
335
+ (
336
+ input_token_ids_logprobs_val,
337
+ input_token_ids_logprobs_idx,
338
+ ) = self.get_token_ids_logprobs(
339
+ sliced_logprobs, logits_metadata, delay_cpu_copy=True
340
+ )
341
+
342
+ # Get the logprob of top-k tokens
343
+ if logits_metadata.extend_return_top_logprob:
344
+ (
345
+ input_top_logprobs_val,
346
+ input_top_logprobs_idx,
347
+ ) = self.get_top_logprobs(sliced_logprobs, logits_metadata)
348
+
349
+ # For input_token_logprobs, use delimiter token logprobs
350
+ input_token_logprobs = sliced_logprobs[:, delimiter_token]
351
+
352
+ return LogitsProcessorOutput(
353
+ next_token_logits=None, # Multi-item scoring doesn't need next token logits
354
+ input_token_logprobs=input_token_logprobs,
355
+ input_top_logprobs_val=input_top_logprobs_val,
356
+ input_top_logprobs_idx=input_top_logprobs_idx,
357
+ input_token_ids_logprobs_val=input_token_ids_logprobs_val,
358
+ input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
248
359
  )
249
360
 
250
361
  def forward(
@@ -257,10 +368,19 @@ class LogitsProcessor(nn.Module):
257
368
  ) -> LogitsProcessorOutput:
258
369
  if isinstance(logits_metadata, ForwardBatch):
259
370
  logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
371
+
372
+ # Check if multi-item scoring is enabled via server args (only for prefill-only requests)
373
+ multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter
374
+ if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
375
+ return self.compute_logprobs_for_multi_item_scoring(
376
+ input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
377
+ )
378
+
260
379
  # Get the last hidden states and last logits for the next token prediction
261
380
  if (
262
381
  logits_metadata.forward_mode.is_decode_or_idle()
263
382
  or logits_metadata.forward_mode.is_target_verify()
383
+ or logits_metadata.forward_mode.is_draft_extend_v2()
264
384
  ):
265
385
  pruned_states = hidden_states
266
386
  if aux_hidden_states is not None:
@@ -269,8 +389,8 @@ class LogitsProcessor(nn.Module):
269
389
  input_logprob_indices = None
270
390
  elif (
271
391
  logits_metadata.forward_mode.is_extend()
272
- and not logits_metadata.extend_return_logprob
273
- ):
392
+ or logits_metadata.forward_mode.is_split_prefill()
393
+ ) and not logits_metadata.extend_return_logprob:
274
394
  # Prefill without input logprobs.
275
395
  if logits_metadata.padded_static_len < 0:
276
396
  last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
@@ -584,7 +704,9 @@ class LogitsProcessor(nn.Module):
584
704
 
585
705
  @staticmethod
586
706
  def get_token_ids_logprobs(
587
- all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
707
+ all_logprobs: torch.Tensor,
708
+ logits_metadata: LogitsMetadata,
709
+ delay_cpu_copy: bool = False,
588
710
  ):
589
711
  input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
590
712
  pt = 0
@@ -597,9 +719,17 @@ class LogitsProcessor(nn.Module):
597
719
  input_token_ids_logprobs_idx.append([])
598
720
  continue
599
721
 
600
- input_token_ids_logprobs_val.append(
601
- [all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
602
- )
722
+ position_logprobs = all_logprobs[
723
+ pt : pt + pruned_len, token_ids
724
+ ] # Shape: [pruned_len, num_tokens]
725
+
726
+ if delay_cpu_copy:
727
+ # Keep as tensor to delay GPU-to-CPU transfer
728
+ input_token_ids_logprobs_val.append(position_logprobs)
729
+ else:
730
+ # Convert to list immediately (default behavior)
731
+ input_token_ids_logprobs_val.append(position_logprobs.tolist())
732
+
603
733
  input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
604
734
  pt += pruned_len
605
735
 
@@ -0,0 +1,11 @@
1
+ """
2
+ ModelOpt related constants
3
+ """
4
+
5
+ QUANT_CFG_CHOICES = {
6
+ "fp8": "FP8_DEFAULT_CFG",
7
+ "int4_awq": "INT4_AWQ_CFG", # TODO: add support for int4_awq
8
+ "w4a8_awq": "W4A8_AWQ_BETA_CFG", # TODO: add support for w4a8_awq
9
+ "nvfp4": "NVFP4_DEFAULT_CFG",
10
+ "nvfp4_awq": "NVFP4_AWQ_LITE_CFG", # TODO: add support for nvfp4_awq
11
+ }
@@ -116,8 +116,6 @@ def cutlass_fused_experts_fp8(
116
116
 
117
117
  if is_cuda:
118
118
  from sglang.srt.layers.quantization.fp8_kernel import (
119
- per_group_transpose,
120
- per_token_group_quant_fp8_hopper_moe_mn_major,
121
119
  sglang_per_token_group_quant_fp8,
122
120
  )
123
121