sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -3,9 +3,11 @@ from __future__ import annotations
3
3
  import logging
4
4
  from contextlib import nullcontext
5
5
  from dataclasses import dataclass
6
- from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union
6
+ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
7
7
 
8
8
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
9
+ from sglang.srt.layers import deep_gemm_wrapper
10
+ from sglang.srt.layers.dp_attention import get_is_extend_in_batch
9
11
  from sglang.srt.layers.moe.token_dispatcher.base import (
10
12
  BaseDispatcher,
11
13
  BaseDispatcherConfig,
@@ -14,8 +16,13 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
14
16
  DispatchOutput,
15
17
  DispatchOutputFormat,
16
18
  )
17
- from sglang.srt.layers.moe.utils import DeepEPMode, get_deepep_config, is_tbo_enabled
18
- from sglang.srt.layers.quantization import deep_gemm_wrapper
19
+ from sglang.srt.layers.moe.topk import TopKOutput
20
+ from sglang.srt.layers.moe.utils import (
21
+ DeepEPMode,
22
+ get_deepep_config,
23
+ get_moe_runner_backend,
24
+ is_tbo_enabled,
25
+ )
19
26
  from sglang.srt.utils import (
20
27
  get_bool_env_var,
21
28
  get_int_env_var,
@@ -46,8 +53,6 @@ from enum import Enum, IntEnum, auto
46
53
  import torch
47
54
  import torch.distributed as dist
48
55
 
49
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
-
51
56
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
52
57
 
53
58
  logger = logging.getLogger(__name__)
@@ -56,9 +61,9 @@ logger = logging.getLogger(__name__)
56
61
  class DeepEPNormalOutput(NamedTuple):
57
62
  """DeepEP normal dispatch output."""
58
63
 
59
- hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
60
- # hidden_states_scale
61
- topk_idx: torch.Tensor
64
+ hidden_states: torch.Tensor
65
+ hidden_states_scale: Optional[torch.Tensor]
66
+ topk_ids: torch.Tensor
62
67
  topk_weights: torch.Tensor
63
68
  num_recv_tokens_per_expert: List[int]
64
69
 
@@ -70,8 +75,9 @@ class DeepEPNormalOutput(NamedTuple):
70
75
  class DeepEPLLOutput(NamedTuple):
71
76
  """DeepEP low latency dispatch output."""
72
77
 
73
- hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
74
- topk_idx: torch.Tensor
78
+ hidden_states: torch.Tensor
79
+ hidden_states_scale: Optional[torch.Tensor]
80
+ topk_ids: torch.Tensor
75
81
  topk_weights: torch.Tensor
76
82
  masked_m: torch.Tensor
77
83
  expected_m: int
@@ -230,6 +236,15 @@ class DeepEPBuffer:
230
236
  cls.clean_buffer()
231
237
  cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
232
238
 
239
+ @classmethod
240
+ def set_dispatch_mode(cls, mode: DeepEPMode):
241
+ if mode.is_low_latency():
242
+ cls.set_dispatch_mode_as_low_latency()
243
+ elif mode.is_normal():
244
+ cls.set_dispatch_mode_as_normal()
245
+ else:
246
+ raise Exception("unsupported mode")
247
+
233
248
 
234
249
  class DeepEPConfig(BaseDispatcherConfig):
235
250
  _instance = None
@@ -300,9 +315,7 @@ class _DeepEPDispatcherImplBase:
300
315
  def dispatch_a(
301
316
  self,
302
317
  hidden_states: torch.Tensor,
303
- input_global_scale: Optional[torch.Tensor],
304
- topk_idx: torch.Tensor,
305
- topk_weights: torch.Tensor,
318
+ topk_output: TopKOutput,
306
319
  ):
307
320
  raise NotImplementedError
308
321
 
@@ -312,7 +325,7 @@ class _DeepEPDispatcherImplBase:
312
325
  def combine_a(
313
326
  self,
314
327
  hidden_states: torch.Tensor,
315
- topk_idx: torch.Tensor,
328
+ topk_ids: torch.Tensor,
316
329
  topk_weights: torch.Tensor,
317
330
  overlap_args: Optional["CombineOverlapArgs"],
318
331
  ):
@@ -331,16 +344,19 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
331
344
 
332
345
  self.async_finish = async_finish
333
346
  self.src2dst = None
347
+ self.quant_config = {}
334
348
 
335
349
  def dispatch_a(
336
350
  self,
337
351
  hidden_states: torch.Tensor,
338
- input_global_scale: Optional[torch.Tensor],
339
- topk_idx: torch.Tensor,
340
- topk_weights: torch.Tensor,
352
+ topk_output: TopKOutput,
341
353
  ):
342
- topk_idx = topk_idx.to(torch.int64)
343
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
354
+ topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
355
+ topk_ids = topk_ids.to(torch.int64)
356
+ if (
357
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
358
+ and not get_moe_runner_backend().is_cutlass()
359
+ ):
344
360
  # TODO hard code 128 block quant,use fp8 communication
345
361
  hidden_states = sglang_per_token_group_quant_fp8(
346
362
  hidden_states,
@@ -350,25 +366,35 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
350
366
  scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
351
367
  )
352
368
  previous_event = Buffer.capture() if self.async_finish else None
353
- return hidden_states, topk_idx, topk_weights, previous_event
369
+ return hidden_states, topk_ids, topk_weights, previous_event
354
370
 
355
- def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
371
+ def dispatch_b(self, hidden_states, topk_ids, topk_weights, previous_event):
356
372
  (
357
373
  hidden_states,
358
- topk_idx,
374
+ topk_ids,
359
375
  topk_weights,
360
376
  num_recv_tokens_per_expert,
361
377
  event,
362
- ) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
378
+ ) = self._dispatch_core(hidden_states, topk_ids, topk_weights, previous_event)
363
379
  event.current_stream_wait() if self.async_finish else ()
