sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -38,17 +38,15 @@ from sglang.srt.layers.dp_attention import (
38
38
  get_dp_device,
39
39
  get_dp_dtype,
40
40
  get_dp_hidden_size,
41
- get_global_dp_buffer,
42
41
  get_local_attention_dp_size,
43
- set_dp_buffer_len,
44
42
  )
45
43
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
46
- from sglang.srt.managers.schedule_batch import global_server_args_dict
47
44
  from sglang.srt.model_executor.forward_batch_info import (
48
45
  CaptureHiddenMode,
49
46
  ForwardBatch,
50
47
  ForwardMode,
51
48
  )
49
+ from sglang.srt.server_args import get_global_server_args
52
50
  from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
53
51
 
54
52
  logger = logging.getLogger(__name__)
@@ -60,13 +58,14 @@ _is_npu = is_npu()
60
58
  class LogitsProcessorOutput:
61
59
  ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
62
60
  # The logits of the next tokens. shape: [#seq, vocab_size]
63
- next_token_logits: torch.Tensor
61
+ # Can be None for certain prefill-only requests (e.g., multi-item scoring) that don't need next token generation
62
+ next_token_logits: Optional[torch.Tensor]
64
63
  # Used by speculative decoding (EAGLE)
65
64
  # The last hidden layers
66
65
  hidden_states: Optional[torch.Tensor] = None
67
66
 
68
67
  ## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
69
- # he log probs of output tokens, if RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
68
+ # he log probs of output tokens, if SGLANG_RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
70
69
  next_token_logprobs: Optional[torch.Tensor] = None
71
70
  # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
72
71
  next_token_top_logprobs_val: Optional[List] = None
@@ -85,7 +84,10 @@ class LogitsProcessorOutput:
85
84
  input_top_logprobs_val: List = None
86
85
  input_top_logprobs_idx: List = None
87
86
  # The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
88
- input_token_ids_logprobs_val: Optional[List] = None
87
+ # Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization)
88
+ input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = (
89
+ None
90
+ )
89
91
  input_token_ids_logprobs_idx: Optional[List] = None
90
92
 
91
93
 
@@ -127,10 +129,16 @@ class LogitsMetadata:
127
129
  # for padding
128
130
  padded_static_len: int = -1
129
131
 
132
+ # Whether this batch is prefill-only (no token generation needed)
133
+ is_prefill_only: bool = False
134
+
130
135
  @classmethod
131
136
  def from_forward_batch(cls, forward_batch: ForwardBatch):
132
137
  if (
133
- forward_batch.forward_mode.is_extend()
138
+ (
139
+ forward_batch.forward_mode.is_extend()
140
+ or forward_batch.forward_mode.is_split_prefill()
141
+ )
134
142
  and forward_batch.return_logprob
135
143
  and not forward_batch.forward_mode.is_target_verify()
136
144
  ):
@@ -169,6 +177,7 @@ class LogitsMetadata:
169
177
  token_ids_logprobs=forward_batch.token_ids_logprobs,
170
178
  extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
171
179
  padded_static_len=forward_batch.padded_static_len,
180
+ is_prefill_only=forward_batch.is_prefill_only,
172
181
  global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
173
182
  dp_local_start_pos=forward_batch.dp_local_start_pos,
174
183
  dp_local_num_tokens=forward_batch.dp_local_num_tokens,
@@ -219,7 +228,8 @@ class LogitsProcessor(nn.Module):
219
228
  super().__init__()
220
229
  self.config = config
221
230
  self.logit_scale = logit_scale
222
- self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
231
+ self.use_attn_tp_group = get_global_server_args().enable_dp_lm_head
232
+ self.use_fp32_lm_head = get_global_server_args().enable_fp32_lm_head
223
233
  if self.use_attn_tp_group:
224
234
  self.attn_tp_size = get_attention_tp_size()
225
235
  self.do_tensor_parallel_all_gather = (
@@ -242,8 +252,110 @@ class LogitsProcessor(nn.Module):
242
252
  ):
243
253
  self.final_logit_softcapping = None
244
254
 
