sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
+ from contextlib import nullcontext
4
5
  from dataclasses import dataclass
5
6
  from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
6
7
 
7
8
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
9
+ from sglang.srt.layers import deep_gemm_wrapper
10
+ from sglang.srt.layers.dp_attention import get_is_extend_in_batch
8
11
  from sglang.srt.layers.moe.token_dispatcher.base import (
9
12
  BaseDispatcher,
10
13
  BaseDispatcherConfig,
@@ -13,8 +16,13 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
13
16
  DispatchOutput,
14
17
  DispatchOutputFormat,
15
18
  )
16
- from sglang.srt.layers.moe.utils import DeepEPMode, get_deepep_config, is_tbo_enabled
17
- from sglang.srt.layers.quantization import deep_gemm_wrapper
19
+ from sglang.srt.layers.moe.topk import TopKOutput
20
+ from sglang.srt.layers.moe.utils import (
21
+ DeepEPMode,
22
+ get_deepep_config,
23
+ get_moe_runner_backend,
24
+ is_tbo_enabled,
25
+ )
18
26
  from sglang.srt.utils import (
19
27
  get_bool_env_var,
20
28
  get_int_env_var,
@@ -25,6 +33,9 @@ from sglang.srt.utils import (
25
33
 
26
34
  _is_npu = is_npu()
27
35
 
36
+ if TYPE_CHECKING:
37
+ from sglang.srt.single_batch_overlap import CombineOverlapArgs
38
+
28
39
  try:
29
40
  from deep_ep import Buffer, Config
30
41
 
@@ -42,8 +53,6 @@ from enum import Enum, IntEnum, auto
42
53
  import torch
43
54
  import torch.distributed as dist
44
55
 
45
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
46
-
47
56
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
48
57
 
49
58
  logger = logging.getLogger(__name__)
@@ -52,9 +61,9 @@ logger = logging.getLogger(__name__)
52
61
  class DeepEPNormalOutput(NamedTuple):
53
62
  """DeepEP normal dispatch output."""
54
63
 
55
- hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
56
- # hidden_states_scale
57
- topk_idx: torch.Tensor
64
+ hidden_states: torch.Tensor
65
+ hidden_states_scale: Optional[torch.Tensor]
66
+ topk_ids: torch.Tensor
58
67
  topk_weights: torch.Tensor
59
68
  num_recv_tokens_per_expert: List[int]
60
69
 
@@ -66,8 +75,9 @@ class DeepEPNormalOutput(NamedTuple):
66
75
  class DeepEPLLOutput(NamedTuple):
67
76
  """DeepEP low latency dispatch output."""
68
77
 
69
- hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
70
- topk_idx: torch.Tensor
78
+ hidden_states: torch.Tensor
79
+ hidden_states_scale: Optional[torch.Tensor]
80
+ topk_ids: torch.Tensor
71
81
  topk_weights: torch.Tensor
72
82
  masked_m: torch.Tensor
73
83
  expected_m: int
@@ -164,10 +174,19 @@ class DeepEPBuffer:
164
174
  num_rdma_bytes,
165
175
  )
166
176
 
177
+ # We should calculate num_qps_per_rank consistently with DeepEP's test script logic:
167
178
  if deepep_mode == DeepEPMode.NORMAL:
168
- num_qps_per_rank = DeepEPConfig.get_instance().num_sms // 2
169
- elif deepep_mode in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]:
179
+ # refer: https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py#L235
180
+ num_qps_per_rank = DeepEPConfig.get_instance().num_sms
181
+ elif deepep_mode == DeepEPMode.LOW_LATENCY:
182
+ # refer: https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_low_latency.py#L176
170
183
  num_qps_per_rank = num_experts // group.size()
184
+ elif deepep_mode == DeepEPMode.AUTO:
185
+ # low-latency and normal mode all need run
186
+ # refer: https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py#L235
187
+ num_qps_per_rank = max(
188
+ DeepEPConfig.get_instance().num_sms, num_experts // group.size()
189
+ )
171
190
  else:
172
191
  raise NotImplementedError
173
192
 
@@ -217,6 +236,15 @@ class DeepEPBuffer:
217
236
  cls.clean_buffer()
218
237
  cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
219
238
 
239
+ @classmethod
240
+ def set_dispatch_mode(cls, mode: DeepEPMode):
241
+ if mode.is_low_latency():
242
+ cls.set_dispatch_mode_as_low_latency()
243
+ elif mode.is_normal():
244
+ cls.set_dispatch_mode_as_normal()
245
+ else:
246
+ raise Exception("unsupported mode")
247
+
220
248
 
221
249
  class DeepEPConfig(BaseDispatcherConfig):
222
250
  _instance = None
@@ -287,8 +315,7 @@ class _DeepEPDispatcherImplBase:
287
315
  def dispatch_a(
288
316
  self,
289
317
  hidden_states: torch.Tensor,
290
- topk_idx: torch.Tensor,
291
- topk_weights: torch.Tensor,
318
+ topk_output: TopKOutput,
292
319
  ):
293
320
  raise NotImplementedError
294
321
 
@@ -298,8 +325,9 @@ class _DeepEPDispatcherImplBase:
298
325
  def combine_a(
299
326
  self,
300
327
  hidden_states: torch.Tensor,
301
- topk_idx: torch.Tensor,
328
+ topk_ids: torch.Tensor,
302
329
  topk_weights: torch.Tensor,
330
+ overlap_args: Optional["CombineOverlapArgs"],
303
331
  ):
304
332
  raise NotImplementedError
305
333
 
@@ -316,15 +344,19 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
316
344
 
317
345
  self.async_finish = async_finish
318
346
  self.src2dst = None
347
+ self.quant_config = {}
319
348
 
320
349
  def dispatch_a(
321
350
  self,
322
351
  hidden_states: torch.Tensor,
323
- topk_idx: torch.Tensor,
324
- topk_weights: torch.Tensor,
352
+ topk_output: TopKOutput,
325
353
  ):
326
- topk_idx = topk_idx.to(torch.int64)
327
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
354
+ topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
355
+ topk_ids = topk_ids.to(torch.int64)
356
+ if (
357
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
358
+ and not get_moe_runner_backend().is_cutlass()
359
+ ):
328
360
  # TODO hard code 128 block quant,use fp8 communication
329
361
  hidden_states = sglang_per_token_group_quant_fp8(
330
362
  hidden_states,
@@ -334,25 +366,35 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
334
366
  scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
335
367
  )
336
368
  previous_event = Buffer.capture() if self.async_finish else None
337
- return hidden_states, topk_idx, topk_weights, previous_event
369
+ return hidden_states, topk_ids, topk_weights, previous_event
338
370
 
339
- def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
371
+ def dispatch_b(self, hidden_states, topk_ids, topk_weights, previous_event):
340
372
  (
341
373
  hidden_states,
342
- topk_idx,
374
+ topk_ids,
343
375
  topk_weights,
344
376
  num_recv_tokens_per_expert,
345
377
  event,
346
- ) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
378
+ ) = self._dispatch_core(hidden_states, topk_ids, topk_weights, previous_event)
347
379
  event.current_stream_wait() if self.async_finish else ()
380
+
381
+ if isinstance(hidden_states, tuple):
382
+ hidden_states, hidden_states_scale = hidden_states
383
+ else:
384
+ hidden_states_scale = None
385
+
348
386
  return DeepEPNormalOutput(
349
- hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
387
+ hidden_states,
388
+ hidden_states_scale,
389
+ topk_ids,
390
+ topk_weights,
391
+ num_recv_tokens_per_expert,
350
392
  )
351
393
 
352
394
  def _dispatch_core(
353
395
  self,
354
396
  x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
355
- topk_idx: torch.Tensor,
397
+ topk_ids: torch.Tensor,
356
398
  topk_weights: torch.Tensor,
357
399
  previous_event,
358
400
  ):
@@ -364,27 +406,26 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
364
406
  is_token_in_rank,
365
407
  previous_event,
366
408
  ) = buffer.get_dispatch_layout(
367
- topk_idx,
409
+ topk_ids,
368
410
  self.num_experts,
369
411
  previous_event=previous_event,
370
412
  async_finish=self.async_finish,
371
413
  allocate_on_comm_stream=previous_event is not None,
372
414
  )
373
-
374
415
  # FIXME: `handle` should be transmitted with tokens from dispatch to combine.
375
416
  # However, doing this would incur an unknown synchronization error, but keeping
376
417
  # `handle` as a member variable works.
377
418
 
378
419
  (
379
420
  recv_x,
380
- recv_topk_idx,
421
+ recv_topk_ids,
381
422
  recv_topk_weights,
382
423
  num_recv_tokens_per_expert,
383
424
  self.handle,
384
425
  event,
385
426
  ) = buffer.dispatch(
386
427
  x,
387
- topk_idx=topk_idx,
428
+ topk_idx=topk_ids,
388
429
  topk_weights=topk_weights,
389
430
  num_tokens_per_rank=num_tokens_per_rank,
390
431
  num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
@@ -396,7 +437,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
396
437
  expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
397
438
  config=DeepEPConfig.get_instance().normal_dispatch_config,
398
439
  )