380
+
381
+ if isinstance(hidden_states, tuple):
382
+ hidden_states, hidden_states_scale = hidden_states
383
+ else:
384
+ hidden_states_scale = None
385
+
364
386
  return DeepEPNormalOutput(
365
- hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
387
+ hidden_states,
388
+ hidden_states_scale,
389
+ topk_ids,
390
+ topk_weights,
391
+ num_recv_tokens_per_expert,
366
392
  )
367
393
 
368
394
  def _dispatch_core(
369
395
  self,
370
396
  x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
371
- topk_idx: torch.Tensor,
397
+ topk_ids: torch.Tensor,
372
398
  topk_weights: torch.Tensor,
373
399
  previous_event,
374
400
  ):
@@ -380,27 +406,26 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
380
406
  is_token_in_rank,
381
407
  previous_event,
382
408
  ) = buffer.get_dispatch_layout(
383
- topk_idx,
409
+ topk_ids,
384
410
  self.num_experts,
385
411
  previous_event=previous_event,
386
412
  async_finish=self.async_finish,
387
413
  allocate_on_comm_stream=previous_event is not None,
388
414
  )
389
-
390
415
  # FIXME: `handle` should be transmitted with tokens from dispatch to combine.
391
416
  # However, doing this would incur an unknown synchronization error, but keeping
392
417
  # `handle` as a member variable works.
393
418
 
394
419
  (
395
420
  recv_x,
396
- recv_topk_idx,
421
+ recv_topk_ids,
397
422
  recv_topk_weights,
398
423
  num_recv_tokens_per_expert,
399
424
  self.handle,
400
425
  event,
401
426
  ) = buffer.dispatch(
402
427
  x,
403
- topk_idx=topk_idx,
428
+ topk_idx=topk_ids,
404
429
  topk_weights=topk_weights,
405
430
  num_tokens_per_rank=num_tokens_per_rank,
406
431
  num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
@@ -412,7 +437,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
412
437
  expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
413
438
  config=DeepEPConfig.get_instance().normal_dispatch_config,
414
439
  )
415
-
416
440
  get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
417
441
  num_recv_tokens_per_expert,
418
442
  num_tokens_per_rank=num_tokens_per_rank,
@@ -422,7 +446,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
422
446
 
423
447
  return (
424
448
  recv_x,
425
- recv_topk_idx,
449
+ recv_topk_ids,
426
450
  recv_topk_weights,
427
451
  num_recv_tokens_per_expert,
428
452
  event,
@@ -431,40 +455,16 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
431
455
  def combine_a(
432
456
  self,
433
457
  hidden_states: torch.Tensor,
434
- topk_idx: torch.Tensor,
458
+ topk_ids: torch.Tensor,
435
459
  topk_weights: torch.Tensor,
436
460
  overlap_args: Optional["CombineOverlapArgs"],
437
461
  ):
438
- from sglang.srt.layers.moe.ep_moe.kernels import (
439
- deepep_post_reorder_triton_kernel,
440
- )
441
462
 
442
463
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
443
464
  output = hidden_states
444
465
  else:
445
- if hidden_states.shape[0] > 0:
446
- num_tokens = self.src2dst.shape[0] // self.router_topk
447
- output = torch.empty(
448
- (num_tokens, hidden_states.shape[1]),
449
- device=hidden_states.device,
450
- dtype=hidden_states.dtype,
451
- )
452
- deepep_post_reorder_triton_kernel[(num_tokens,)](
453
- hidden_states,
454
- output,
455
- self.src2dst,
456
- topk_idx,
457
- topk_weights,
458
- self.router_topk,
459
- hidden_states.shape[1],
460
- BLOCK_SIZE=512,
461
- )
462
- else:
463
- output = torch.zeros(
464
- (0, hidden_states.shape[1]),
465
- device=hidden_states.device,
466
- dtype=hidden_states.dtype,
467
- )
466
+ raise NotImplementedError() # triton runner was supported but it's temporarily disabled
467
+
468
468
  previous_event = Buffer.capture() if self.async_finish else None
469
469
  return output, previous_event
470
470
 
@@ -499,6 +499,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
499
499
  self.num_experts,
500
500
  )