245
- self.debug_tensor_dump_output_folder = global_server_args_dict.get(
246
- "debug_tensor_dump_output_folder", None
255
+ self.debug_tensor_dump_output_folder = (
256
+ get_global_server_args().debug_tensor_dump_output_folder
257
+ )
258
+
259
+ def compute_logprobs_for_multi_item_scoring(
260
+ self,
261
+ input_ids,
262
+ hidden_states,
263
+ lm_head: VocabParallelEmbedding,
264
+ logits_metadata: Union[LogitsMetadata, ForwardBatch],
265
+ delimiter_token: int,
266
+ ):
267
+ """
268
+ Compute logprobs for multi-item scoring using delimiter-based token extraction.
269
+
270
+ This method is designed for scenarios where you want to score multiple items/candidates
271
+ against a single query by combining them into one sequence separated by delimiters.
272
+
273
+ Sequence format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
274
+ Scoring positions: Extracts logprobs at positions before each <delimiter>
275
+
276
+ Args:
277
+ input_ids (torch.Tensor): Input token IDs containing query and items separated by delimiters.
278
+ Shape: [total_sequence_length] for single request or [batch_total_length] for batch.
279
+ hidden_states (torch.Tensor): Hidden states from the model.
280
+ Shape: [sequence_length, hidden_dim].
281
+ lm_head (VocabParallelEmbedding): Language model head for computing logits.
282
+ logits_metadata (Union[LogitsMetadata, ForwardBatch]): Metadata containing batch info
283
+ and token ID specifications for logprob extraction.
284
+ delimiter_token (int): Token ID used as delimiter between query and items.
285
+
286
+ Returns:
287
+ LogitsProcessorOutput: Contains:
288
+ - next_token_logits: None (not needed for scoring-only requests)
289
+ - input_token_logprobs: Logprobs of delimiter tokens at scoring positions
290
+ - input_top_logprobs_val: Top-k logprobs at delimiter positions (if requested)
291
+ - input_top_logprobs_idx: Top-k token indices at delimiter positions (if requested)
292
+ - input_token_ids_logprobs_val: Logprobs for user-requested token IDs (if any)
293
+ - input_token_ids_logprobs_idx: Indices for user-requested token IDs (if any)
294
+ """
295
+ multi_item_indices = (input_ids == delimiter_token).nonzero(as_tuple=True)[
296
+ 0
297
+ ] - 1
298
+ # Extract hidden states at delimiter positions for multi-item scoring
299
+ sliced_hidden = hidden_states[multi_item_indices]
300
+
301
+ sliced_logits = self._get_logits(sliced_hidden, lm_head, logits_metadata)
302
+ sliced_logprobs = torch.nn.functional.log_softmax(sliced_logits, dim=-1)
303
+
304
+ # Initialize return values
305
+ input_token_ids_logprobs_val = []
306
+ input_token_ids_logprobs_idx = []
307
+ input_top_logprobs_val = None
308
+ input_top_logprobs_idx = None
309
+
310
+ # Recalculate extend_logprob_pruned_lens_cpu to match delimiter counts per request
311
+ # Original contains sequence lengths, but we need delimiter counts for sliced_logprobs
312
+ if (
313
+ logits_metadata.token_ids_logprobs
314
+ or logits_metadata.extend_return_top_logprob
315
+ ):
316
+ logits_metadata.extend_logprob_pruned_lens_cpu = []
317
+
318
+ if logits_metadata.extend_seq_lens_cpu is not None:
319
+ # Multi-request batch: count delimiters per request
320
+ input_pt = 0
321
+ for req_seq_len in logits_metadata.extend_seq_lens_cpu:
322
+ req_input_ids = input_ids[input_pt : input_pt + req_seq_len]
323
+ delimiter_count = (req_input_ids == delimiter_token).sum().item()
324
+ logits_metadata.extend_logprob_pruned_lens_cpu.append(
325
+ delimiter_count
326
+ )
327
+ input_pt += req_seq_len
328
+ else:
329
+ # Single request case: one request gets all delimiters
330
+ total_delimiters = (input_ids == delimiter_token).sum().item()
331
+ logits_metadata.extend_logprob_pruned_lens_cpu = [total_delimiters]
332
+
333
+ # Get the logprobs of specified token ids
334
+ if logits_metadata.extend_token_ids_logprob:
335
+ (
336
+ input_token_ids_logprobs_val,
337
+ input_token_ids_logprobs_idx,
338
+ ) = self.get_token_ids_logprobs(
339
+ sliced_logprobs, logits_metadata, delay_cpu_copy=True
340
+ )
341
+
342
+ # Get the logprob of top-k tokens
343
+ if logits_metadata.extend_return_top_logprob:
344
+ (
345
+ input_top_logprobs_val,
346
+ input_top_logprobs_idx,
347
+ ) = self.get_top_logprobs(sliced_logprobs, logits_metadata)
348
+
349
+ # For input_token_logprobs, use delimiter token logprobs
350
+ input_token_logprobs = sliced_logprobs[:, delimiter_token]
351
+
352
+ return LogitsProcessorOutput(
353
+ next_token_logits=None, # Multi-item scoring doesn't need next token logits
354
+ input_token_logprobs=input_token_logprobs,
355
+ input_top_logprobs_val=input_top_logprobs_val,
356
+ input_top_logprobs_idx=input_top_logprobs_idx,
357
+ input_token_ids_logprobs_val=input_token_ids_logprobs_val,
358
+ input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
247
359
  )
248
360
 
249
361
  def forward(
@@ -256,10 +368,19 @@ class LogitsProcessor(nn.Module):
256
368
  ) -> LogitsProcessorOutput:
257
369
  if isinstance(logits_metadata, ForwardBatch):
258
370
  logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
371
+
372
+ # Check if multi-item scoring is enabled via server args (only for prefill-only requests)
373
+ multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter
374
+ if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
375
+ return self.compute_logprobs_for_multi_item_scoring(
376
+ input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
377
+ )
378
+
259
379
  # Get the last hidden states and last logits for the next token prediction
260
380
  if (
261
381
  logits_metadata.forward_mode.is_decode_or_idle()
262
382
  or logits_metadata.forward_mode.is_target_verify()
383
+ or logits_metadata.forward_mode.is_draft_extend_v2()
263
384
  ):
264
385
  pruned_states = hidden_states
265
386
  if aux_hidden_states is not None:
@@ -268,8 +389,8 @@ class LogitsProcessor(nn.Module):
268
389
  input_logprob_indices = None
269
390
  elif (
270
391
  logits_metadata.forward_mode.is_extend()
271
- and not logits_metadata.extend_return_logprob
272
- ):
392
+ or logits_metadata.forward_mode.is_split_prefill()
393
+ ) and not logits_metadata.extend_return_logprob:
273
394
  # Prefill without input logprobs.
