sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -3,9 +3,11 @@ from __future__ import annotations
3
3
  import logging
4
4
  from contextlib import nullcontext
5
5
  from dataclasses import dataclass
6
- from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union
6
+ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
7
7
 
8
8
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
9
+ from sglang.srt.layers import deep_gemm_wrapper
10
+ from sglang.srt.layers.dp_attention import get_is_extend_in_batch
9
11
  from sglang.srt.layers.moe.token_dispatcher.base import (
10
12
  BaseDispatcher,
11
13
  BaseDispatcherConfig,
@@ -14,8 +16,13 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
14
16
  DispatchOutput,
15
17
  DispatchOutputFormat,
16
18
  )
17
- from sglang.srt.layers.moe.utils import DeepEPMode, get_deepep_config, is_tbo_enabled
18
- from sglang.srt.layers.quantization import deep_gemm_wrapper
19
+ from sglang.srt.layers.moe.topk import TopKOutput
20
+ from sglang.srt.layers.moe.utils import (
21
+ DeepEPMode,
22
+ get_deepep_config,
23
+ get_moe_runner_backend,
24
+ is_tbo_enabled,
25
+ )
19
26
  from sglang.srt.utils import (
20
27
  get_bool_env_var,
21
28
  get_int_env_var,
@@ -46,19 +53,17 @@ from enum import Enum, IntEnum, auto
46
53
  import torch
47
54
  import torch.distributed as dist
48
55
 
49
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
-
51
56
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
52
57
 
53
58
  logger = logging.getLogger(__name__)
54
59
 
55
60
 
56
- class DeepEPNormalOutput(NamedTuple):
61
+ class DeepEPNormalDispatchOutput(NamedTuple):
57
62
  """DeepEP normal dispatch output."""
58
63
 
59
- hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
60
- # hidden_states_scale
61
- topk_idx: torch.Tensor
64
+ hidden_states: torch.Tensor
65
+ hidden_states_scale: Optional[torch.Tensor]
66
+ topk_ids: torch.Tensor
62
67
  topk_weights: torch.Tensor
63
68
  num_recv_tokens_per_expert: List[int]
64
69
 
@@ -67,11 +72,12 @@ class DeepEPNormalOutput(NamedTuple):
67
72
  return DispatchOutputFormat.DEEPEP_NORMAL
68
73
 
69
74
 
70
- class DeepEPLLOutput(NamedTuple):
75
+ class DeepEPLLDispatchOutput(NamedTuple):
71
76
  """DeepEP low latency dispatch output."""
72
77
 
73
- hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
74
- topk_idx: torch.Tensor
78
+ hidden_states: torch.Tensor
79
+ hidden_states_scale: Optional[torch.Tensor]
80
+ topk_ids: torch.Tensor
75
81
  topk_weights: torch.Tensor
76
82
  masked_m: torch.Tensor
77
83
  expected_m: int
@@ -81,14 +87,17 @@ class DeepEPLLOutput(NamedTuple):
81
87
  return DispatchOutputFormat.DEEPEP_LL
82
88
 
83
89
 
84
- assert isinstance(DeepEPNormalOutput, DispatchOutput)
85
- assert isinstance(DeepEPLLOutput, DispatchOutput)
90
+ assert isinstance(DeepEPNormalDispatchOutput, DispatchOutput)
91
+ assert isinstance(DeepEPLLDispatchOutput, DispatchOutput)
86
92
 
87
93
 
88
94
  class DeepEPNormalCombineInput(NamedTuple):
89
95
  """DeepEP normal combine input."""
90
96
 
91
- pass
97
+ hidden_states: torch.Tensor
98
+ topk_ids: torch.Tensor
99
+ topk_weights: torch.Tensor
100
+ overlap_args: Optional[CombineOverlapArgs] = None
92
101
 
93
102
  @property
94
103
  def format(self) -> CombineInputFormat:
@@ -98,7 +107,10 @@ class DeepEPNormalCombineInput(NamedTuple):
98
107
  class DeepEPLLCombineInput(NamedTuple):
99
108
  """DeepEP low latency combine input."""
100
109
 
101
- pass
110
+ hidden_states: torch.Tensor
111
+ topk_ids: torch.Tensor
112
+ topk_weights: torch.Tensor
113
+ overlap_args: Optional[CombineOverlapArgs] = None
102
114
 
103
115
  @property
104
116
  def format(self) -> CombineInputFormat:
@@ -230,6 +242,15 @@ class DeepEPBuffer:
230
242
  cls.clean_buffer()
231
243
  cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
232
244
 
245
+ @classmethod
246
+ def set_dispatch_mode(cls, mode: DeepEPMode):
247
+ if mode.is_low_latency():
248
+ cls.set_dispatch_mode_as_low_latency()
249
+ elif mode.is_normal():
250
+ cls.set_dispatch_mode_as_normal()
251
+ else:
252
+ raise Exception("unsupported mode")
253
+
233
254
 
234
255
  class DeepEPConfig(BaseDispatcherConfig):
235
256
  _instance = None
@@ -300,9 +321,7 @@ class _DeepEPDispatcherImplBase:
300
321
  def dispatch_a(
301
322
  self,
302
323
  hidden_states: torch.Tensor,
303
- input_global_scale: Optional[torch.Tensor],
304
- topk_idx: torch.Tensor,
305
- topk_weights: torch.Tensor,
324
+ topk_output: TopKOutput,
306
325
  ):
307
326
  raise NotImplementedError
308
327
 
@@ -312,7 +331,7 @@ class _DeepEPDispatcherImplBase:
312
331
  def combine_a(
313
332
  self,
314
333
  hidden_states: torch.Tensor,
315
- topk_idx: torch.Tensor,
334
+ topk_ids: torch.Tensor,
316
335
  topk_weights: torch.Tensor,
317
336
  overlap_args: Optional["CombineOverlapArgs"],
318
337
  ):
@@ -331,16 +350,19 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
331
350
 
332
351
  self.async_finish = async_finish
333
352
  self.src2dst = None
353
+ self.quant_config = {}
334
354
 
335
355
  def dispatch_a(
336
356
  self,
337
357
  hidden_states: torch.Tensor,
338
- input_global_scale: Optional[torch.Tensor],
339
- topk_idx: torch.Tensor,
340
- topk_weights: torch.Tensor,
358
+ topk_output: TopKOutput,
341
359
  ):
342
- topk_idx = topk_idx.to(torch.int64)
343
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
360
+ topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
361
+ topk_ids = topk_ids.to(torch.int64)
362
+ if (
363
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
364
+ and not get_moe_runner_backend().is_cutlass()
365
+ ):
344
366
  # TODO hard code 128 block quant,use fp8 communication
345
367
  hidden_states = sglang_per_token_group_quant_fp8(
346
368
  hidden_states,
@@ -350,25 +372,35 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
350
372
  scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
351
373
  )
352
374
  previous_event = Buffer.capture() if self.async_finish else None
353
- return hidden_states, topk_idx, topk_weights, previous_event
375
+ return hidden_states, topk_ids, topk_weights, previous_event
354
376
 
355
- def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
377
+ def dispatch_b(self, hidden_states, topk_ids, topk_weights, previous_event):
356
378
  (
357
379
  hidden_states,
358
- topk_idx,
380
+ topk_ids,
359
381
  topk_weights,
360
382
  num_recv_tokens_per_expert,
361
383
  event,
362
- ) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
384
+ ) = self._dispatch_core(hidden_states, topk_ids, topk_weights, previous_event)
363
385
  event.current_stream_wait() if self.async_finish else ()