501
501
 
502
+ def set_quant_config(self, quant_config: dict):
503
+ self.quant_config = quant_config
504
+
502
505
 
503
506
  class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
504
507
  def __init__(self, return_recv_hook: bool, **kwargs):
@@ -510,28 +513,27 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
510
513
  """
511
514
  self.return_recv_hook = return_recv_hook
512
515
  self.device_module = torch.get_device_module()
516
+ self.quant_config = {}
513
517
 
514
518
  def dispatch_a(
515
519
  self,
516
520
  hidden_states: torch.Tensor,
517
- input_global_scale: Optional[torch.Tensor],
518
- topk_idx: torch.Tensor,
519
- topk_weights: torch.Tensor,
521
+ topk_output: TopKOutput,
520
522
  ):
521
523
  buffer = self._get_buffer()
522
- topk_idx = topk_idx.to(torch.int64)
524
+ topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
525
+ topk_ids = topk_ids.to(torch.int64)
523
526
  expected_m = (
524
- hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1]
527
+ hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1]
525
528
  + self.num_experts
526
529
  ) // self.num_experts
527
530
  hidden_states, masked_m, event, hook = self._dispatch_core(
528
531
  hidden_states,
529
- input_global_scale,
530
- topk_idx,
532
+ topk_ids,
531
533
  )
532
534
  return (
533
535
  hidden_states,
534
- topk_idx,
536
+ topk_ids,
535
537
  topk_weights,
536
538
  masked_m,
537
539
  expected_m,
@@ -542,7 +544,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
542
544
  def dispatch_b(
543
545
  self,
544
546
  hidden_states,
545
- topk_idx,
547
+ topk_ids,
546
548
  topk_weights,
547
549
  masked_m,
548
550
  expected_m,
@@ -555,9 +557,15 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
555
557
  masked_m
556
558
  )
557
559
 
560
+ if isinstance(hidden_states, tuple):
561
+ hidden_states, hidden_states_scale = hidden_states
562
+ else:
563
+ hidden_states_scale = None
564
+
558
565
  deepep_output = DeepEPLLOutput(
559
566
  hidden_states,
560
- topk_idx,
567
+ hidden_states_scale,
568
+ topk_ids,
561
569
  topk_weights,
562
570
  masked_m,
563
571
  expected_m,
@@ -567,10 +575,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
567
575
  def _dispatch_core(
568
576
  self,
569
577
  hidden_states: torch.Tensor,
570
- input_global_scale: Optional[torch.Tensor],
571
- topk_idx: torch.Tensor,
578
+ topk_ids: torch.Tensor,
572
579
  ):
573
580
  use_nvfp4 = use_fp8 = False
581
+ input_global_scale = self.quant_config.get("input_global_scale", None)
574
582
  if input_global_scale is not None:
575
583
  use_nvfp4 = True
576
584
  elif not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"):
@@ -580,7 +588,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
580
588
  packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
581
589
  buffer.low_latency_dispatch(
582
590
  hidden_states,
583
- topk_idx,
591
+ topk_ids,
584
592
  self.num_max_dispatch_tokens_per_rank,
585
593
  self.num_experts,
586
594
  use_fp8=use_fp8,
@@ -603,19 +611,22 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
603
611
  def combine_a(
604
612
  self,
605
613
  hidden_states: torch.Tensor,
606
- topk_idx: torch.Tensor,
614
+ topk_ids: torch.Tensor,
607
615
  topk_weights: torch.Tensor,
608
616
  overlap_args: Optional["CombineOverlapArgs"],
609
617
  ):
610
618
  hidden_states, event, hook = self._combine_core(
611
619
  hidden_states,
612
- topk_idx,
620
+ topk_ids,
613
621
  topk_weights,
614
622
  overlap_args=overlap_args,
615
623
  )
616
624
  return hidden_states, event, hook, overlap_args
617
625
 
618
626
  def combine_b(self, hidden_states, event, hook, overlap_args):
627
+ if overlap_args is not None:
628
+ overlap_args.stream.wait_stream(self.device_module.current_stream())
629
+
619
630
  hook() if self.return_recv_hook else event.current_stream_wait()
620
631
 
621
632
  if overlap_args is not None:
@@ -626,7 +637,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
626
637
  def _combine_core(
627
638
  self,
628
639
  hidden_states: torch.Tensor,
629
- topk_idx: torch.Tensor,
640
+ topk_ids: torch.Tensor,
630
641
  topk_weights: torch.Tensor,
631
642
  overlap_args: Optional["CombineOverlapArgs"],
632
643
  ):
@@ -640,7 +651,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
640
651
  with ctx:
641
652
  combined_hidden_states, event, hook = buffer.low_latency_combine(
642
653
  x=hidden_states,
643
- topk_idx=topk_idx,
654
+ topk_idx=topk_ids,
644
655
  topk_weights=topk_weights,
645
656
  handle=self.handle,
646
657
  async_finish=not self.return_recv_hook,
@@ -670,6 +681,9 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
670
681
  self.num_experts,
671
682
  )
672
683
 
684
+ def set_quant_config(self, quant_config: dict):
685
+ self.quant_config = quant_config
686
+
673
687
 
674
688
  @dataclass
675
689
  class _Stage(Enum):
@@ -727,25 +741,20 @@ class DeepEPDispatcher(BaseDispatcher):
727
741
  def dispatch_a(
728
742
  self,
729
743
  hidden_states: torch.Tensor,
730
- input_global_scale: Optional[torch.Tensor],
731
- topk_idx: torch.Tensor,
732
- topk_weights: torch.Tensor,
733
- forward_batch: ForwardBatch,
744
+ topk_output: TopKOutput,
734
745
  ):
735
746
  self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
736
- inner_state = self._get_impl(forward_batch).dispatch_a(
747
+ inner_state = self._get_impl().dispatch_a(
737
748
  hidden_states=hidden_states,
738
- input_global_scale=input_global_scale,
739
- topk_idx=topk_idx,
740
- topk_weights=topk_weights,
749
+ topk_output=topk_output,
741
750
  )
742
- self._dispatch_intermediate_state = forward_batch, inner_state
751
+ self._dispatch_intermediate_state = inner_state
743
752
 
744
753
  def dispatch_b(self):
745
754
  self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
746
- forward_batch, inner_state = self._dispatch_intermediate_state
755
+ inner_state = self._dispatch_intermediate_state
747
756
  del self._dispatch_intermediate_state
748
- return self._get_impl(forward_batch).dispatch_b(*inner_state)
757
+ return self._get_impl().dispatch_b(*inner_state)
749
758
 
750
759
  def combine(self, *args, **kwargs) -> Tuple:
751
760
  self.combine_a(*args, **kwargs)
@@ -755,30 +764,28 @@ class DeepEPDispatcher(BaseDispatcher):
755
764
  def combine_a(
756
765
  self,
757
766
  hidden_states: torch.Tensor,
758
- topk_idx: torch.Tensor,
767
+ topk_ids: torch.Tensor,
759
768
  topk_weights: torch.Tensor,
760
- forward_batch: ForwardBatch,
761
769
  overlap_args: Optional["CombineOverlapArgs"] = None,
762
770
  ):
763
771
  self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
764
- inner_state = self._get_impl(forward_batch).combine_a(
772
+ inner_state = self._get_impl().combine_a(
765
773
  hidden_states=hidden_states,
766
- topk_idx=topk_idx,
774
+ topk_ids=topk_ids,
767
775
  topk_weights=topk_weights,
768
776
  overlap_args=overlap_args,
769
777
  )
770
- self._combine_intermediate_state = forward_batch, inner_state
778
+ self._combine_intermediate_state = inner_state
771
779
 
772
780
  def combine_b(self):
773
781
  self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
774
- forward_batch, inner_state = self._combine_intermediate_state
782
+ inner_state = self._combine_intermediate_state
775
783
  del self._combine_intermediate_state
776
- return self._get_impl(forward_batch).combine_b(*inner_state)
784
+ return self._get_impl().combine_b(*inner_state)
777
785
 
778
- def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase:
779
- resolved_deepep_mode = self.deepep_mode.resolve(
780
- forward_batch.is_extend_in_batch
781
- )
786
+ def _get_impl(self) -> _DeepEPDispatcherImplBase:
787
+ is_extend_in_batch = get_is_extend_in_batch()
788
+ resolved_deepep_mode = self.deepep_mode.resolve(is_extend_in_batch)
782
789
  if resolved_deepep_mode == DeepEPMode.NORMAL:
783
790
  return self._normal_dispatcher
784
791
  elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
@@ -789,3 +796,9 @@ class DeepEPDispatcher(BaseDispatcher):
789
796
  def _update_stage(self, old_stage, new_stage):
790
797
  assert self._stage == old_stage
791
798
  self._stage = new_stage
799
+
800
+ def set_quant_config(self, quant_config: dict):
801
+ if self.deepep_mode.enable_low_latency():
802
+ self._low_latency_dispatcher.set_quant_config(quant_config)
803
+ if self.deepep_mode.enable_normal():
804
+ self._normal_dispatcher.set_quant_config(quant_config)