sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import NamedTuple, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
|
|
8
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
|
9
|
+
from sglang.srt.layers.dp_attention import get_is_extend_in_batch
|
|
10
|
+
from sglang.srt.layers.moe.token_dispatcher.base import (
|
|
11
|
+
BaseDispatcher,
|
|
12
|
+
CombineInput,
|
|
13
|
+
CombineInputFormat,
|
|
14
|
+
DispatchOutput,
|
|
15
|
+
DispatchOutputFormat,
|
|
16
|
+
)
|
|
17
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
|
18
|
+
from sglang.srt.layers.moe.utils import DeepEPMode
|
|
19
|
+
from sglang.srt.utils import get_int_env_var
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
from mooncake.mooncake_ep_buffer import Buffer
|
|
23
|
+
|
|
24
|
+
use_mooncake_ep = True
|
|
25
|
+
except ImportError:
|
|
26
|
+
use_mooncake_ep = False
|
|
27
|
+
|
|
28
|
+
from enum import Enum, auto
|
|
29
|
+
|
|
30
|
+
import torch
|
|
31
|
+
import torch.distributed as dist
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class MooncakeDispatchOutput(NamedTuple):
|
|
37
|
+
"""Mooncake EP dispatch output."""
|
|
38
|
+
|
|
39
|
+
hidden_states: torch.Tensor
|
|
40
|
+
hidden_states_scale: Optional[torch.Tensor]
|
|
41
|
+
topk_ids: torch.Tensor
|
|
42
|
+
topk_weights: torch.Tensor
|
|
43
|
+
masked_m: torch.Tensor
|
|
44
|
+
expected_m: int
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def format(self) -> DispatchOutputFormat:
|
|
48
|
+
return DispatchOutputFormat.DEEPEP_LL
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
assert isinstance(MooncakeDispatchOutput, DispatchOutput)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class MooncakeCombineInput(NamedTuple):
|
|
55
|
+
"""Mooncake EP combine input."""
|
|
56
|
+
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def format(self) -> CombineInputFormat:
|
|
61
|
+
return CombineInputFormat.DEEPEP_LL
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
assert isinstance(MooncakeCombineInput, CombineInput)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class EPBuffer:
|
|
68
|
+
_buffer = None
|
|
69
|
+
_hidden_size: Optional[int] = None
|
|
70
|
+
_num_max_dispatch_tokens_per_rank: Optional[int] = None
|
|
71
|
+
_num_experts: Optional[int] = None
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
def get_ep_buffer(
|
|
75
|
+
cls,
|
|
76
|
+
group: dist.ProcessGroup,
|
|
77
|
+
hidden_size: int,
|
|
78
|
+
param_bytes: int,
|
|
79
|
+
deepep_mode: DeepEPMode,
|
|
80
|
+
num_max_dispatch_tokens_per_rank: int = -1,
|
|
81
|
+
num_experts: int = -1,
|
|
82
|
+
):
|
|
83
|
+
if cls._buffer is not None:
|
|
84
|
+
return cls._buffer
|
|
85
|
+
|
|
86
|
+
cls._hidden_size = hidden_size
|
|
87
|
+
cls._num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
|
88
|
+
cls._num_experts = num_experts
|
|
89
|
+
|
|
90
|
+
num_ep_buffer_bytes = 0
|
|
91
|
+
if deepep_mode.enable_normal():
|
|
92
|
+
raise NotImplementedError(
|
|
93
|
+
"Normal mode is not supported for Mooncake EP yet."
|
|
94
|
+
)
|
|
95
|
+
if deepep_mode.enable_low_latency():
|
|
96
|
+
assert num_max_dispatch_tokens_per_rank != -1
|
|
97
|
+
assert num_experts != -1 and num_experts % group.size() == 0
|
|
98
|
+
num_ep_buffer_bytes = Buffer.get_ep_buffer_size_hint(
|
|
99
|
+
num_max_dispatch_tokens_per_rank,
|
|
100
|
+
hidden_size,
|
|
101
|
+
group.size(),
|
|
102
|
+
num_experts,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
cls._buffer = Buffer(group, num_ep_buffer_bytes)
|
|
106
|
+
return cls._buffer
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class _MooncakeEPDispatcherImpl:
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
group: torch.distributed.ProcessGroup,
|
|
113
|
+
router_topk: int,
|
|
114
|
+
permute_fusion: bool,
|
|
115
|
+
num_experts: int,
|
|
116
|
+
num_local_experts: int,
|
|
117
|
+
hidden_size: int,
|
|
118
|
+
params_dtype: torch.dtype,
|
|
119
|
+
return_recv_hook: bool,
|
|
120
|
+
deepep_mode: DeepEPMode,
|
|
121
|
+
):
|
|
122
|
+
if not use_mooncake_ep:
|
|
123
|
+
raise ImportError(
|
|
124
|
+
"Mooncake EP is not installed. Please install Mooncake package at "
|
|
125
|
+
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md "
|
|
126
|
+
"with EP support to run SGLang with Mooncake EP."
|
|
127
|
+
)
|
|
128
|
+
self.group = group
|
|
129
|
+
self.router_topk = router_topk
|
|
130
|
+
self.permute_fusion = permute_fusion
|
|
131
|
+
self.num_experts = num_experts
|
|
132
|
+
self.num_local_experts = num_local_experts
|
|
133
|
+
self.hidden_size = hidden_size
|
|
134
|
+
self.params_dtype = params_dtype
|
|
135
|
+
self.return_recv_hook = return_recv_hook
|
|
136
|
+
self.deepep_mode = deepep_mode
|
|
137
|
+
|
|
138
|
+
self.params_bytes = 2
|
|
139
|
+
self.num_max_dispatch_tokens_per_rank = get_int_env_var(
|
|
140
|
+
"SGLANG_MOONCAKE_EP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128
|
|
141
|
+
)
|
|
142
|
+
# Mooncake EP dispatch uses FINISHED_SUM_TAG=1024
|
|
143
|
+
# and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
|
|
144
|
+
assert self.num_max_dispatch_tokens_per_rank <= 1024
|
|
145
|
+
|
|
146
|
+
self.first_execution = True
|
|
147
|
+
self.timeout_us = 10000000
|
|
148
|
+
|
|
149
|
+
self.active_ranks = ElasticEPStateManager.instance().active_ranks
|
|
150
|
+
|
|
151
|
+
self.handle = None
|
|
152
|
+
|
|
153
|
+
def dispatch_a(
|
|
154
|
+
self,
|
|
155
|
+
hidden_states: torch.Tensor,
|
|
156
|
+
topk_output: TopKOutput,
|
|
157
|
+
):
|
|
158
|
+
topk_ids, topk_weights = topk_output.topk_ids, topk_output.topk_weights
|
|
159
|
+
buffer = self._get_buffer()
|
|
160
|
+
topk_ids = topk_ids.to(torch.int64)
|
|
161
|
+
expected_m = (
|
|
162
|
+
hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1]
|
|
163
|
+
+ self.num_experts
|
|
164
|
+
) // self.num_experts
|
|
165
|
+
hidden_states, masked_m, event, hook = self._dispatch_core(
|
|
166
|
+
hidden_states,
|
|
167
|
+
topk_ids,
|
|
168
|
+
use_fp8=True,
|
|
169
|
+
)
|
|
170
|
+
return (
|
|
171
|
+
hidden_states,
|
|
172
|
+
topk_ids,
|
|
173
|
+
topk_weights,
|
|
174
|
+
masked_m,
|
|
175
|
+
expected_m,
|
|
176
|
+
event,
|
|
177
|
+
hook,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def dispatch_b(
|
|
181
|
+
self,
|
|
182
|
+
hidden_states,
|
|
183
|
+
topk_ids,
|
|
184
|
+
topk_weights,
|
|
185
|
+
masked_m,
|
|
186
|
+
expected_m,
|
|
187
|
+
event,
|
|
188
|
+
hook,
|
|
189
|
+
):
|
|
190
|
+
hook() if self.return_recv_hook else event.current_stream_wait()
|
|
191
|
+
|
|
192
|
+
get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency(
|
|
193
|
+
masked_m
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
if isinstance(hidden_states, tuple):
|
|
197
|
+
hidden_states, hidden_states_scale = hidden_states
|
|
198
|
+
else:
|
|
199
|
+
hidden_states_scale = None
|
|
200
|
+
|
|
201
|
+
return MooncakeDispatchOutput(
|
|
202
|
+
hidden_states,
|
|
203
|
+
hidden_states_scale,
|
|
204
|
+
topk_ids,
|
|
205
|
+
topk_weights,
|
|
206
|
+
masked_m,
|
|
207
|
+
expected_m,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def _dispatch_core(
|
|
211
|
+
self,
|
|
212
|
+
hidden_states: torch.Tensor,
|
|
213
|
+
topk_ids: torch.Tensor,
|
|
214
|
+
use_fp8: bool = False,
|
|
215
|
+
):
|
|
216
|
+
buffer = self._get_buffer()
|
|
217
|
+
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
|
218
|
+
buffer.dispatch(
|
|
219
|
+
hidden_states,
|
|
220
|
+
topk_ids,
|
|
221
|
+
self.active_ranks,
|
|
222
|
+
self.num_max_dispatch_tokens_per_rank,
|
|
223
|
+
self.num_experts,
|
|
224
|
+
-1 if self.first_execution else self.timeout_us,
|
|
225
|
+
use_fp8=use_fp8,
|
|
226
|
+
async_finish=not self.return_recv_hook,
|
|
227
|
+
return_recv_hook=self.return_recv_hook,
|
|
228
|
+
)
|
|
229
|
+
)
|
|
230
|
+
return packed_recv_hidden, packed_recv_count, event, hook
|
|
231
|
+
|
|
232
|
+
def combine_a(
|
|
233
|
+
self,
|
|
234
|
+
hidden_states: torch.Tensor,
|
|
235
|
+
topk_ids: torch.Tensor,
|
|
236
|
+
topk_weights: torch.Tensor,
|
|
237
|
+
):
|
|
238
|
+
hidden_states, event, hook = self._combine_core(
|
|
239
|
+
hidden_states,
|
|
240
|
+
topk_ids,
|
|
241
|
+
topk_weights,
|
|
242
|
+
)
|
|
243
|
+
return hidden_states, event, hook
|
|
244
|
+
|
|
245
|
+
def combine_b(self, hidden_states, event, hook):
|
|
246
|
+
hook() if self.return_recv_hook else event.current_stream_wait()
|
|
247
|
+
return hidden_states
|
|
248
|
+
|
|
249
|
+
def _combine_core(
|
|
250
|
+
self,
|
|
251
|
+
hidden_states: torch.Tensor,
|
|
252
|
+
topk_ids: torch.Tensor,
|
|
253
|
+
topk_weights: torch.Tensor,
|
|
254
|
+
):
|
|
255
|
+
buffer = self._get_buffer()
|
|
256
|
+
combined_hidden_states, event, hook = buffer.combine(
|
|
257
|
+
hidden_states,
|
|
258
|
+
topk_ids,
|
|
259
|
+
topk_weights,
|
|
260
|
+
self.active_ranks,
|
|
261
|
+
-1 if self.first_execution else self.timeout_us,
|
|
262
|
+
self.handle,
|
|
263
|
+
async_finish=not self.return_recv_hook,
|
|
264
|
+
return_recv_hook=self.return_recv_hook,
|
|
265
|
+
)
|
|
266
|
+
self.first_execution = False
|
|
267
|
+
self.handle = None
|
|
268
|
+
return combined_hidden_states, event, hook
|
|
269
|
+
|
|
270
|
+
def _get_buffer(self):
|
|
271
|
+
return EPBuffer.get_ep_buffer(
|
|
272
|
+
self.group,
|
|
273
|
+
self.hidden_size,
|
|
274
|
+
self.params_bytes,
|
|
275
|
+
self.deepep_mode,
|
|
276
|
+
self.num_max_dispatch_tokens_per_rank,
|
|
277
|
+
self.num_experts,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
@dataclass
|
|
282
|
+
class _Stage(Enum):
|
|
283
|
+
INITIAL = auto()
|
|
284
|
+
AFTER_DISPATCH_A = auto()
|
|
285
|
+
AFTER_DISPATCH_B = auto()
|
|
286
|
+
AFTER_COMBINE_A = auto()
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class MooncakeEPDispatcher(BaseDispatcher):
|
|
290
|
+
def __init__(
|
|
291
|
+
self,
|
|
292
|
+
group: torch.distributed.ProcessGroup,
|
|
293
|
+
router_topk: int,
|
|
294
|
+
permute_fusion: bool = False,
|
|
295
|
+
num_experts: int = None,
|
|
296
|
+
num_local_experts: int = None,
|
|
297
|
+
hidden_size: int = None,
|
|
298
|
+
params_dtype: torch.dtype = None,
|
|
299
|
+
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
|
|
300
|
+
async_finish: bool = False,
|
|
301
|
+
return_recv_hook: bool = False,
|
|
302
|
+
):
|
|
303
|
+
self.deepep_mode = deepep_mode
|
|
304
|
+
|
|
305
|
+
if self.deepep_mode.enable_low_latency():
|
|
306
|
+
self._low_latency_dispatcher = _MooncakeEPDispatcherImpl(
|
|
307
|
+
group=group,
|
|
308
|
+
router_topk=router_topk,
|
|
309
|
+
permute_fusion=permute_fusion,
|
|
310
|
+
num_experts=num_experts,
|
|
311
|
+
num_local_experts=num_local_experts,
|
|
312
|
+
hidden_size=hidden_size,
|
|
313
|
+
params_dtype=params_dtype,
|
|
314
|
+
return_recv_hook=return_recv_hook,
|
|
315
|
+
deepep_mode=deepep_mode,
|
|
316
|
+
)
|
|
317
|
+
if self.deepep_mode.enable_normal():
|
|
318
|
+
raise NotImplementedError
|
|
319
|
+
|
|
320
|
+
self._stage = _Stage.INITIAL
|
|
321
|
+
|
|
322
|
+
def dispatch(self, *args, **kwargs) -> DispatchOutput:
|
|
323
|
+
self.dispatch_a(*args, **kwargs)
|
|
324
|
+
ret = self.dispatch_b()
|
|
325
|
+
return ret
|
|
326
|
+
|
|
327
|
+
def dispatch_a(
|
|
328
|
+
self,
|
|
329
|
+
hidden_states: torch.Tensor,
|
|
330
|
+
topk_output: TopKOutput,
|
|
331
|
+
):
|
|
332
|
+
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
|
333
|
+
inner_state = self._get_impl().dispatch_a(
|
|
334
|
+
hidden_states=hidden_states,
|
|
335
|
+
topk_output=topk_output,
|
|
336
|
+
)
|
|
337
|
+
self._dispatch_intermediate_state = inner_state
|
|
338
|
+
|
|
339
|
+
def dispatch_b(self):
|
|
340
|
+
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
|
341
|
+
inner_state = self._dispatch_intermediate_state
|
|
342
|
+
del self._dispatch_intermediate_state
|
|
343
|
+
return self._get_impl().dispatch_b(*inner_state)
|
|
344
|
+
|
|
345
|
+
def combine(self, *args, **kwargs) -> Tuple:
|
|
346
|
+
self.combine_a(*args, **kwargs)
|
|
347
|
+
ret = self.combine_b()
|
|
348
|
+
return ret
|
|
349
|
+
|
|
350
|
+
def combine_a(
|
|
351
|
+
self,
|
|
352
|
+
hidden_states: torch.Tensor,
|
|
353
|
+
topk_ids: torch.Tensor,
|
|
354
|
+
topk_weights: torch.Tensor,
|
|
355
|
+
overlap_args: Optional = None,
|
|
356
|
+
):
|
|
357
|
+
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
|
358
|
+
inner_state = self._get_impl().combine_a(
|
|
359
|
+
hidden_states=hidden_states,
|
|
360
|
+
topk_ids=topk_ids,
|
|
361
|
+
topk_weights=topk_weights,
|
|
362
|
+
)
|
|
363
|
+
self._combine_intermediate_state = inner_state
|
|
364
|
+
|
|
365
|
+
def combine_b(self):
|
|
366
|
+
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
|
367
|
+
inner_state = self._combine_intermediate_state
|
|
368
|
+
del self._combine_intermediate_state
|
|
369
|
+
return self._get_impl().combine_b(*inner_state)
|
|
370
|
+
|
|
371
|
+
def _get_impl(self) -> _MooncakeEPDispatcherImpl:
|
|
372
|
+
is_extend_in_batch = get_is_extend_in_batch()
|
|
373
|
+
resolved_deepep_mode = self.deepep_mode.resolve(is_extend_in_batch)
|
|
374
|
+
if resolved_deepep_mode == DeepEPMode.NORMAL:
|
|
375
|
+
raise NotImplementedError
|
|
376
|
+
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
|
|
377
|
+
return self._low_latency_dispatcher
|
|
378
|
+
else:
|
|
379
|
+
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
|
380
|
+
|
|
381
|
+
def _update_stage(self, old_stage, new_stage):
|
|
382
|
+
assert self._stage == old_stage
|
|
383
|
+
self._stage = new_stage
|
|
384
|
+
|
|
385
|
+
def set_quant_config(self, quant_config: dict):
|
|
386
|
+
pass
|
|
@@ -4,6 +4,11 @@ from typing import TYPE_CHECKING, NamedTuple
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from sglang.srt.distributed import (
|
|
8
|
+
get_moe_expert_parallel_rank,
|
|
9
|
+
get_moe_expert_parallel_world_size,
|
|
10
|
+
)
|
|
11
|
+
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
|
|
7
12
|
from sglang.srt.layers.moe.token_dispatcher.base import (
|
|
8
13
|
BaseDispatcher,
|
|
9
14
|
CombineInput,
|
|
@@ -11,6 +16,8 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
|
|
|
11
16
|
DispatchOutput,
|
|
12
17
|
DispatchOutputFormat,
|
|
13
18
|
)
|
|
19
|
+
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
|
|
20
|
+
from sglang.srt.layers.moe.utils import get_moe_runner_backend
|
|
14
21
|
|
|
15
22
|
if TYPE_CHECKING:
|
|
16
23
|
from sglang.srt.layers.moe.topk import TopKOutput
|
|
@@ -45,9 +52,45 @@ assert isinstance(StandardCombineInput, CombineInput)
|
|
|
45
52
|
|
|
46
53
|
class StandardDispatcher(BaseDispatcher):
|
|
47
54
|
|
|
55
|
+
def __init__(self, moe_runner_config: MoeRunnerConfig):
|
|
56
|
+
self.moe_ep_size = get_moe_expert_parallel_world_size()
|
|
57
|
+
self.enable_flashinfer_cutlass_moe = (
|
|
58
|
+
get_moe_runner_backend().is_flashinfer_cutlass()
|
|
59
|
+
)
|
|
60
|
+
self.num_experts = moe_runner_config.num_experts
|
|
61
|
+
self.num_local_experts = moe_runner_config.num_local_experts
|
|
62
|
+
self.moe_ep_rank = get_moe_expert_parallel_rank()
|
|
63
|
+
self.local_expert_mapping = None
|
|
64
|
+
|
|
48
65
|
def dispatch(
|
|
49
66
|
self, hidden_states: torch.Tensor, topk_output: TopKOutput
|
|
50
67
|
) -> DispatchOutput:
|
|
68
|
+
|
|
69
|
+
if (
|
|
70
|
+
self.moe_ep_size > 1
|
|
71
|
+
and not self.enable_flashinfer_cutlass_moe
|
|
72
|
+
and TopKOutputChecker.format_is_standard(topk_output)
|
|
73
|
+
):
|
|
74
|
+
if self.local_expert_mapping is None:
|
|
75
|
+
self.local_expert_mapping = torch.full(
|
|
76
|
+
(self.num_experts,), -1, dtype=torch.int32, device="cuda"
|
|
77
|
+
)
|
|
78
|
+
self.local_expert_mapping[
|
|
79
|
+
self.moe_ep_rank
|
|
80
|
+
* self.num_local_experts : (self.moe_ep_rank + 1)
|
|
81
|
+
* self.num_local_experts
|
|
82
|
+
] = torch.arange(
|
|
83
|
+
0, self.num_local_experts, dtype=torch.int32, device="cuda"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if self.local_expert_mapping is not None:
|
|
87
|
+
if TopKOutputChecker.format_is_standard(topk_output):
|
|
88
|
+
topk_output = topk_output._replace(
|
|
89
|
+
topk_ids=self.local_expert_mapping[topk_output.topk_ids]
|
|
90
|
+
)
|
|
91
|
+
elif TopKOutputChecker.format_is_triton_kernel(topk_output):
|
|
92
|
+
raise NotImplementedError()
|
|
93
|
+
|
|
51
94
|
return StandardDispatchOutput(
|
|
52
95
|
hidden_states=hidden_states, topk_output=topk_output
|
|
53
96
|
)
|
|
@@ -59,3 +102,6 @@ class StandardDispatcher(BaseDispatcher):
|
|
|
59
102
|
# TODO: this branch should be removed in the future
|
|
60
103
|
assert isinstance(combine_input, torch.Tensor)
|
|
61
104
|
return combine_input
|
|
105
|
+
|
|
106
|
+
def set_quant_config(self, quant_config: dict):
|
|
107
|
+
pass
|
sglang/srt/layers/moe/topk.py
CHANGED
|
@@ -365,9 +365,10 @@ class TopK(CustomOp):
|
|
|
365
365
|
def empty_topk_output(self, device: torch.device) -> TopKOutput:
|
|
366
366
|
topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
|
|
367
367
|
topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
|
|
368
|
-
|
|
368
|
+
topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device)
|
|
369
|
+
# FIXME: router_logits should be of size (0, num_experts)
|
|
369
370
|
router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
|
|
370
|
-
return StandardTopKOutput(topk_weights,
|
|
371
|
+
return StandardTopKOutput(topk_weights, topk_ids, router_logits)
|
|
371
372
|
|
|
372
373
|
|
|
373
374
|
# ------------------------------- TopK implementation -------------------------------------
|
sglang/srt/layers/moe/utils.py
CHANGED
|
@@ -13,6 +13,7 @@ from sglang.srt.layers.dp_attention import (
|
|
|
13
13
|
get_attention_dp_size,
|
|
14
14
|
is_dp_attention_enabled,
|
|
15
15
|
)
|
|
16
|
+
from sglang.srt.utils import log_info_on_rank0
|
|
16
17
|
|
|
17
18
|
if TYPE_CHECKING:
|
|
18
19
|
from sglang.srt.server_args import ServerArgs
|
|
@@ -24,6 +25,7 @@ class MoeA2ABackend(Enum):
|
|
|
24
25
|
|
|
25
26
|
NONE = "none"
|
|
26
27
|
DEEPEP = "deepep"
|
|
28
|
+
MOONCAKE = "mooncake"
|
|
27
29
|
|
|
28
30
|
@classmethod
|
|
29
31
|
def _missing_(cls, value):
|
|
@@ -40,20 +42,28 @@ class MoeA2ABackend(Enum):
|
|
|
40
42
|
def is_deepep(self):
|
|
41
43
|
return self == MoeA2ABackend.DEEPEP
|
|
42
44
|
|
|
45
|
+
def is_mooncake(self):
|
|
46
|
+
return self == MoeA2ABackend.MOONCAKE
|
|
47
|
+
|
|
43
48
|
|
|
44
49
|
class MoeRunnerBackend(Enum):
|
|
45
50
|
|
|
46
51
|
AUTO = "auto"
|
|
52
|
+
DEEP_GEMM = "deep_gemm"
|
|
47
53
|
TRITON = "triton"
|
|
48
54
|
TRITON_KERNEL = "triton_kernel"
|
|
49
55
|
FLASHINFER_TRTLLM = "flashinfer_trtllm"
|
|
50
56
|
FLASHINFER_CUTLASS = "flashinfer_cutlass"
|
|
51
57
|
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
|
|
52
58
|
FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
|
|
59
|
+
CUTLASS = "cutlass"
|
|
53
60
|
|
|
54
61
|
def is_auto(self):
|
|
55
62
|
return self == MoeRunnerBackend.AUTO
|
|
56
63
|
|
|
64
|
+
def is_deep_gemm(self):
|
|
65
|
+
return self == MoeRunnerBackend.DEEP_GEMM
|
|
66
|
+
|
|
57
67
|
def is_triton(self):
|
|
58
68
|
return self == MoeRunnerBackend.TRITON
|
|
59
69
|
|
|
@@ -72,6 +82,9 @@ class MoeRunnerBackend(Enum):
|
|
|
72
82
|
def is_flashinfer_mxfp4(self):
|
|
73
83
|
return self == MoeRunnerBackend.FLASHINFER_MXFP4
|
|
74
84
|
|
|
85
|
+
def is_cutlass(self):
|
|
86
|
+
return self == MoeRunnerBackend.CUTLASS
|
|
87
|
+
|
|
75
88
|
|
|
76
89
|
class DeepEPMode(Enum):
|
|
77
90
|
|
|
@@ -147,7 +160,10 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
|
|
|
147
160
|
def get_moe_runner_backend() -> MoeRunnerBackend:
|
|
148
161
|
global MOE_RUNNER_BACKEND
|
|
149
162
|
if MOE_RUNNER_BACKEND is None:
|
|
150
|
-
|
|
163
|
+
log_info_on_rank0(
|
|
164
|
+
logger,
|
|
165
|
+
"MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected",
|
|
166
|
+
)
|
|
151
167
|
MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
|
|
152
168
|
return MOE_RUNNER_BACKEND
|
|
153
169
|
|
|
@@ -10,10 +10,6 @@ import torch
|
|
|
10
10
|
try:
|
|
11
11
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
|
12
12
|
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
|
13
|
-
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
|
14
|
-
CompressedTensorsW8A8Fp8MoEMethod,
|
|
15
|
-
CompressedTensorsWNA16MoEMethod,
|
|
16
|
-
)
|
|
17
13
|
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
|
18
14
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
|
19
15
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
|
@@ -72,7 +68,8 @@ if TYPE_CHECKING:
|
|
|
72
68
|
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
73
69
|
"fp8": Fp8Config,
|
|
74
70
|
"blockwise_int8": BlockInt8Config,
|
|
75
|
-
"modelopt": ModelOptFp8Config,
|
|
71
|
+
"modelopt": ModelOptFp8Config, # Auto-detect, defaults to FP8
|
|
72
|
+
"modelopt_fp8": ModelOptFp8Config,
|
|
76
73
|
"modelopt_fp4": ModelOptFp4Config,
|
|
77
74
|
"w8a8_int8": W8A8Int8Config,
|
|
78
75
|
"w8a8_fp8": W8A8Fp8Config,
|
|
@@ -174,51 +171,3 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
|
|
|
174
171
|
return original_isinstance(obj, classinfo)
|
|
175
172
|
|
|
176
173
|
builtins.isinstance = patched_isinstance
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|
180
|
-
"""
|
|
181
|
-
Monkey patch the apply function of vllm's FusedMoEMethodBase.
|
|
182
|
-
Convert sglang arguments to vllm arguments.
|
|
183
|
-
"""
|
|
184
|
-
original_apply = class_obj.apply
|
|
185
|
-
sig = inspect.signature(original_apply)
|
|
186
|
-
param_names = list(sig.parameters.keys())
|
|
187
|
-
has_correction_bias = "e_score_correction_bias" in param_names
|
|
188
|
-
|
|
189
|
-
def new_apply(
|
|
190
|
-
self,
|
|
191
|
-
layer: torch.nn.Module,
|
|
192
|
-
x: torch.Tensor,
|
|
193
|
-
topk_output: TopKOutput,
|
|
194
|
-
*,
|
|
195
|
-
activation: str = "silu",
|
|
196
|
-
apply_router_weight_on_input: bool = False,
|
|
197
|
-
inplace: bool = True,
|
|
198
|
-
no_combine: bool = False,
|
|
199
|
-
routed_scaling_factor: Optional[float] = None,
|
|
200
|
-
):
|
|
201
|
-
assert activation == "silu"
|
|
202
|
-
assert inplace and not no_combine
|
|
203
|
-
|
|
204
|
-
kwargs = {
|
|
205
|
-
"self": self,
|
|
206
|
-
"layer": layer,
|
|
207
|
-
"x": x,
|
|
208
|
-
"topk_output": topk_output,
|
|
209
|
-
}
|
|
210
|
-
return original_apply(**kwargs)
|
|
211
|
-
|
|
212
|
-
setattr(class_obj, "apply", new_apply)
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
def monkey_patch_quant_configs():
|
|
216
|
-
"""Apply all monkey patches in one place."""
|
|
217
|
-
|
|
218
|
-
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
|
219
|
-
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
# Only apply monkey patches if vllm is available
|
|
223
|
-
if VLLM_AVAILABLE:
|
|
224
|
-
monkey_patch_quant_configs()
|