364
- return DeepEPNormalOutput(
365
- hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
386
+
387
+ if isinstance(hidden_states, tuple):
388
+ hidden_states, hidden_states_scale = hidden_states
389
+ else:
390
+ hidden_states_scale = None
391
+
392
+ return DeepEPNormalDispatchOutput(
393
+ hidden_states,
394
+ hidden_states_scale,
395
+ topk_ids,
396
+ topk_weights,
397
+ num_recv_tokens_per_expert,
366
398
  )
367
399
 
368
400
  def _dispatch_core(
369
401
  self,
370
402
  x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
371
- topk_idx: torch.Tensor,
403
+ topk_ids: torch.Tensor,
372
404
  topk_weights: torch.Tensor,
373
405
  previous_event,
374
406
  ):
@@ -380,27 +412,26 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
380
412
  is_token_in_rank,
381
413
  previous_event,
382
414
  ) = buffer.get_dispatch_layout(
383
- topk_idx,
415
+ topk_ids,
384
416
  self.num_experts,
385
417
  previous_event=previous_event,
386
418
  async_finish=self.async_finish,
387
419
  allocate_on_comm_stream=previous_event is not None,
388
420
  )
389
-
390
421
  # FIXME: `handle` should be transmitted with tokens from dispatch to combine.
391
422
  # However, doing this would incur an unknown synchronization error, but keeping