399
-
400
440
  get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
401
441
  num_recv_tokens_per_expert,
402
442
  num_tokens_per_rank=num_tokens_per_rank,
@@ -406,7 +446,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
406
446
 
407
447
  return (
408
448
  recv_x,
409
- recv_topk_idx,
449
+ recv_topk_ids,
410
450
  recv_topk_weights,
411
451
  num_recv_tokens_per_expert,
412
452
  event,
@@ -415,39 +455,16 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
415
455
  def combine_a(
416
456
  self,
417
457
  hidden_states: torch.Tensor,
418
- topk_idx: torch.Tensor,
458
+ topk_ids: torch.Tensor,
419
459
  topk_weights: torch.Tensor,
460
+ overlap_args: Optional["CombineOverlapArgs"],
420
461
  ):
421
- from sglang.srt.layers.moe.ep_moe.kernels import (
422
- deepep_post_reorder_triton_kernel,
423
- )
424
462
 
425
463
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
426
464
  output = hidden_states
427
465
  else:
428
- if hidden_states.shape[0] > 0:
429
- num_tokens = self.src2dst.shape[0] // self.router_topk
430
- output = torch.empty(
431
- (num_tokens, hidden_states.shape[1]),
432
- device=hidden_states.device,
433
- dtype=hidden_states.dtype,
434
- )
435
- deepep_post_reorder_triton_kernel[(num_tokens,)](
436
- hidden_states,
437
- output,
438
- self.src2dst,
439
- topk_idx,
440
- topk_weights,
441
- self.router_topk,
442
- hidden_states.shape[1],
443
- BLOCK_SIZE=512,
444
- )
445
- else:
446
- output = torch.zeros(
447
- (0, hidden_states.shape[1]),
448
- device=hidden_states.device,
449
- dtype=hidden_states.dtype,
450
- )
466
+ raise NotImplementedError() # triton runner was supported but it's temporarily disabled
467
+
451
468
  previous_event = Buffer.capture() if self.async_finish else None
452
469
  return output, previous_event
453
470
 
@@ -482,6 +499,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
482
499
  self.num_experts,
483
500
  )
484
501
 
502
+ def set_quant_config(self, quant_config: dict):
503
+ self.quant_config = quant_config
504
+
485
505
 
486
506
  class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
487
507
  def __init__(self, return_recv_hook: bool, **kwargs):
@@ -492,28 +512,28 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
492
512
  https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
