sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 3
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 3
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 2
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 2
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 32,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 32,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 64,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json
ADDED
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 64,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 3
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 64,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 64,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 3
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 64,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 4
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 64,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 64,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 16,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 2
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 64,
|
61
|
+
"BLOCK_SIZE_K": 64,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 64,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 64,
|
77
|
+
"BLOCK_SIZE_K": 64,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 64,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 8,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 8,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 64,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 32,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 64,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 64,
|
126
|
+
"GROUP_SIZE_M": 64,
|
127
|
+
"num_warps": 8,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 8,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 32,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 2
|
145
|
+
}
|
146
|
+
}
|
@@ -51,10 +51,14 @@ def get_moe_configs(
|
|
51
51
|
|
52
52
|
# We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
|
53
53
|
# so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
|
54
|
+
config_dir = os.environ.get(
|
55
|
+
"SGLANG_MOE_CONFIG_DIR", os.path.dirname(os.path.realpath(__file__))
|
56
|
+
)
|
57
|
+
|
54
58
|
triton_version = triton.__version__
|
55
59
|
version_dir = f"triton_{triton_version.replace('.', '_')}"
|
56
60
|
config_file_path = os.path.join(
|
57
|
-
|
61
|
+
config_dir,
|
58
62
|
"configs",
|
59
63
|
version_dir,
|
60
64
|
json_file_name,
|
@@ -75,7 +79,7 @@ def get_moe_configs(
|
|
75
79
|
if try_triton_version == triton_version:
|
76
80
|
continue
|
77
81
|
try_config_file_path = os.path.join(
|
78
|
-
|
82
|
+
config_dir,
|
79
83
|
"configs",
|
80
84
|
f"triton_{try_triton_version.replace('.', '_')}",
|
81
85
|
json_file_name,
|
@@ -11,12 +11,8 @@ from sglang.srt.distributed import (
|
|
11
11
|
get_moe_expert_parallel_world_size,
|
12
12
|
get_moe_tensor_parallel_rank,
|
13
13
|
get_moe_tensor_parallel_world_size,
|
14
|
-
get_tp_group,
|
15
14
|
tensor_model_parallel_all_reduce,
|
16
15
|
)
|
17
|
-
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
18
|
-
use_symmetric_memory,
|
19
|
-
)
|
20
16
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
21
17
|
from sglang.srt.layers.moe import (
|
22
18
|
MoeRunnerConfig,
|
@@ -24,7 +20,6 @@ from sglang.srt.layers.moe import (
|
|
24
20
|
should_use_flashinfer_trtllm_moe,
|
25
21
|
)
|
26
22
|
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
27
|
-
CombineInput,
|
28
23
|
StandardDispatcher,
|
29
24
|
StandardDispatchOutput,
|
30
25
|
)
|
@@ -239,6 +234,13 @@ class FusedMoE(torch.nn.Module):
|
|
239
234
|
self.quant_method.create_moe_runner(self, self.moe_runner_config)
|
240
235
|
self.dispatcher = StandardDispatcher()
|
241
236
|
|
237
|
+
self.should_fuse_routed_scaling_factor_in_topk = isinstance(
|
238
|
+
self.quant_method, ModelOptNvFp4FusedMoEMethod
|
239
|
+
) or (
|
240
|
+
isinstance(self.quant_method, Fp8MoEMethod)
|
241
|
+
and self.quant_method.use_cutlass_fused_experts_fp8
|
242
|
+
)
|
243
|
+
|
242
244
|
def _load_per_tensor_weight_scale(
|
243
245
|
self,
|
244
246
|
shard_id: str,
|
@@ -575,7 +577,10 @@ class FusedMoE(torch.nn.Module):
|
|
575
577
|
)
|
576
578
|
|
577
579
|
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
578
|
-
if should_use_flashinfer_trtllm_moe()
|
580
|
+
if should_use_flashinfer_trtllm_moe() and (
|
581
|
+
isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
|
582
|
+
or isinstance(self.quant_method, Fp8MoEMethod)
|
583
|
+
):
|
579
584
|
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
580
585
|
|
581
586
|
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
@@ -938,12 +943,6 @@ class FusedMoE(torch.nn.Module):
|
|
938
943
|
for shard_id in ["w1", "w2", "w3"]
|
939
944
|
]
|
940
945
|
|
941
|
-
def should_fuse_routed_scaling_factor_in_topk(self):
|
942
|
-
return isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) or (
|
943
|
-
isinstance(self.quant_method, Fp8MoEMethod)
|
944
|
-
and self.quant_method.use_cutlass_fused_experts_fp8
|
945
|
-
)
|
946
|
-
|
947
946
|
|
948
947
|
class FlashInferFusedMoE(FusedMoE):
|
949
948
|
def __init__(self, *args, **kwargs):
|
@@ -1,8 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
+
from contextlib import nullcontext
|
4
5
|
from dataclasses import dataclass
|
5
|
-
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union
|
6
7
|
|
7
8
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
8
9
|
from sglang.srt.layers.moe.token_dispatcher.base import (
|
@@ -25,6 +26,9 @@ from sglang.srt.utils import (
|
|
25
26
|
|
26
27
|
_is_npu = is_npu()
|
27
28
|
|
29
|
+
if TYPE_CHECKING:
|
30
|
+
from sglang.srt.single_batch_overlap import CombineOverlapArgs
|
31
|
+
|
28
32
|
try:
|
29
33
|
from deep_ep import Buffer, Config
|
30
34
|
|
@@ -164,10 +168,19 @@ class DeepEPBuffer:
|
|
164
168
|
num_rdma_bytes,
|
165
169
|
)
|
166
170
|
|
171
|
+
# We should calculate num_qps_per_rank consistently with DeepEP's test script logic:
|
167
172
|
if deepep_mode == DeepEPMode.NORMAL:
|
168
|
-
|
169
|
-
|
173
|
+
# refer: https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py#L235
|
174
|
+
num_qps_per_rank = DeepEPConfig.get_instance().num_sms
|
175
|
+
elif deepep_mode == DeepEPMode.LOW_LATENCY:
|
176
|
+
# refer: https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_low_latency.py#L176
|
170
177
|
num_qps_per_rank = num_experts // group.size()
|
178
|
+
elif deepep_mode == DeepEPMode.AUTO:
|
179
|
+
# low-latency and normal mode all need run
|
180
|
+
# refer: https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py#L235
|
181
|
+
num_qps_per_rank = max(
|
182
|
+
DeepEPConfig.get_instance().num_sms, num_experts // group.size()
|
183
|
+
)
|
171
184
|
else:
|
172
185
|
raise NotImplementedError
|
173
186
|
|
@@ -287,6 +300,7 @@ class _DeepEPDispatcherImplBase:
|
|
287
300
|
def dispatch_a(
|
288
301
|
self,
|
289
302
|
hidden_states: torch.Tensor,
|
303
|
+
input_global_scale: Optional[torch.Tensor],
|
290
304
|
topk_idx: torch.Tensor,
|
291
305
|
topk_weights: torch.Tensor,
|
292
306
|
):
|
@@ -300,6 +314,7 @@ class _DeepEPDispatcherImplBase:
|
|
300
314
|
hidden_states: torch.Tensor,
|
301
315
|
topk_idx: torch.Tensor,
|
302
316
|
topk_weights: torch.Tensor,
|
317
|
+
overlap_args: Optional["CombineOverlapArgs"],
|
303
318
|
):
|
304
319
|
raise NotImplementedError
|
305
320
|
|
@@ -320,6 +335,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
320
335
|
def dispatch_a(
|
321
336
|
self,
|
322
337
|
hidden_states: torch.Tensor,
|
338
|
+
input_global_scale: Optional[torch.Tensor],
|
323
339
|
topk_idx: torch.Tensor,
|
324
340
|
topk_weights: torch.Tensor,
|
325
341
|
):
|
@@ -417,6 +433,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
417
433
|
hidden_states: torch.Tensor,
|
418
434
|
topk_idx: torch.Tensor,
|
419
435
|
topk_weights: torch.Tensor,
|
436
|
+
overlap_args: Optional["CombineOverlapArgs"],
|
420
437
|
):
|
421
438
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
422
439
|
deepep_post_reorder_triton_kernel,
|
@@ -492,10 +509,12 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
492
509
|
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
493
510
|
"""
|
494
511
|
self.return_recv_hook = return_recv_hook
|
512
|
+
self.device_module = torch.get_device_module()
|
495
513
|
|
496
514
|
def dispatch_a(
|
497
515
|
self,
|
498
516
|
hidden_states: torch.Tensor,
|
517
|
+
input_global_scale: Optional[torch.Tensor],
|
499
518
|
topk_idx: torch.Tensor,
|
500
519
|
topk_weights: torch.Tensor,
|
501
520
|
):
|
@@ -507,9 +526,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
507
526
|
) // self.num_experts
|
508
527
|
hidden_states, masked_m, event, hook = self._dispatch_core(
|
509
528
|
hidden_states,
|
529
|
+
input_global_scale,
|
510
530
|
topk_idx,
|
511
|
-
# TODO(shuw): pending https://github.com/deepseek-ai/DeepEP/pull/341
|
512
|
-
use_fp8=not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"),
|
513
531
|
)
|
514
532
|
return (
|
515
533
|
hidden_states,
|
@@ -549,17 +567,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
549
567
|
def _dispatch_core(
|
550
568
|
self,
|
551
569
|
hidden_states: torch.Tensor,
|
570
|
+
input_global_scale: Optional[torch.Tensor],
|
552
571
|
topk_idx: torch.Tensor,
|
553
|
-
use_fp8: bool = False,
|
554
572
|
):
|
573
|
+
use_nvfp4 = use_fp8 = False
|
574
|
+
if input_global_scale is not None:
|
575
|
+
use_nvfp4 = True
|
576
|
+
elif not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"):
|
577
|
+
use_fp8 = True
|
578
|
+
|
555
579
|
buffer = self._get_buffer()
|
556
|
-
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
580
|
+
packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
|
557
581
|
buffer.low_latency_dispatch(
|
558
582
|
hidden_states,
|
559
583
|
topk_idx,
|
560
584
|
self.num_max_dispatch_tokens_per_rank,
|
561
585
|
self.num_experts,
|
562
586
|
use_fp8=use_fp8,
|
587
|
+
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
|
588
|
+
**(
|
589
|
+
dict(x_global_scale=input_global_scale)
|
590
|
+
if input_global_scale is not None
|
591
|
+
else dict()
|
592
|
+
),
|
563
593
|
async_finish=not self.return_recv_hook,
|
564
594
|
return_recv_hook=self.return_recv_hook,
|
565
595
|
round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
@@ -568,23 +598,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
568
598
|
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
|
569
599
|
)
|
570
600
|
)
|
571
|
-
return packed_recv_hidden, packed_recv_count, event, hook
|
601
|
+
return packed_recv_hidden, self.packed_recv_count, event, hook
|
572
602
|
|
573
603
|
def combine_a(
|
574
604
|
self,
|
575
605
|
hidden_states: torch.Tensor,
|
576
606
|
topk_idx: torch.Tensor,
|
577
607
|
topk_weights: torch.Tensor,
|
608
|
+
overlap_args: Optional["CombineOverlapArgs"],
|
578
609
|
):
|
579
610
|
hidden_states, event, hook = self._combine_core(
|
580
611
|
hidden_states,
|
581
612
|
topk_idx,
|
582
613
|
topk_weights,
|
614
|
+
overlap_args=overlap_args,
|
583
615
|
)
|
584
|
-
return hidden_states, event, hook
|
616
|
+
return hidden_states, event, hook, overlap_args
|
585
617
|
|
586
|
-
def combine_b(self, hidden_states, event, hook):
|
618
|
+
def combine_b(self, hidden_states, event, hook, overlap_args):
|
587
619
|
hook() if self.return_recv_hook else event.current_stream_wait()
|
620
|
+
|
621
|
+
if overlap_args is not None:
|
622
|
+
self.device_module.current_stream().wait_stream(overlap_args.stream)
|
623
|
+
|
588
624
|
return hidden_states
|
589
625
|
|
590
626
|
def _combine_core(
|
@@ -592,17 +628,35 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
592
628
|
hidden_states: torch.Tensor,
|
593
629
|
topk_idx: torch.Tensor,
|
594
630
|
topk_weights: torch.Tensor,
|
631
|
+
overlap_args: Optional["CombineOverlapArgs"],
|
595
632
|
):
|
596
633
|
buffer = self._get_buffer()
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
634
|
+
|
635
|
+
ctx = nullcontext()
|
636
|
+
if overlap_args is not None:
|
637
|
+
overlap_args.stream.wait_event(overlap_args.wait_event)
|
638
|
+
ctx = torch.cuda.stream(overlap_args.stream)
|
639
|
+
|
640
|
+
with ctx:
|
641
|
+
combined_hidden_states, event, hook = buffer.low_latency_combine(
|
642
|
+
x=hidden_states,
|
643
|
+
topk_idx=topk_idx,
|
644
|
+
topk_weights=topk_weights,
|
645
|
+
handle=self.handle,
|
646
|
+
async_finish=not self.return_recv_hook,
|
647
|
+
return_recv_hook=self.return_recv_hook,
|
648
|
+
**(
|
649
|
+
dict(
|
650
|
+
overlap=overlap_args.overlap,
|
651
|
+
src_signals=overlap_args.signal,
|
652
|
+
src_signal_expect_value=overlap_args.threshold,
|
653
|
+
)
|
654
|
+
if overlap_args is not None
|
655
|
+
else {}
|
656
|
+
),
|
657
|
+
)
|
658
|
+
|
659
|
+
self.packed_recv_count = self.handle = None
|
606
660
|
return combined_hidden_states, event, hook
|
607
661
|
|
608
662
|
def _get_buffer(self):
|
@@ -673,6 +727,7 @@ class DeepEPDispatcher(BaseDispatcher):
|
|
673
727
|
def dispatch_a(
|
674
728
|
self,
|
675
729
|
hidden_states: torch.Tensor,
|
730
|
+
input_global_scale: Optional[torch.Tensor],
|
676
731
|
topk_idx: torch.Tensor,
|
677
732
|
topk_weights: torch.Tensor,
|
678
733
|
forward_batch: ForwardBatch,
|
@@ -680,6 +735,7 @@ class DeepEPDispatcher(BaseDispatcher):
|
|
680
735
|
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
681
736
|
inner_state = self._get_impl(forward_batch).dispatch_a(
|
682
737
|
hidden_states=hidden_states,
|
738
|
+
input_global_scale=input_global_scale,
|
683
739
|
topk_idx=topk_idx,
|
684
740
|
topk_weights=topk_weights,
|
685
741
|
)
|
@@ -702,12 +758,14 @@ class DeepEPDispatcher(BaseDispatcher):
|
|
702
758
|
topk_idx: torch.Tensor,
|
703
759
|
topk_weights: torch.Tensor,
|
704
760
|
forward_batch: ForwardBatch,
|
761
|
+
overlap_args: Optional["CombineOverlapArgs"] = None,
|
705
762
|
):
|
706
763
|
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
707
764
|
inner_state = self._get_impl(forward_batch).combine_a(
|
708
765
|
hidden_states=hidden_states,
|
709
766
|
topk_idx=topk_idx,
|
710
767
|
topk_weights=topk_weights,
|
768
|
+
overlap_args=overlap_args,
|
711
769
|
)
|
712
770
|
self._combine_intermediate_state = forward_batch, inner_state
|
713
771
|
|
sglang/srt/layers/moe/utils.py
CHANGED
@@ -108,6 +108,7 @@ MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
|
|
108
108
|
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
|
109
109
|
DEEPEP_MODE: Optional[DeepEPMode] = None
|
110
110
|
IS_TBO_ENABLED: Optional[bool] = None
|
111
|
+
IS_SBO_ENABLED: Optional[bool] = None
|
111
112
|
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
|
112
113
|
DEEPEP_CONFIG: Optional[str] = None
|
113
114
|
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
|
@@ -119,6 +120,7 @@ def initialize_moe_config(server_args: ServerArgs):
|
|
119
120
|
global DEEPEP_MODE
|
120
121
|
global DEEPEP_CONFIG
|
121
122
|
global IS_TBO_ENABLED
|
123
|
+
global IS_SBO_ENABLED
|
122
124
|
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
123
125
|
global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
|
124
126
|
|
@@ -127,6 +129,7 @@ def initialize_moe_config(server_args: ServerArgs):
|
|
127
129
|
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
|
128
130
|
DEEPEP_CONFIG = server_args.deepep_config or ""
|
129
131
|
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
|
132
|
+
IS_SBO_ENABLED = server_args.enable_single_batch_overlap
|
130
133
|
TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
|
131
134
|
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
|
132
135
|
server_args.disable_flashinfer_cutlass_moe_fp4_allgather
|
@@ -172,6 +175,13 @@ def is_tbo_enabled() -> bool:
|
|
172
175
|
return IS_TBO_ENABLED
|
173
176
|
|
174
177
|
|
178
|
+
def is_sbo_enabled() -> bool:
|
179
|
+
global IS_SBO_ENABLED
|
180
|
+
if IS_SBO_ENABLED is None:
|
181
|
+
IS_SBO_ENABLED = False
|
182
|
+
return IS_SBO_ENABLED
|
183
|
+
|
184
|
+
|
175
185
|
def get_tbo_token_distribution_threshold() -> float:
|
176
186
|
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
177
187
|
if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
|