392
423
  # `handle` as a member variable works.
393
424
 
394
425
  (
395
426
  recv_x,
396
- recv_topk_idx,
427
+ recv_topk_ids,
397
428
  recv_topk_weights,
398
429
  num_recv_tokens_per_expert,
399
430
  self.handle,
400
431
  event,
401
432
  ) = buffer.dispatch(
402
433
  x,
403
- topk_idx=topk_idx,
434
+ topk_idx=topk_ids,
404
435
  topk_weights=topk_weights,
405
436
  num_tokens_per_rank=num_tokens_per_rank,
406
437
  num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
@@ -412,7 +443,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
412
443
  expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
413
444
  config=DeepEPConfig.get_instance().normal_dispatch_config,
414
445
  )
415
-
416
446
  get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
417
447
  num_recv_tokens_per_expert,
418
448
  num_tokens_per_rank=num_tokens_per_rank,
@@ -422,7 +452,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
422
452
 
423
453
  return (
424
454
  recv_x,
425
- recv_topk_idx,
455
+ recv_topk_ids,
426
456
  recv_topk_weights,
427
457
  num_recv_tokens_per_expert,
428
458
  event,
@@ -431,40 +461,16 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
431
461
  def combine_a(
432
462
  self,
433
463
  hidden_states: torch.Tensor,
434
- topk_idx: torch.Tensor,
464
+ topk_ids: torch.Tensor,
435
465
  topk_weights: torch.Tensor,
436
466
  overlap_args: Optional["CombineOverlapArgs"],
437
467
  ):
438
- from sglang.srt.layers.moe.ep_moe.kernels import (
439
- deepep_post_reorder_triton_kernel,
440
- )
441
468
 
442
469
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
443
470
  output = hidden_states
444
471
  else:
445
- if hidden_states.shape[0] > 0:
446
- num_tokens = self.src2dst.shape[0] // self.router_topk
447
- output = torch.empty(
448
- (num_tokens, hidden_states.shape[1]),
449
- device=hidden_states.device,
450
- dtype=hidden_states.dtype,
451
- )
452
- deepep_post_reorder_triton_kernel[(num_tokens,)](
453
- hidden_states,
454
- output,
455
- self.src2dst,
456
- topk_idx,
457
- topk_weights,
458
- self.router_topk,
459
- hidden_states.shape[1],
460
- BLOCK_SIZE=512,
461
- )
462
- else:
463
- output = torch.zeros(
464
- (0, hidden_states.shape[1]),
465
- device=hidden_states.device,
466
- dtype=hidden_states.dtype,
467
- )
472
+ raise NotImplementedError() # triton runner was supported but it's temporarily disabled
473
+
468
474
  previous_event = Buffer.capture() if self.async_finish else None
469
475
  return output, previous_event
470
476
 
@@ -499,6 +505,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
499
505
  self.num_experts,
500
506
  )
501
507
 
508
+ def set_quant_config(self, quant_config: dict):
509
+ self.quant_config = quant_config
510
+
502
511
 
503
512
  class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
504
513
  def __init__(self, return_recv_hook: bool, **kwargs):