493
513
  """
494
514
  self.return_recv_hook = return_recv_hook
515
+ self.device_module = torch.get_device_module()
516
+ self.quant_config = {}
495
517
 
496
518
  def dispatch_a(
497
519
  self,
498
520
  hidden_states: torch.Tensor,
499
- topk_idx: torch.Tensor,
500
- topk_weights: torch.Tensor,
521
+ topk_output: TopKOutput,
501
522
  ):
502
523
  buffer = self._get_buffer()
503
- topk_idx = topk_idx.to(torch.int64)
524
+ topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
525
+ topk_ids = topk_ids.to(torch.int64)
504
526
  expected_m = (
505
- hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1]
527
+ hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1]
506
528
  + self.num_experts
507
529
  ) // self.num_experts
508
530
  hidden_states, masked_m, event, hook = self._dispatch_core(
509
531
  hidden_states,
510
- topk_idx,
511
- # TODO(shuw): pending https://github.com/deepseek-ai/DeepEP/pull/341
512
- use_fp8=not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"),
532
+ topk_ids,
513
533
  )
514
534
  return (
515
535
  hidden_states,
516
- topk_idx,
536
+ topk_ids,
517
537
  topk_weights,
518
538
  masked_m,
519
539
  expected_m,
@@ -524,7 +544,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
524
544
  def dispatch_b(
525
545
  self,
526
546
  hidden_states,
527
- topk_idx,
547
+ topk_ids,
528
548
  topk_weights,
529
549
  masked_m,
530
550
  expected_m,
@@ -537,9 +557,15 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
537
557
  masked_m
538
558
  )
539
559
 
560
+ if isinstance(hidden_states, tuple):
561
+ hidden_states, hidden_states_scale = hidden_states
562
+ else:
563
+ hidden_states_scale = None
564
+
540
565
  deepep_output = DeepEPLLOutput(
541
566
  hidden_states,
542
- topk_idx,
567
+ hidden_states_scale,
568
+ topk_ids,
543
569
  topk_weights,
544
570
  masked_m,
545
571
  expected_m,
@@ -549,17 +575,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
549
575
  def _dispatch_core(
550
576
  self,
551
577
  hidden_states: torch.Tensor,
552
- topk_idx: torch.Tensor,
553
- use_fp8: bool = False,
578
+ topk_ids: torch.Tensor,
554
579
  ):
580
+ use_nvfp4 = use_fp8 = False
581
+ input_global_scale = self.quant_config.get("input_global_scale", None)
582
+ if input_global_scale is not None:
583
+ use_nvfp4 = True
584
+ elif not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"):
585
+ use_fp8 = True
586
+
555
587
  buffer = self._get_buffer()
556
- packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
588
+ packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
557
589
  buffer.low_latency_dispatch(
558
590
  hidden_states,
559
- topk_idx,
591
+ topk_ids,
560
592
  self.num_max_dispatch_tokens_per_rank,
561
593
  self.num_experts,
562
594
  use_fp8=use_fp8,
595
+ **(dict(use_nvfp4=True) if use_nvfp4 else dict()),
596
+ **(
597
+ dict(x_global_scale=input_global_scale)
598
+ if input_global_scale is not None
599
+ else dict()
600
+ ),
563
601
  async_finish=not self.return_recv_hook,
564
602
  return_recv_hook=self.return_recv_hook,
565
603
  round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
@@ -568,41 +606,68 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
568
606
  and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
569
607
  )
570
608
  )
571
- return packed_recv_hidden, packed_recv_count, event, hook
609
+ return packed_recv_hidden, self.packed_recv_count, event, hook
572
610
 
573
611
  def combine_a(
574
612
  self,
575
613
  hidden_states: torch.Tensor,
576
- topk_idx: torch.Tensor,
614
+ topk_ids: torch.Tensor,
577
615
  topk_weights: torch.Tensor,
616
+ overlap_args: Optional["CombineOverlapArgs"],
578
617
  ):
579
618
  hidden_states, event, hook = self._combine_core(
580
619
  hidden_states,
581
- topk_idx,
620
+ topk_ids,
582
621
  topk_weights,
622
+ overlap_args=overlap_args,
583
623
  )
584
- return hidden_states, event, hook
624
+ return hidden_states, event, hook, overlap_args
625
+
626
+ def combine_b(self, hidden_states, event, hook, overlap_args):
627
+ if overlap_args is not None:
628
+ overlap_args.stream.wait_stream(self.device_module.current_stream())
585
629
 
586
- def combine_b(self, hidden_states, event, hook):
587
630
  hook() if self.return_recv_hook else event.current_stream_wait()
631
+
632
+ if overlap_args is not None:
633
+ self.device_module.current_stream().wait_stream(overlap_args.stream)
634
+
588
635
  return hidden_states
589
636
 
590
637
  def _combine_core(
591
638
  self,
592
639
  hidden_states: torch.Tensor,
593
- topk_idx: torch.Tensor,
640
+ topk_ids: torch.Tensor,
594
641
  topk_weights: torch.Tensor,
642
+ overlap_args: Optional["CombineOverlapArgs"],
595
643
  ):
596
644
  buffer = self._get_buffer()
597
- combined_hidden_states, event, hook = buffer.low_latency_combine(
598
- hidden_states,
599
- topk_idx,
600
- topk_weights,
601
- self.handle,
602
- async_finish=not self.return_recv_hook,
603
- return_recv_hook=self.return_recv_hook,
604
- )
605
- self.handle = None
645
+
646
+ ctx = nullcontext()
647
+ if overlap_args is not None:
648
+ overlap_args.stream.wait_event(overlap_args.wait_event)
649
+ ctx = torch.cuda.stream(overlap_args.stream)
650
+
651
+ with ctx:
652
+ combined_hidden_states, event, hook = buffer.low_latency_combine(
653
+ x=hidden_states,
654
+ topk_idx=topk_ids,
655
+ topk_weights=topk_weights,
656
+ handle=self.handle,
657
+ async_finish=not self.return_recv_hook,
658
+ return_recv_hook=self.return_recv_hook,
659
+ **(
660
+ dict(
661
+ overlap=overlap_args.overlap,
662
+ src_signals=overlap_args.signal,
663
+ src_signal_expect_value=overlap_args.threshold,
664
+ )
665
+ if overlap_args is not None
666
+ else {}
667
+ ),
668
+ )
669
+
670
+ self.packed_recv_count = self.handle = None
606
671
  return combined_hidden_states, event, hook
607
672
 
608
673
  def _get_buffer(self):
@@ -616,6 +681,9 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
616
681
  self.num_experts,
617
682
  )
618
683
 
684
+ def set_quant_config(self, quant_config: dict):
685
+ self.quant_config = quant_config
686
+
619
687
 
620
688
  @dataclass
621
689
  class _Stage(Enum):
@@ -673,23 +741,20 @@ class DeepEPDispatcher(BaseDispatcher):
673
741
  def dispatch_a(
674
742
  self,
675
743
  hidden_states: torch.Tensor,
676
- topk_idx: torch.Tensor,
677
- topk_weights: torch.Tensor,
678
- forward_batch: ForwardBatch,
744
+ topk_output: TopKOutput,
679
745
  ):
680
746
  self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
681
- inner_state = self._get_impl(forward_batch).dispatch_a(
747
+ inner_state = self._get_impl().dispatch_a(
682
748
  hidden_states=hidden_states,
683
- topk_idx=topk_idx,
684
- topk_weights=topk_weights,
749
+ topk_output=topk_output,
685
750
  )
686
- self._dispatch_intermediate_state = forward_batch, inner_state
751
+ self._dispatch_intermediate_state = inner_state
687
752
 
688
753
  def dispatch_b(self):
689
754
  self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
690
- forward_batch, inner_state = self._dispatch_intermediate_state
755
+ inner_state = self._dispatch_intermediate_state
691
756
  del self._dispatch_intermediate_state
692
- return self._get_impl(forward_batch).dispatch_b(*inner_state)
757
+ return self._get_impl().dispatch_b(*inner_state)
693
758
 
694
759
  def combine(self, *args, **kwargs) -> Tuple:
695
760
  self.combine_a(*args, **kwargs)
@@ -699,28 +764,28 @@ class DeepEPDispatcher(BaseDispatcher):
699
764
  def combine_a(
700
765
  self,
701
766
  hidden_states: torch.Tensor,
702
- topk_idx: torch.Tensor,
767
+ topk_ids: torch.Tensor,
703
768
  topk_weights: torch.Tensor,
704
- forward_batch: ForwardBatch,
769
+ overlap_args: Optional["CombineOverlapArgs"] = None,
705
770
  ):
706
771
  self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
707
- inner_state = self._get_impl(forward_batch).combine_a(
772
+ inner_state = self._get_impl().combine_a(
708
773
  hidden_states=hidden_states,
709
- topk_idx=topk_idx,
774
+ topk_ids=topk_ids,
710
775
  topk_weights=topk_weights,
776
+ overlap_args=overlap_args,
711
777
  )
712
- self._combine_intermediate_state = forward_batch, inner_state
778
+ self._combine_intermediate_state = inner_state
713
779
 
714
780
  def combine_b(self):
715
781
  self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
716
- forward_batch, inner_state = self._combine_intermediate_state
782
+ inner_state = self._combine_intermediate_state
717
783
  del self._combine_intermediate_state
718
- return self._get_impl(forward_batch).combine_b(*inner_state)
784
+ return self._get_impl().combine_b(*inner_state)
719
785
 
720
- def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase:
721
- resolved_deepep_mode = self.deepep_mode.resolve(
722
- forward_batch.is_extend_in_batch
723
- )
786
+ def _get_impl(self) -> _DeepEPDispatcherImplBase:
787
+ is_extend_in_batch = get_is_extend_in_batch()
788
+ resolved_deepep_mode = self.deepep_mode.resolve(is_extend_in_batch)
724
789
  if resolved_deepep_mode == DeepEPMode.NORMAL:
725
790
  return self._normal_dispatcher
726
791
  elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
@@ -731,3 +796,9 @@ class DeepEPDispatcher(BaseDispatcher):
731
796
  def _update_stage(self, old_stage, new_stage):
732
797
  assert self._stage == old_stage
733
798
  self._stage = new_stage
799
+
800
+ def set_quant_config(self, quant_config: dict):
801
+ if self.deepep_mode.enable_low_latency():
802
+ self._low_latency_dispatcher.set_quant_config(quant_config)
803
+ if self.deepep_mode.enable_normal():
804
+ self._normal_dispatcher.set_quant_config(quant_config)