274
395
  if logits_metadata.padded_static_len < 0:
275
396
  last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
@@ -461,7 +582,11 @@ class LogitsProcessor(nn.Module):
461
582
  dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
462
583
 
463
584
  if hasattr(lm_head, "weight"):
464
- if use_intel_amx_backend(lm_head):
585
+ if self.use_fp32_lm_head:
586
+ logits = torch.matmul(
587
+ hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T
588
+ )
589
+ elif use_intel_amx_backend(lm_head):
465
590
  logits = torch.ops.sgl_kernel.weight_packed_linear(
466
591
  hidden_states.to(lm_head.weight.dtype),
467
592
  lm_head.weight,
@@ -475,7 +600,15 @@ class LogitsProcessor(nn.Module):
475
600
  else:
476
601
  # GGUF models
477
602
  # TODO: use weight_packed_linear for GGUF models
478
- logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
603
+ if self.use_fp32_lm_head:
604
+ with torch.cuda.amp.autocast(enabled=False):
605
+ logits = lm_head.quant_method.apply(
606
+ lm_head, hidden_states.to(torch.float32), embedding_bias
607
+ )
608
+ else:
609
+ logits = lm_head.quant_method.apply(
610
+ lm_head, hidden_states, embedding_bias
611
+ )
479
612
 
480
613
  if self.logit_scale is not None:
481
614
  logits.mul_(self.logit_scale)
@@ -571,7 +704,9 @@ class LogitsProcessor(nn.Module):
571
704
 
572
705
  @staticmethod
573
706
  def get_token_ids_logprobs(
574
- all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
707
+ all_logprobs: torch.Tensor,
708
+ logits_metadata: LogitsMetadata,
709
+ delay_cpu_copy: bool = False,
575
710
  ):
576
711
  input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
577
712
  pt = 0
@@ -584,9 +719,17 @@ class LogitsProcessor(nn.Module):
584
719
  input_token_ids_logprobs_idx.append([])
585
720
  continue
586
721
 
587
- input_token_ids_logprobs_val.append(
588
- [all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
589
- )
722
+ position_logprobs = all_logprobs[
723
+ pt : pt + pruned_len, token_ids
724
+ ] # Shape: [pruned_len, num_tokens]
725
+
726
+ if delay_cpu_copy:
727
+ # Keep as tensor to delay GPU-to-CPU transfer
728
+ input_token_ids_logprobs_val.append(position_logprobs)
729
+ else:
730
+ # Convert to list immediately (default behavior)
731
+ input_token_ids_logprobs_val.append(position_logprobs.tolist())
732
+
590
733
  input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
591
734
  pt += pruned_len
592
735
 
@@ -0,0 +1,11 @@
1
+ """
2
+ ModelOpt related constants
3
+ """
4
+
5
+ QUANT_CFG_CHOICES = {
6
+ "fp8": "FP8_DEFAULT_CFG",
7
+ "int4_awq": "INT4_AWQ_CFG", # TODO: add support for int4_awq
8
+ "w4a8_awq": "W4A8_AWQ_BETA_CFG", # TODO: add support for w4a8_awq
9
+ "nvfp4": "NVFP4_DEFAULT_CFG",
10
+ "nvfp4_awq": "NVFP4_AWQ_LITE_CFG", # TODO: add support for nvfp4_awq
11
+ }
@@ -116,8 +116,6 @@ def cutlass_fused_experts_fp8(
116
116
 
117
117
  if is_cuda:
118
118
  from sglang.srt.layers.quantization.fp8_kernel import (
119
- per_group_transpose,
120
- per_token_group_quant_fp8_hopper_moe_mn_major,
121
119
  sglang_per_token_group_quant_fp8,
122
120
  )
123
121
 
@@ -11,24 +11,23 @@ from sgl_kernel import (
11
11
  )
12
12
 
13
13
  from sglang.srt.layers.moe.ep_moe.kernels import (
14
+ deepep_permute_triton_kernel,
15
+ deepep_post_reorder_triton_kernel,
16
+ deepep_run_moe_deep_preprocess,
14
17
  post_reorder_triton_kernel_for_cutlass_moe,
15
18
  pre_reorder_triton_kernel_for_cutlass_moe,
16
- run_cutlass_moe_ep_preproess,
19
+ run_moe_ep_preproess,
17
20
  )
18
21
 
19
22
 
20
23
  def cutlass_w4a8_moe(
21
- start_expert_id: int,
22
- end_expert_id: int,
23
- total_num_experts: int,
24
24
  a: torch.Tensor,
25
25
  w1_q: torch.Tensor,
26
26
  w2_q: torch.Tensor,
27
27
  w1_scale: torch.Tensor,
28
28
  w2_scale: torch.Tensor,
29
29
  topk_weights: torch.Tensor,
30
- topk_ids_: torch.Tensor,
31
- local_topk_ids: torch.Tensor,
30
+ topk_ids: torch.Tensor,
32
31
  a_strides1: torch.Tensor,
33
32
  b_strides1: torch.Tensor,
34
33
  c_strides1: torch.Tensor,
@@ -64,6 +63,7 @@ def cutlass_w4a8_moe(
64
63
  - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
65
64
  Shape: [num_experts, N // 512, K * 4]
66
65
  - topk_weights (torch.Tensor): The weights of each token->expert mapping.
66
+ - topk_ids (torch.Tensor): The ids of each token->expert mapping.
67
67
  - a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
68
68
  - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
69
69
  - c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
@@ -83,7 +83,7 @@ def cutlass_w4a8_moe(
83
83
  Returns:
84
84
  - torch.Tensor: The fp8 output tensor after applying the MoE layer.
85
85
  """
86
- assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
86
+ assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
87
87
  assert w1_q.dtype == torch.int8
88
88
  assert w2_q.dtype == torch.int8
89
89
  assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
@@ -96,20 +96,21 @@ def cutlass_w4a8_moe(
96
96
  assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
97
97
  assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
98
98
  assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
99
- num_experts = w1_q.size(0)
99
+ num_local_experts = w1_q.size(0)
100
100
  m = a.size(0)
101
101
  k = w1_q.size(2) * 2 # w1_q is transposed and packed
102
102
  n = w2_q.size(2) * 2 # w2_q is transposed and packed
103
- topk = topk_ids_.size(1)
103
+ topk = topk_ids.size(1)
104
104
 
105
105
  if apply_router_weight_on_input:
106
106
  assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
107
107
 
108
108
  device = a.device
109
+ topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids)
109
110
 
110
- _, src2dst, _ = run_cutlass_moe_ep_preproess(
111
- local_topk_ids,
112
- num_experts,
111
+ _, src2dst, _ = run_moe_ep_preproess(
112
+ topk_ids,
113
+ num_local_experts,
113
114
  )
114
115
 
115
116
  gateup_input = torch.empty(
@@ -122,9 +123,9 @@ def cutlass_w4a8_moe(
122
123
  a,
123
124
  gateup_input,
124
125
  src2dst,
125
- local_topk_ids,
126
+ topk_ids,
126
127
  a1_scale,
127
- total_num_experts,
128
+ num_local_experts,
128
129
  topk,
129
130
  k,
130
131
  BLOCK_SIZE=512,
@@ -133,16 +134,16 @@ def cutlass_w4a8_moe(
133
134
  # NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
134
135
  # they are kept to allow for a quick switch of the permutation logic
135
136
  # from the current triton kernel implementation to the cutlass-based one if needed.
136
- a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
137
- c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
137
+ a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
138
+ c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
138
139
  get_cutlass_w4a8_moe_mm_data(
139
- local_topk_ids,
140
+ topk_ids,
140
141
  expert_offsets,
141
142
  problem_sizes1,
142
143
  problem_sizes2,
143
144
  a_map,
144
145
  c_map,
145
- num_experts,
146
+ num_local_experts,
146
147
  n,
147
148
  k,
148
149
  )
@@ -195,12 +196,203 @@ def cutlass_w4a8_moe(
195
196
  c2,
196
197
  output,
197
198
  src2dst,
198
- local_topk_ids,
199
+ topk_ids,
199
200
  topk_weights,
200
- num_experts,
201
201
  topk,
202
+ num_local_experts,
202
203
  k,
203
- 0,
204
204
  BLOCK_SIZE=512,
205
205
  )
206
206
  return output
207
+
208
+
209
+ def cutlass_w4a8_moe_deepep_normal(
210
+ a: torch.Tensor,
211
+ w1_q: torch.Tensor,
212
+ w2_q: torch.Tensor,
213
+ w1_scale: torch.Tensor,
214
+ w2_scale: torch.Tensor,
215
+ topk_weights: torch.Tensor,
216
+ topk_ids_: torch.Tensor,
217
+ a_strides1: torch.Tensor,
218
+ b_strides1: torch.Tensor,
219
+ c_strides1: torch.Tensor,
220
+ a_strides2: torch.Tensor,
221
+ b_strides2: torch.Tensor,
222
+ c_strides2: torch.Tensor,
223
+ s_strides13: torch.Tensor,
224
+ s_strides2: torch.Tensor,
225
+ expert_offsets: torch.Tensor,
226
+ problem_sizes1: torch.Tensor,
227
+ problem_sizes2: torch.Tensor,
228
+ a1_scale: Optional[torch.Tensor] = None,
229
+ a2_scale: Optional[torch.Tensor] = None,
230
+ ) -> torch.Tensor:
231
+ """
232
+ This function computes a w4a8-quantized Mixture of Experts (MoE) layer
233
+ using two sets of quantized weights, w1_q and w2_q, and top-k gating
234
+ mechanism. The matrix multiplications are implemented with CUTLASS
235
+ grouped gemm.
236
+
237
+ Parameters:
238
+ - a (torch.Tensor): The input tensor to the MoE layer.
239
+ Shape: [M, K]
240
+ - w1_q (torch.Tensor): The first set of int4-quantized expert weights.
241
+ Shape: [num_experts, N * 2, K // 2]
242
+ (the weights are passed transposed and int4-packed)
243
+ - w2_q (torch.Tensor): The second set of int4-quantized expert weights.
244
+ Shape: [num_experts, K, N // 2]
245
+ (the weights are passed transposed and int4-packed)
246
+ - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
247
+ Shape: [num_experts, K // 512, N * 8]
248
+ - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
249
+ Shape: [num_experts, N // 512, K * 4]
250
+ - topk_weights (torch.Tensor): The weights of each token->expert mapping.
251
+ - a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
252
+ - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
253
+ - c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
254
+ - a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
255
+ - b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
256
+ - c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
257
+ - s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
258
+ - s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
259
+ - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
260
+ Shape: scalar or [1, K]
261
+ - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
262
+ quantize the intermediate result between the gemms.
263
+ Shape: scalar or [1, N]
264
+ - apply_router_weight_on_input (bool): When true, the topk weights are
265
+ applied directly on the inputs. This is only applicable when topk is 1.
266
+
267
+ Returns:
268
+ - torch.Tensor: The fp8 output tensor after applying the MoE layer.
269
+ """
270
+ assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
271
+ assert w1_q.dtype == torch.int8
272
+ assert w2_q.dtype == torch.int8
273
+ assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
274
+ assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
275
+ assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
276
+ assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
277
+ assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
278
+
279
+ assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
280
+ assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
281
+ assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
282
+ assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
283
+ num_experts = w1_q.size(0)
284
+ m = a.size(0)
285
+ k = w1_q.size(2) * 2 # w1_q is transposed and packed
286
+ n = w2_q.size(2) * 2 # w2_q is transposed and packed
287
+ topk = topk_ids_.size(1)
288
+
289
+ num_experts = w1_q.size(0)
290
+ m = a.size(0)
291
+ k = w1_q.size(2) * 2
292
+ n = w2_q.size(2) * 2
293
+ topk = topk_ids_.size(1)
294
+ device = a.device
295
+
296
+ reorder_topk_ids, src2dst, _ = deepep_run_moe_deep_preprocess(
297
+ topk_ids_, num_experts
298
+ )
299
+ num_total_tokens = reorder_topk_ids.numel()
300
+ gateup_input_pre_reorder = torch.empty(
301
+ (int(num_total_tokens), a.shape[1]),
302
+ device=device,
303
+ dtype=a.dtype,
304
+ )
305
+ deepep_permute_triton_kernel[(a.shape[0],)](
306
+ a,
307
+ gateup_input_pre_reorder,
308
+ src2dst,
309
+ topk_ids_.to(torch.int64),
310
+ None,
311
+ topk,
312
+ a.shape[1],
313
+ BLOCK_SIZE=512,
314
+ )
315
+ gateup_input = torch.empty(
316
+ gateup_input_pre_reorder.shape, dtype=torch.float8_e4m3fn, device=device
317
+ )
318
+ sgl_per_tensor_quant_fp8(
319
+ gateup_input_pre_reorder, gateup_input, a1_scale.float(), True
320
+ )
321
+ del gateup_input_pre_reorder
322
+ local_topk_ids = topk_ids_
323
+ local_topk_ids = (
324
+ torch.where(local_topk_ids == -1, num_experts, topk_ids_).to(torch.int32)
325
+ ).contiguous()
326
+
327
+ a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
328
+ c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
329
+ get_cutlass_w4a8_moe_mm_data(
330
+ local_topk_ids,
331
+ expert_offsets,
332
+ problem_sizes1,
333
+ problem_sizes2,
334
+ a_map,
335
+ c_map,
336
+ num_experts,
337
+ n,
338
+ k,
339
+ )
340
+ c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
341
+ c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
342
+
343
+ cutlass_w4a8_moe_mm(
344
+ c1,
345
+ gateup_input,
346
+ w1_q,
347
+ a1_scale.float(),
348
+ w1_scale,
349
+ expert_offsets[:-1],
350
+ problem_sizes1,
351
+ a_strides1,
352
+ b_strides1,
353
+ c_strides1,
354
+ s_strides13,
355
+ 128,
356
+ topk,
357
+ )
358
+ intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
359
+ silu_and_mul(c1, intermediate)
360
+
361
+ intermediate_q = torch.empty(
362
+ intermediate.shape, dtype=torch.float8_e4m3fn, device=device
363
+ )
364
+ sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)
365
+
366
+ cutlass_w4a8_moe_mm(
367
+ c2,
368
+ intermediate_q,
369
+ w2_q,
370
+ a2_scale.float(),
371
+ w2_scale,
372
+ expert_offsets[:-1],
373
+ problem_sizes2,
374
+ a_strides2,
375
+ b_strides2,
376
+ c_strides2,
377
+ s_strides2,
378
+ 128,
379
+ topk,
380
+ )
381
+ num_tokens = src2dst.shape[0] // topk
382
+ output = torch.empty(
383
+ (num_tokens, c2.shape[1]),
384
+ device=c2.device,
385
+ dtype=torch.bfloat16,
386
+ )
387
+ deepep_post_reorder_triton_kernel[(num_tokens,)](
388
+ c2,
389
+ output,
390
+ src2dst,
391
+ topk_ids_,
392
+ topk_weights,
393
+ topk,
394
+ c2.shape[1],
395
+ BLOCK_SIZE=512,
396
+ )
397
+
398
+ return output