@@ -510,28 +519,27 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
510
519
  """
511
520
  self.return_recv_hook = return_recv_hook
512
521
  self.device_module = torch.get_device_module()
522
+ self.quant_config = {}
513
523
 
514
524
  def dispatch_a(
515
525
  self,
516
526
  hidden_states: torch.Tensor,
517
- input_global_scale: Optional[torch.Tensor],
518
- topk_idx: torch.Tensor,
519
- topk_weights: torch.Tensor,
527
+ topk_output: TopKOutput,
520
528
  ):
521
529
  buffer = self._get_buffer()
522
- topk_idx = topk_idx.to(torch.int64)
530
+ topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
531
+ topk_ids = topk_ids.to(torch.int64)
523
532
  expected_m = (
524
- hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1]
533
+ hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1]
525
534
  + self.num_experts
526
535
  ) // self.num_experts
527
536
  hidden_states, masked_m, event, hook = self._dispatch_core(
528
537
  hidden_states,
529
- input_global_scale,
530
- topk_idx,
538
+ topk_ids,
531
539
  )
532
540
  return (
533
541
  hidden_states,
534
- topk_idx,
542
+ topk_ids,
535
543
  topk_weights,
536
544
  masked_m,
537
545
  expected_m,
@@ -542,7 +550,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
542
550
  def dispatch_b(
543
551
  self,
544
552
  hidden_states,
545
- topk_idx,
553
+ topk_ids,
546
554
  topk_weights,
547
555
  masked_m,
548
556
  expected_m,
@@ -555,9 +563,15 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
555
563
  masked_m
556
564
  )
557
565
 
558
- deepep_output = DeepEPLLOutput(
566
+ if isinstance(hidden_states, tuple):
567
+ hidden_states, hidden_states_scale = hidden_states
568
+ else:
569
+ hidden_states_scale = None
570
+
571
+ deepep_output = DeepEPLLDispatchOutput(
559
572
  hidden_states,
560
- topk_idx,
573
+ hidden_states_scale,
574
+ topk_ids,
561
575
  topk_weights,
562
576
  masked_m,
563
577
  expected_m,
@@ -567,10 +581,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
567
581
  def _dispatch_core(
568
582
  self,
569
583
  hidden_states: torch.Tensor,
570
- input_global_scale: Optional[torch.Tensor],
571
- topk_idx: torch.Tensor,
584
+ topk_ids: torch.Tensor,
572
585
  ):
573
586
  use_nvfp4 = use_fp8 = False
587
+ input_global_scale = self.quant_config.get("input_global_scale", None)
574
588
  if input_global_scale is not None:
575
589
  use_nvfp4 = True
576
590
  elif not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"):
@@ -580,7 +594,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
580
594
  packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
581
595
  buffer.low_latency_dispatch(
582
596
  hidden_states,
583
- topk_idx,
597
+ topk_ids,
584
598
  self.num_max_dispatch_tokens_per_rank,
585
599
  self.num_experts,
586
600
  use_fp8=use_fp8,
@@ -603,19 +617,22 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
603
617
  def combine_a(
604
618
  self,
605
619
  hidden_states: torch.Tensor,
606
- topk_idx: torch.Tensor,
620
+ topk_ids: torch.Tensor,
607
621
  topk_weights: torch.Tensor,
608
622
  overlap_args: Optional["CombineOverlapArgs"],
609
623
  ):
610
624
  hidden_states, event, hook = self._combine_core(
611
625
  hidden_states,
612
- topk_idx,
626
+ topk_ids,
613
627
  topk_weights,
614
628
  overlap_args=overlap_args,
615
629
  )
616
630
  return hidden_states, event, hook, overlap_args
617
631
 
618
632
  def combine_b(self, hidden_states, event, hook, overlap_args):
633
+ if overlap_args is not None:
634
+ overlap_args.stream.wait_stream(self.device_module.current_stream())
635
+
619
636
  hook() if self.return_recv_hook else event.current_stream_wait()
620
637
 
621
638
  if overlap_args is not None:
@@ -626,7 +643,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
626
643
  def _combine_core(
627
644
  self,
628
645
  hidden_states: torch.Tensor,
629
- topk_idx: torch.Tensor,
646
+ topk_ids: torch.Tensor,
630
647
  topk_weights: torch.Tensor,
631
648
  overlap_args: Optional["CombineOverlapArgs"],
632
649
  ):
@@ -640,7 +657,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
640
657
  with ctx:
641
658
  combined_hidden_states, event, hook = buffer.low_latency_combine(
642
659
  x=hidden_states,
643
- topk_idx=topk_idx,
660
+ topk_idx=topk_ids,
644
661
  topk_weights=topk_weights,
645
662
  handle=self.handle,
646
663
  async_finish=not self.return_recv_hook,
@@ -670,6 +687,9 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
670
687
  self.num_experts,
671
688
  )
672
689
 
690
+ def set_quant_config(self, quant_config: dict):
691
+ self.quant_config = quant_config
692
+
673
693
 
674
694
  @dataclass
675
695
  class _Stage(Enum):
@@ -727,58 +747,49 @@ class DeepEPDispatcher(BaseDispatcher):
727
747
  def dispatch_a(
728
748
  self,
729
749
  hidden_states: torch.Tensor,
730
- input_global_scale: Optional[torch.Tensor],
731
- topk_idx: torch.Tensor,
732
- topk_weights: torch.Tensor,
733
- forward_batch: ForwardBatch,
750
+ topk_output: TopKOutput,
734
751
  ):
735
752
  self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
736
- inner_state = self._get_impl(forward_batch).dispatch_a(
753
+ inner_state = self._get_impl().dispatch_a(
737
754
  hidden_states=hidden_states,
738
- input_global_scale=input_global_scale,
739
- topk_idx=topk_idx,
740
- topk_weights=topk_weights,
755
+ topk_output=topk_output,
741
756
  )
742
- self._dispatch_intermediate_state = forward_batch, inner_state
757
+ self._dispatch_intermediate_state = inner_state
743
758
 
744
759
  def dispatch_b(self):
745
760
  self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
746
- forward_batch, inner_state = self._dispatch_intermediate_state
761
+ inner_state = self._dispatch_intermediate_state
747
762
  del self._dispatch_intermediate_state
748
- return self._get_impl(forward_batch).dispatch_b(*inner_state)
763
+ return self._get_impl().dispatch_b(*inner_state)
749
764
 
750
- def combine(self, *args, **kwargs) -> Tuple:
751
- self.combine_a(*args, **kwargs)
765
+ def combine(self, combine_input: CombineInput) -> Tuple:
766
+ self.combine_a(combine_input)
752
767
  ret = self.combine_b()
753
768
  return ret
754
769
 
755
770
  def combine_a(
756
771
  self,
757
- hidden_states: torch.Tensor,
758
- topk_idx: torch.Tensor,
759
- topk_weights: torch.Tensor,
760
- forward_batch: ForwardBatch,
761
- overlap_args: Optional["CombineOverlapArgs"] = None,
772
+ combine_input: CombineInput,
762
773
  ):
774
+ hidden_states, topk_ids, topk_weights, overlap_args = combine_input
763
775
  self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
764
- inner_state = self._get_impl(forward_batch).combine_a(
776
+ inner_state = self._get_impl().combine_a(
765
777
  hidden_states=hidden_states,
766
- topk_idx=topk_idx,
778
+ topk_ids=topk_ids,
767
779
  topk_weights=topk_weights,
768
780
  overlap_args=overlap_args,
769
781
  )
770
- self._combine_intermediate_state = forward_batch, inner_state
782
+ self._combine_intermediate_state = inner_state
771
783
 
772
784
  def combine_b(self):
773
785
  self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
774
- forward_batch, inner_state = self._combine_intermediate_state
786
+ inner_state = self._combine_intermediate_state
775
787
  del self._combine_intermediate_state
776
- return self._get_impl(forward_batch).combine_b(*inner_state)
788
+ return self._get_impl().combine_b(*inner_state)
777
789
 
778
- def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase:
779
- resolved_deepep_mode = self.deepep_mode.resolve(
780
- forward_batch.is_extend_in_batch
781
- )
790
+ def _get_impl(self) -> _DeepEPDispatcherImplBase:
791
+ is_extend_in_batch = get_is_extend_in_batch()
792
+ resolved_deepep_mode = self.deepep_mode.resolve(is_extend_in_batch)
782
793
  if resolved_deepep_mode == DeepEPMode.NORMAL:
783
794
  return self._normal_dispatcher
784
795
  elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
@@ -789,3 +800,9 @@ class DeepEPDispatcher(BaseDispatcher):
789
800
  def _update_stage(self, old_stage, new_stage):
790
801
  assert self._stage == old_stage
791
802
  self._stage = new_stage
803
+
804
+ def set_quant_config(self, quant_config: dict):
805
+ if self.deepep_mode.enable_low_latency():
806
+ self._low_latency_dispatcher.set_quant_config(quant_config)
807
+ if self.deepep_mode.enable_normal():
808
+ self._normal_dispatcher.set_quant_config(quant_config)