sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 2
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 2
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 32,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 64,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 2
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 8,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 8,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 32,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 8,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 8,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 4,
144
+ "num_stages": 2
145
+ }
146
+ }
@@ -51,10 +51,14 @@ def get_moe_configs(
51
51
 
52
52
  # We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
53
53
  # so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
54
+ config_dir = os.environ.get(
55
+ "SGLANG_MOE_CONFIG_DIR", os.path.dirname(os.path.realpath(__file__))
56
+ )
57
+
54
58
  triton_version = triton.__version__
55
59
  version_dir = f"triton_{triton_version.replace('.', '_')}"
56
60
  config_file_path = os.path.join(
57
- os.path.dirname(os.path.realpath(__file__)),
61
+ config_dir,
58
62
  "configs",
59
63
  version_dir,
60
64
  json_file_name,
@@ -75,7 +79,7 @@ def get_moe_configs(
75
79
  if try_triton_version == triton_version:
76
80
  continue
77
81
  try_config_file_path = os.path.join(
78
- os.path.dirname(os.path.realpath(__file__)),
82
+ config_dir,
79
83
  "configs",
80
84
  f"triton_{try_triton_version.replace('.', '_')}",
81
85
  json_file_name,
@@ -11,12 +11,8 @@ from sglang.srt.distributed import (
11
11
  get_moe_expert_parallel_world_size,
12
12
  get_moe_tensor_parallel_rank,
13
13
  get_moe_tensor_parallel_world_size,
14
- get_tp_group,
15
14
  tensor_model_parallel_all_reduce,
16
15
  )
17
- from sglang.srt.distributed.device_communicators.pynccl_allocator import (
18
- use_symmetric_memory,
19
- )
20
16
  from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
21
17
  from sglang.srt.layers.moe import (
22
18
  MoeRunnerConfig,
@@ -24,7 +20,6 @@ from sglang.srt.layers.moe import (
24
20
  should_use_flashinfer_trtllm_moe,
25
21
  )
26
22
  from sglang.srt.layers.moe.token_dispatcher.standard import (
27
- CombineInput,
28
23
  StandardDispatcher,
29
24
  StandardDispatchOutput,
30
25
  )
@@ -239,6 +234,13 @@ class FusedMoE(torch.nn.Module):
239
234
  self.quant_method.create_moe_runner(self, self.moe_runner_config)
240
235
  self.dispatcher = StandardDispatcher()
241
236
 
237
+ self.should_fuse_routed_scaling_factor_in_topk = isinstance(
238
+ self.quant_method, ModelOptNvFp4FusedMoEMethod
239
+ ) or (
240
+ isinstance(self.quant_method, Fp8MoEMethod)
241
+ and self.quant_method.use_cutlass_fused_experts_fp8
242
+ )
243
+
242
244
  def _load_per_tensor_weight_scale(
243
245
  self,
244
246
  shard_id: str,
@@ -575,7 +577,10 @@ class FusedMoE(torch.nn.Module):
575
577
  )
576
578
 
577
579
  # Flashinfer assumes w31 format for w13_weight. Same for the scales.
578
- if should_use_flashinfer_trtllm_moe():
580
+ if should_use_flashinfer_trtllm_moe() and (
581
+ isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
582
+ or isinstance(self.quant_method, Fp8MoEMethod)
583
+ ):
579
584
  shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
580
585
 
581
586
  WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
@@ -938,12 +943,6 @@ class FusedMoE(torch.nn.Module):
938
943
  for shard_id in ["w1", "w2", "w3"]
939
944
  ]
940
945
 
941
- def should_fuse_routed_scaling_factor_in_topk(self):
942
- return isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) or (
943
- isinstance(self.quant_method, Fp8MoEMethod)
944
- and self.quant_method.use_cutlass_fused_experts_fp8
945
- )
946
-
947
946
 
948
947
  class FlashInferFusedMoE(FusedMoE):
949
948
  def __init__(self, *args, **kwargs):
@@ -1,8 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
+ from contextlib import nullcontext
4
5
  from dataclasses import dataclass
5
- from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
6
+ from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union
6
7
 
7
8
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
8
9
  from sglang.srt.layers.moe.token_dispatcher.base import (
@@ -25,6 +26,9 @@ from sglang.srt.utils import (
25
26
 
26
27
  _is_npu = is_npu()
27
28
 
29
+ if TYPE_CHECKING:
30
+ from sglang.srt.single_batch_overlap import CombineOverlapArgs
31
+
28
32
  try:
29
33
  from deep_ep import Buffer, Config
30
34
 
@@ -164,10 +168,19 @@ class DeepEPBuffer:
164
168
  num_rdma_bytes,
165
169
  )
166
170
 
171
+ # We should calculate num_qps_per_rank consistently with DeepEP's test script logic:
167
172
  if deepep_mode == DeepEPMode.NORMAL:
168
- num_qps_per_rank = DeepEPConfig.get_instance().num_sms // 2
169
- elif deepep_mode in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]:
173
+ # refer: https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py#L235
174
+ num_qps_per_rank = DeepEPConfig.get_instance().num_sms
175
+ elif deepep_mode == DeepEPMode.LOW_LATENCY:
176
+ # refer: https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_low_latency.py#L176
170
177
  num_qps_per_rank = num_experts // group.size()
178
+ elif deepep_mode == DeepEPMode.AUTO:
179
+ # low-latency and normal mode all need run
180
+ # refer: https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py#L235
181
+ num_qps_per_rank = max(
182
+ DeepEPConfig.get_instance().num_sms, num_experts // group.size()
183
+ )
171
184
  else:
172
185
  raise NotImplementedError
173
186
 
@@ -287,6 +300,7 @@ class _DeepEPDispatcherImplBase:
287
300
  def dispatch_a(
288
301
  self,
289
302
  hidden_states: torch.Tensor,
303
+ input_global_scale: Optional[torch.Tensor],
290
304
  topk_idx: torch.Tensor,
291
305
  topk_weights: torch.Tensor,
292
306
  ):
@@ -300,6 +314,7 @@ class _DeepEPDispatcherImplBase:
300
314
  hidden_states: torch.Tensor,
301
315
  topk_idx: torch.Tensor,
302
316
  topk_weights: torch.Tensor,
317
+ overlap_args: Optional["CombineOverlapArgs"],
303
318
  ):
304
319
  raise NotImplementedError
305
320
 
@@ -320,6 +335,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
320
335
  def dispatch_a(
321
336
  self,
322
337
  hidden_states: torch.Tensor,
338
+ input_global_scale: Optional[torch.Tensor],
323
339
  topk_idx: torch.Tensor,
324
340
  topk_weights: torch.Tensor,
325
341
  ):
@@ -417,6 +433,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
417
433
  hidden_states: torch.Tensor,
418
434
  topk_idx: torch.Tensor,
419
435
  topk_weights: torch.Tensor,
436
+ overlap_args: Optional["CombineOverlapArgs"],
420
437
  ):
421
438
  from sglang.srt.layers.moe.ep_moe.kernels import (
422
439
  deepep_post_reorder_triton_kernel,
@@ -492,10 +509,12 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
492
509
  https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
493
510
  """
494
511
  self.return_recv_hook = return_recv_hook
512
+ self.device_module = torch.get_device_module()
495
513
 
496
514
  def dispatch_a(
497
515
  self,
498
516
  hidden_states: torch.Tensor,
517
+ input_global_scale: Optional[torch.Tensor],
499
518
  topk_idx: torch.Tensor,
500
519
  topk_weights: torch.Tensor,
501
520
  ):
@@ -507,9 +526,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
507
526
  ) // self.num_experts
508
527
  hidden_states, masked_m, event, hook = self._dispatch_core(
509
528
  hidden_states,
529
+ input_global_scale,
510
530
  topk_idx,
511
- # TODO(shuw): pending https://github.com/deepseek-ai/DeepEP/pull/341
512
- use_fp8=not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"),
513
531
  )
514
532
  return (
515
533
  hidden_states,
@@ -549,17 +567,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
549
567
  def _dispatch_core(
550
568
  self,
551
569
  hidden_states: torch.Tensor,
570
+ input_global_scale: Optional[torch.Tensor],
552
571
  topk_idx: torch.Tensor,
553
- use_fp8: bool = False,
554
572
  ):
573
+ use_nvfp4 = use_fp8 = False
574
+ if input_global_scale is not None:
575
+ use_nvfp4 = True
576
+ elif not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"):
577
+ use_fp8 = True
578
+
555
579
  buffer = self._get_buffer()
556
- packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
580
+ packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
557
581
  buffer.low_latency_dispatch(
558
582
  hidden_states,
559
583
  topk_idx,
560
584
  self.num_max_dispatch_tokens_per_rank,
561
585
  self.num_experts,
562
586
  use_fp8=use_fp8,
587
+ **(dict(use_nvfp4=True) if use_nvfp4 else dict()),
588
+ **(
589
+ dict(x_global_scale=input_global_scale)
590
+ if input_global_scale is not None
591
+ else dict()
592
+ ),
563
593
  async_finish=not self.return_recv_hook,
564
594
  return_recv_hook=self.return_recv_hook,
565
595
  round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
@@ -568,23 +598,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
568
598
  and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
569
599
  )
570
600
  )
571
- return packed_recv_hidden, packed_recv_count, event, hook
601
+ return packed_recv_hidden, self.packed_recv_count, event, hook
572
602
 
573
603
  def combine_a(
574
604
  self,
575
605
  hidden_states: torch.Tensor,
576
606
  topk_idx: torch.Tensor,
577
607
  topk_weights: torch.Tensor,
608
+ overlap_args: Optional["CombineOverlapArgs"],
578
609
  ):
579
610
  hidden_states, event, hook = self._combine_core(
580
611
  hidden_states,
581
612
  topk_idx,
582
613
  topk_weights,
614
+ overlap_args=overlap_args,
583
615
  )
584
- return hidden_states, event, hook
616
+ return hidden_states, event, hook, overlap_args
585
617
 
586
- def combine_b(self, hidden_states, event, hook):
618
+ def combine_b(self, hidden_states, event, hook, overlap_args):
587
619
  hook() if self.return_recv_hook else event.current_stream_wait()
620
+
621
+ if overlap_args is not None:
622
+ self.device_module.current_stream().wait_stream(overlap_args.stream)
623
+
588
624
  return hidden_states
589
625
 
590
626
  def _combine_core(
@@ -592,17 +628,35 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
592
628
  hidden_states: torch.Tensor,
593
629
  topk_idx: torch.Tensor,
594
630
  topk_weights: torch.Tensor,
631
+ overlap_args: Optional["CombineOverlapArgs"],
595
632
  ):
596
633
  buffer = self._get_buffer()
597
- combined_hidden_states, event, hook = buffer.low_latency_combine(
598
- hidden_states,
599
- topk_idx,
600
- topk_weights,
601
- self.handle,
602
- async_finish=not self.return_recv_hook,
603
- return_recv_hook=self.return_recv_hook,
604
- )
605
- self.handle = None
634
+
635
+ ctx = nullcontext()
636
+ if overlap_args is not None:
637
+ overlap_args.stream.wait_event(overlap_args.wait_event)
638
+ ctx = torch.cuda.stream(overlap_args.stream)
639
+
640
+ with ctx:
641
+ combined_hidden_states, event, hook = buffer.low_latency_combine(
642
+ x=hidden_states,
643
+ topk_idx=topk_idx,
644
+ topk_weights=topk_weights,
645
+ handle=self.handle,
646
+ async_finish=not self.return_recv_hook,
647
+ return_recv_hook=self.return_recv_hook,
648
+ **(
649
+ dict(
650
+ overlap=overlap_args.overlap,
651
+ src_signals=overlap_args.signal,
652
+ src_signal_expect_value=overlap_args.threshold,
653
+ )
654
+ if overlap_args is not None
655
+ else {}
656
+ ),
657
+ )
658
+
659
+ self.packed_recv_count = self.handle = None
606
660
  return combined_hidden_states, event, hook
607
661
 
608
662
  def _get_buffer(self):
@@ -673,6 +727,7 @@ class DeepEPDispatcher(BaseDispatcher):
673
727
  def dispatch_a(
674
728
  self,
675
729
  hidden_states: torch.Tensor,
730
+ input_global_scale: Optional[torch.Tensor],
676
731
  topk_idx: torch.Tensor,
677
732
  topk_weights: torch.Tensor,
678
733
  forward_batch: ForwardBatch,
@@ -680,6 +735,7 @@ class DeepEPDispatcher(BaseDispatcher):
680
735
  self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
681
736
  inner_state = self._get_impl(forward_batch).dispatch_a(
682
737
  hidden_states=hidden_states,
738
+ input_global_scale=input_global_scale,
683
739
  topk_idx=topk_idx,
684
740
  topk_weights=topk_weights,
685
741
  )
@@ -702,12 +758,14 @@ class DeepEPDispatcher(BaseDispatcher):
702
758
  topk_idx: torch.Tensor,
703
759
  topk_weights: torch.Tensor,
704
760
  forward_batch: ForwardBatch,
761
+ overlap_args: Optional["CombineOverlapArgs"] = None,
705
762
  ):
706
763
  self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
707
764
  inner_state = self._get_impl(forward_batch).combine_a(
708
765
  hidden_states=hidden_states,
709
766
  topk_idx=topk_idx,
710
767
  topk_weights=topk_weights,
768
+ overlap_args=overlap_args,
711
769
  )
712
770
  self._combine_intermediate_state = forward_batch, inner_state
713
771
 
@@ -108,6 +108,7 @@ MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
108
108
  MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
109
109
  DEEPEP_MODE: Optional[DeepEPMode] = None
110
110
  IS_TBO_ENABLED: Optional[bool] = None
111
+ IS_SBO_ENABLED: Optional[bool] = None
111
112
  TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
112
113
  DEEPEP_CONFIG: Optional[str] = None
113
114
  DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
@@ -119,6 +120,7 @@ def initialize_moe_config(server_args: ServerArgs):
119
120
  global DEEPEP_MODE
120
121
  global DEEPEP_CONFIG
121
122
  global IS_TBO_ENABLED
123
+ global IS_SBO_ENABLED
122
124
  global TBO_TOKEN_DISTRIBUTION_THRESHOLD
123
125
  global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
124
126
 
@@ -127,6 +129,7 @@ def initialize_moe_config(server_args: ServerArgs):
127
129
  DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
128
130
  DEEPEP_CONFIG = server_args.deepep_config or ""
129
131
  IS_TBO_ENABLED = server_args.enable_two_batch_overlap
132
+ IS_SBO_ENABLED = server_args.enable_single_batch_overlap
130
133
  TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
131
134
  DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
132
135
  server_args.disable_flashinfer_cutlass_moe_fp4_allgather
@@ -172,6 +175,13 @@ def is_tbo_enabled() -> bool:
172
175
  return IS_TBO_ENABLED
173
176
 
174
177
 
178
+ def is_sbo_enabled() -> bool:
179
+ global IS_SBO_ENABLED
180
+ if IS_SBO_ENABLED is None:
181
+ IS_SBO_ENABLED = False
182
+ return IS_SBO_ENABLED
183
+
184
+
175
185
  def get_tbo_token_distribution_threshold() -> float:
176
186
  global TBO_TOKEN_DISTRIBUTION_THRESHOLD
177
187
  if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None: