sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,20 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
+
from contextlib import nullcontext
|
4
5
|
from dataclasses import dataclass
|
5
|
-
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union
|
6
7
|
|
7
8
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
8
|
-
from sglang.srt.layers.moe import
|
9
|
-
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
9
|
+
from sglang.srt.layers.moe.token_dispatcher.base import (
|
10
10
|
BaseDispatcher,
|
11
11
|
BaseDispatcherConfig,
|
12
|
+
CombineInput,
|
13
|
+
CombineInputFormat,
|
12
14
|
DispatchOutput,
|
13
15
|
DispatchOutputFormat,
|
14
16
|
)
|
17
|
+
from sglang.srt.layers.moe.utils import DeepEPMode, get_deepep_config, is_tbo_enabled
|
15
18
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
16
19
|
from sglang.srt.utils import (
|
17
20
|
get_bool_env_var,
|
@@ -23,6 +26,9 @@ from sglang.srt.utils import (
|
|
23
26
|
|
24
27
|
_is_npu = is_npu()
|
25
28
|
|
29
|
+
if TYPE_CHECKING:
|
30
|
+
from sglang.srt.single_batch_overlap import CombineOverlapArgs
|
31
|
+
|
26
32
|
try:
|
27
33
|
from deep_ep import Buffer, Config
|
28
34
|
|
@@ -40,11 +46,6 @@ from enum import Enum, IntEnum, auto
|
|
40
46
|
import torch
|
41
47
|
import torch.distributed as dist
|
42
48
|
|
43
|
-
from sglang.srt.layers.moe.ep_moe.kernels import (
|
44
|
-
deepep_permute_triton_kernel,
|
45
|
-
deepep_post_reorder_triton_kernel,
|
46
|
-
deepep_run_moe_deep_preprocess,
|
47
|
-
)
|
48
49
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
49
50
|
|
50
51
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
@@ -56,6 +57,7 @@ class DeepEPNormalOutput(NamedTuple):
|
|
56
57
|
"""DeepEP normal dispatch output."""
|
57
58
|
|
58
59
|
hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
|
60
|
+
# hidden_states_scale
|
59
61
|
topk_idx: torch.Tensor
|
60
62
|
topk_weights: torch.Tensor
|
61
63
|
num_recv_tokens_per_expert: List[int]
|
@@ -79,24 +81,32 @@ class DeepEPLLOutput(NamedTuple):
|
|
79
81
|
return DispatchOutputFormat.DEEPEP_LL
|
80
82
|
|
81
83
|
|
82
|
-
|
83
|
-
|
84
|
+
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
85
|
+
assert isinstance(DeepEPLLOutput, DispatchOutput)
|
84
86
|
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
expected_m: int
|
87
|
+
|
88
|
+
class DeepEPNormalCombineInput(NamedTuple):
|
89
|
+
"""DeepEP normal combine input."""
|
90
|
+
|
91
|
+
pass
|
91
92
|
|
92
93
|
@property
|
93
|
-
def format(self) ->
|
94
|
-
return
|
94
|
+
def format(self) -> CombineInputFormat:
|
95
|
+
return CombineInputFormat.DEEPEP_NORMAL
|
95
96
|
|
96
97
|
|
97
|
-
|
98
|
-
|
99
|
-
|
98
|
+
class DeepEPLLCombineInput(NamedTuple):
|
99
|
+
"""DeepEP low latency combine input."""
|
100
|
+
|
101
|
+
pass
|
102
|
+
|
103
|
+
@property
|
104
|
+
def format(self) -> CombineInputFormat:
|
105
|
+
return CombineInputFormat.DEEPEP_LL
|
106
|
+
|
107
|
+
|
108
|
+
assert isinstance(DeepEPNormalCombineInput, CombineInput)
|
109
|
+
assert isinstance(DeepEPLLCombineInput, CombineInput)
|
100
110
|
|
101
111
|
|
102
112
|
class DeepEPDispatchMode(IntEnum):
|
@@ -158,10 +168,19 @@ class DeepEPBuffer:
|
|
158
168
|
num_rdma_bytes,
|
159
169
|
)
|
160
170
|
|
171
|
+
# We should calculate num_qps_per_rank consistently with DeepEP's test script logic:
|
161
172
|
if deepep_mode == DeepEPMode.NORMAL:
|
162
|
-
|
163
|
-
|
173
|
+
# refer: https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py#L235
|
174
|
+
num_qps_per_rank = DeepEPConfig.get_instance().num_sms
|
175
|
+
elif deepep_mode == DeepEPMode.LOW_LATENCY:
|
176
|
+
# refer: https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_low_latency.py#L176
|
164
177
|
num_qps_per_rank = num_experts // group.size()
|
178
|
+
elif deepep_mode == DeepEPMode.AUTO:
|
179
|
+
# low-latency and normal mode all need run
|
180
|
+
# refer: https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py#L235
|
181
|
+
num_qps_per_rank = max(
|
182
|
+
DeepEPConfig.get_instance().num_sms, num_experts // group.size()
|
183
|
+
)
|
165
184
|
else:
|
166
185
|
raise NotImplementedError
|
167
186
|
|
@@ -272,12 +291,16 @@ class _DeepEPDispatcherImplBase:
|
|
272
291
|
self.num_max_dispatch_tokens_per_rank = get_int_env_var(
|
273
292
|
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128
|
274
293
|
)
|
294
|
+
# DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024
|
295
|
+
# and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
|
296
|
+
assert self.num_max_dispatch_tokens_per_rank <= 1024
|
275
297
|
|
276
298
|
self.handle = None
|
277
299
|
|
278
300
|
def dispatch_a(
|
279
301
|
self,
|
280
302
|
hidden_states: torch.Tensor,
|
303
|
+
input_global_scale: Optional[torch.Tensor],
|
281
304
|
topk_idx: torch.Tensor,
|
282
305
|
topk_weights: torch.Tensor,
|
283
306
|
):
|
@@ -291,6 +314,7 @@ class _DeepEPDispatcherImplBase:
|
|
291
314
|
hidden_states: torch.Tensor,
|
292
315
|
topk_idx: torch.Tensor,
|
293
316
|
topk_weights: torch.Tensor,
|
317
|
+
overlap_args: Optional["CombineOverlapArgs"],
|
294
318
|
):
|
295
319
|
raise NotImplementedError
|
296
320
|
|
@@ -311,6 +335,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
311
335
|
def dispatch_a(
|
312
336
|
self,
|
313
337
|
hidden_states: torch.Tensor,
|
338
|
+
input_global_scale: Optional[torch.Tensor],
|
314
339
|
topk_idx: torch.Tensor,
|
315
340
|
topk_weights: torch.Tensor,
|
316
341
|
):
|
@@ -408,8 +433,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
408
433
|
hidden_states: torch.Tensor,
|
409
434
|
topk_idx: torch.Tensor,
|
410
435
|
topk_weights: torch.Tensor,
|
436
|
+
overlap_args: Optional["CombineOverlapArgs"],
|
411
437
|
):
|
412
|
-
|
438
|
+
from sglang.srt.layers.moe.ep_moe.kernels import (
|
439
|
+
deepep_post_reorder_triton_kernel,
|
440
|
+
)
|
441
|
+
|
442
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
413
443
|
output = hidden_states
|
414
444
|
else:
|
415
445
|
if hidden_states.shape[0] > 0:
|
@@ -479,10 +509,12 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
479
509
|
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
480
510
|
"""
|
481
511
|
self.return_recv_hook = return_recv_hook
|
512
|
+
self.device_module = torch.get_device_module()
|
482
513
|
|
483
514
|
def dispatch_a(
|
484
515
|
self,
|
485
516
|
hidden_states: torch.Tensor,
|
517
|
+
input_global_scale: Optional[torch.Tensor],
|
486
518
|
topk_idx: torch.Tensor,
|
487
519
|
topk_weights: torch.Tensor,
|
488
520
|
):
|
@@ -494,8 +526,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
494
526
|
) // self.num_experts
|
495
527
|
hidden_states, masked_m, event, hook = self._dispatch_core(
|
496
528
|
hidden_states,
|
529
|
+
input_global_scale,
|
497
530
|
topk_idx,
|
498
|
-
use_fp8=True,
|
499
531
|
)
|
500
532
|
return (
|
501
533
|
hidden_states,
|
@@ -523,39 +555,41 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
523
555
|
masked_m
|
524
556
|
)
|
525
557
|
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
expected_m,
|
534
|
-
)
|
535
|
-
else:
|
536
|
-
deepep_output = DeepEPLLOutput(
|
537
|
-
hidden_states,
|
538
|
-
topk_idx,
|
539
|
-
topk_weights,
|
540
|
-
masked_m,
|
541
|
-
expected_m,
|
542
|
-
)
|
558
|
+
deepep_output = DeepEPLLOutput(
|
559
|
+
hidden_states,
|
560
|
+
topk_idx,
|
561
|
+
topk_weights,
|
562
|
+
masked_m,
|
563
|
+
expected_m,
|
564
|
+
)
|
543
565
|
return deepep_output
|
544
566
|
|
545
567
|
def _dispatch_core(
|
546
568
|
self,
|
547
569
|
hidden_states: torch.Tensor,
|
570
|
+
input_global_scale: Optional[torch.Tensor],
|
548
571
|
topk_idx: torch.Tensor,
|
549
|
-
use_fp8: bool = False,
|
550
572
|
):
|
573
|
+
use_nvfp4 = use_fp8 = False
|
574
|
+
if input_global_scale is not None:
|
575
|
+
use_nvfp4 = True
|
576
|
+
elif not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"):
|
577
|
+
use_fp8 = True
|
578
|
+
|
551
579
|
buffer = self._get_buffer()
|
552
|
-
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
580
|
+
packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
|
553
581
|
buffer.low_latency_dispatch(
|
554
582
|
hidden_states,
|
555
583
|
topk_idx,
|
556
584
|
self.num_max_dispatch_tokens_per_rank,
|
557
585
|
self.num_experts,
|
558
586
|
use_fp8=use_fp8,
|
587
|
+
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
|
588
|
+
**(
|
589
|
+
dict(x_global_scale=input_global_scale)
|
590
|
+
if input_global_scale is not None
|
591
|
+
else dict()
|
592
|
+
),
|
559
593
|
async_finish=not self.return_recv_hook,
|
560
594
|
return_recv_hook=self.return_recv_hook,
|
561
595
|
round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
@@ -564,23 +598,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
564
598
|
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
|
565
599
|
)
|
566
600
|
)
|
567
|
-
return packed_recv_hidden, packed_recv_count, event, hook
|
601
|
+
return packed_recv_hidden, self.packed_recv_count, event, hook
|
568
602
|
|
569
603
|
def combine_a(
|
570
604
|
self,
|
571
605
|
hidden_states: torch.Tensor,
|
572
606
|
topk_idx: torch.Tensor,
|
573
607
|
topk_weights: torch.Tensor,
|
608
|
+
overlap_args: Optional["CombineOverlapArgs"],
|
574
609
|
):
|
575
610
|
hidden_states, event, hook = self._combine_core(
|
576
611
|
hidden_states,
|
577
612
|
topk_idx,
|
578
613
|
topk_weights,
|
614
|
+
overlap_args=overlap_args,
|
579
615
|
)
|
580
|
-
return hidden_states, event, hook
|
616
|
+
return hidden_states, event, hook, overlap_args
|
581
617
|
|
582
|
-
def combine_b(self, hidden_states, event, hook):
|
618
|
+
def combine_b(self, hidden_states, event, hook, overlap_args):
|
583
619
|
hook() if self.return_recv_hook else event.current_stream_wait()
|
620
|
+
|
621
|
+
if overlap_args is not None:
|
622
|
+
self.device_module.current_stream().wait_stream(overlap_args.stream)
|
623
|
+
|
584
624
|
return hidden_states
|
585
625
|
|
586
626
|
def _combine_core(
|
@@ -588,17 +628,35 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
588
628
|
hidden_states: torch.Tensor,
|
589
629
|
topk_idx: torch.Tensor,
|
590
630
|
topk_weights: torch.Tensor,
|
631
|
+
overlap_args: Optional["CombineOverlapArgs"],
|
591
632
|
):
|
592
633
|
buffer = self._get_buffer()
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
634
|
+
|
635
|
+
ctx = nullcontext()
|
636
|
+
if overlap_args is not None:
|
637
|
+
overlap_args.stream.wait_event(overlap_args.wait_event)
|
638
|
+
ctx = torch.cuda.stream(overlap_args.stream)
|
639
|
+
|
640
|
+
with ctx:
|
641
|
+
combined_hidden_states, event, hook = buffer.low_latency_combine(
|
642
|
+
x=hidden_states,
|
643
|
+
topk_idx=topk_idx,
|
644
|
+
topk_weights=topk_weights,
|
645
|
+
handle=self.handle,
|
646
|
+
async_finish=not self.return_recv_hook,
|
647
|
+
return_recv_hook=self.return_recv_hook,
|
648
|
+
**(
|
649
|
+
dict(
|
650
|
+
overlap=overlap_args.overlap,
|
651
|
+
src_signals=overlap_args.signal,
|
652
|
+
src_signal_expect_value=overlap_args.threshold,
|
653
|
+
)
|
654
|
+
if overlap_args is not None
|
655
|
+
else {}
|
656
|
+
),
|
657
|
+
)
|
658
|
+
|
659
|
+
self.packed_recv_count = self.handle = None
|
602
660
|
return combined_hidden_states, event, hook
|
603
661
|
|
604
662
|
def _get_buffer(self):
|
@@ -669,6 +727,7 @@ class DeepEPDispatcher(BaseDispatcher):
|
|
669
727
|
def dispatch_a(
|
670
728
|
self,
|
671
729
|
hidden_states: torch.Tensor,
|
730
|
+
input_global_scale: Optional[torch.Tensor],
|
672
731
|
topk_idx: torch.Tensor,
|
673
732
|
topk_weights: torch.Tensor,
|
674
733
|
forward_batch: ForwardBatch,
|
@@ -676,6 +735,7 @@ class DeepEPDispatcher(BaseDispatcher):
|
|
676
735
|
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
677
736
|
inner_state = self._get_impl(forward_batch).dispatch_a(
|
678
737
|
hidden_states=hidden_states,
|
738
|
+
input_global_scale=input_global_scale,
|
679
739
|
topk_idx=topk_idx,
|
680
740
|
topk_weights=topk_weights,
|
681
741
|
)
|
@@ -698,12 +758,14 @@ class DeepEPDispatcher(BaseDispatcher):
|
|
698
758
|
topk_idx: torch.Tensor,
|
699
759
|
topk_weights: torch.Tensor,
|
700
760
|
forward_batch: ForwardBatch,
|
761
|
+
overlap_args: Optional["CombineOverlapArgs"] = None,
|
701
762
|
):
|
702
763
|
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
703
764
|
inner_state = self._get_impl(forward_batch).combine_a(
|
704
765
|
hidden_states=hidden_states,
|
705
766
|
topk_idx=topk_idx,
|
706
767
|
topk_weights=topk_weights,
|
768
|
+
overlap_args=overlap_args,
|
707
769
|
)
|
708
770
|
self._combine_intermediate_state = forward_batch, inner_state
|
709
771
|
|
@@ -1,19 +1,61 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import NamedTuple
|
3
|
+
from typing import TYPE_CHECKING, NamedTuple
|
4
4
|
|
5
|
-
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from sglang.srt.layers.moe.token_dispatcher.base import (
|
8
|
+
BaseDispatcher,
|
9
|
+
CombineInput,
|
10
|
+
CombineInputFormat,
|
6
11
|
DispatchOutput,
|
7
12
|
DispatchOutputFormat,
|
8
13
|
)
|
9
14
|
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
17
|
+
|
10
18
|
|
11
19
|
class StandardDispatchOutput(NamedTuple):
|
12
20
|
"""Standard dispatch output."""
|
13
21
|
|
22
|
+
hidden_states: torch.Tensor
|
23
|
+
topk_output: TopKOutput
|
24
|
+
|
14
25
|
@property
|
15
26
|
def format(self) -> DispatchOutputFormat:
|
16
27
|
return DispatchOutputFormat.STANDARD
|
17
28
|
|
18
29
|
|
19
30
|
assert isinstance(StandardDispatchOutput, DispatchOutput)
|
31
|
+
|
32
|
+
|
33
|
+
class StandardCombineInput(NamedTuple):
|
34
|
+
"""Standard combine input."""
|
35
|
+
|
36
|
+
hidden_states: torch.Tensor
|
37
|
+
|
38
|
+
@property
|
39
|
+
def format(self) -> CombineInputFormat:
|
40
|
+
return CombineInputFormat.STANDARD
|
41
|
+
|
42
|
+
|
43
|
+
assert isinstance(StandardCombineInput, CombineInput)
|
44
|
+
|
45
|
+
|
46
|
+
class StandardDispatcher(BaseDispatcher):
|
47
|
+
|
48
|
+
def dispatch(
|
49
|
+
self, hidden_states: torch.Tensor, topk_output: TopKOutput
|
50
|
+
) -> DispatchOutput:
|
51
|
+
return StandardDispatchOutput(
|
52
|
+
hidden_states=hidden_states, topk_output=topk_output
|
53
|
+
)
|
54
|
+
|
55
|
+
def combine(self, combine_input: CombineInput) -> torch.Tensor:
|
56
|
+
if isinstance(combine_input, StandardCombineInput):
|
57
|
+
return combine_input.hidden_states
|
58
|
+
else:
|
59
|
+
# TODO: this branch should be removed in the future
|
60
|
+
assert isinstance(combine_input, torch.Tensor)
|
61
|
+
return combine_input
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -19,6 +19,7 @@ import math
|
|
19
19
|
from dataclasses import dataclass
|
20
20
|
from enum import Enum, auto
|
21
21
|
from typing import (
|
22
|
+
TYPE_CHECKING,
|
22
23
|
Callable,
|
23
24
|
NamedTuple,
|
24
25
|
Optional,
|
@@ -51,6 +52,9 @@ from sglang.srt.utils import (
|
|
51
52
|
is_npu,
|
52
53
|
)
|
53
54
|
|
55
|
+
if TYPE_CHECKING:
|
56
|
+
from sglang.srt.layers.quantization import QuantizationConfig
|
57
|
+
|
54
58
|
try:
|
55
59
|
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
56
60
|
except ImportError:
|
@@ -94,6 +98,7 @@ class TopKConfig:
|
|
94
98
|
torch_native: bool = False
|
95
99
|
routed_scaling_factor: Optional[float] = None
|
96
100
|
apply_routed_scaling_factor_on_output: bool = False
|
101
|
+
output_format: Optional[TopKOutputFormat] = None
|
97
102
|
|
98
103
|
|
99
104
|
# -------------------------------- TopKOutput ---------------------------------------
|
@@ -196,9 +201,10 @@ class TopK(CustomOp):
|
|
196
201
|
custom_routing_function: Optional[Callable] = None,
|
197
202
|
scoring_func: str = "softmax",
|
198
203
|
correction_bias: Optional[torch.Tensor] = None,
|
204
|
+
quant_config: Optional[QuantizationConfig] = None,
|
199
205
|
routed_scaling_factor: Optional[float] = None,
|
200
206
|
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
201
|
-
|
207
|
+
output_format: Optional[TopKOutputFormat] = None,
|
202
208
|
):
|
203
209
|
# NOTE: scoring_func is not used for now, but we keep it for future use
|
204
210
|
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
@@ -218,11 +224,9 @@ class TopK(CustomOp):
|
|
218
224
|
correction_bias=correction_bias,
|
219
225
|
routed_scaling_factor=routed_scaling_factor,
|
220
226
|
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
227
|
+
output_format=output_format,
|
221
228
|
)
|
222
229
|
|
223
|
-
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
224
|
-
self.force_topk = force_topk
|
225
|
-
|
226
230
|
def forward_native(
|
227
231
|
self,
|
228
232
|
hidden_states: torch.Tensor,
|
@@ -248,7 +252,19 @@ class TopK(CustomOp):
|
|
248
252
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
249
253
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
250
254
|
) -> TopKOutput:
|
251
|
-
if self.
|
255
|
+
if self.topk_config.output_format is not None:
|
256
|
+
output_format = self.topk_config.output_format
|
257
|
+
elif get_moe_runner_backend().is_triton_kernel():
|
258
|
+
output_format = TopKOutputFormat.TRITON_KERNEL
|
259
|
+
elif (
|
260
|
+
should_use_flashinfer_trtllm_moe()
|
261
|
+
or get_moe_runner_backend().is_flashinfer_mxfp4()
|
262
|
+
):
|
263
|
+
output_format = TopKOutputFormat.BYPASSED
|
264
|
+
else:
|
265
|
+
output_format = TopKOutputFormat.STANDARD
|
266
|
+
|
267
|
+
if output_format == TopKOutputFormat.TRITON_KERNEL:
|
252
268
|
# renormalize=True is equivalent to sm_first=False
|
253
269
|
routing_data, gather_idx, scatter_idx = routing(
|
254
270
|
router_logits,
|
@@ -256,10 +272,7 @@ class TopK(CustomOp):
|
|
256
272
|
sm_first=not self.topk_config.renormalize,
|
257
273
|
)
|
258
274
|
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
259
|
-
elif
|
260
|
-
should_use_flashinfer_trtllm_moe()
|
261
|
-
or get_moe_runner_backend().is_flashinfer_mxfp4()
|
262
|
-
):
|
275
|
+
elif output_format == TopKOutputFormat.BYPASSED:
|
263
276
|
return BypassedTopKOutput(
|
264
277
|
hidden_states=hidden_states,
|
265
278
|
router_logits=router_logits,
|
@@ -330,6 +343,14 @@ class TopK(CustomOp):
|
|
330
343
|
)
|
331
344
|
topk_weights = topk_weights / topk_weights_sum
|
332
345
|
|
346
|
+
if expert_location_dispatch_info is not None:
|
347
|
+
topk_ids = topk_ids_logical_to_physical(
|
348
|
+
topk_ids, expert_location_dispatch_info
|
349
|
+
)
|
350
|
+
get_global_expert_distribution_recorder().on_select_experts(
|
351
|
+
topk_ids=topk_ids
|
352
|
+
)
|
353
|
+
|
333
354
|
return StandardTopKOutput(topk_weights, topk_ids, _)
|
334
355
|
else:
|
335
356
|
self.topk_config.torch_native = True
|
sglang/srt/layers/moe/utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import importlib.util
|
4
|
+
import logging
|
4
5
|
from enum import Enum
|
5
6
|
from functools import lru_cache
|
6
7
|
from typing import TYPE_CHECKING, Optional
|
@@ -12,11 +13,12 @@ from sglang.srt.layers.dp_attention import (
|
|
12
13
|
get_attention_dp_size,
|
13
14
|
is_dp_attention_enabled,
|
14
15
|
)
|
15
|
-
from sglang.srt.utils import logger
|
16
16
|
|
17
17
|
if TYPE_CHECKING:
|
18
18
|
from sglang.srt.server_args import ServerArgs
|
19
19
|
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
20
22
|
|
21
23
|
class MoeA2ABackend(Enum):
|
22
24
|
|
@@ -44,9 +46,10 @@ class MoeRunnerBackend(Enum):
|
|
44
46
|
AUTO = "auto"
|
45
47
|
TRITON = "triton"
|
46
48
|
TRITON_KERNEL = "triton_kernel"
|
47
|
-
|
49
|
+
FLASHINFER_TRTLLM = "flashinfer_trtllm"
|
48
50
|
FLASHINFER_CUTLASS = "flashinfer_cutlass"
|
49
51
|
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
|
52
|
+
FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
|
50
53
|
|
51
54
|
def is_auto(self):
|
52
55
|
return self == MoeRunnerBackend.AUTO
|
@@ -58,11 +61,14 @@ class MoeRunnerBackend(Enum):
|
|
58
61
|
return self == MoeRunnerBackend.TRITON_KERNEL
|
59
62
|
|
60
63
|
def is_flashinfer_trtllm(self):
|
61
|
-
return self == MoeRunnerBackend.
|
64
|
+
return self == MoeRunnerBackend.FLASHINFER_TRTLLM
|
62
65
|
|
63
66
|
def is_flashinfer_cutlass(self):
|
64
67
|
return self == MoeRunnerBackend.FLASHINFER_CUTLASS
|
65
68
|
|
69
|
+
def is_flashinfer_cutedsl(self):
|
70
|
+
return self == MoeRunnerBackend.FLASHINFER_CUTEDSL
|
71
|
+
|
66
72
|
def is_flashinfer_mxfp4(self):
|
67
73
|
return self == MoeRunnerBackend.FLASHINFER_MXFP4
|
68
74
|
|
@@ -102,6 +108,7 @@ MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
|
|
102
108
|
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
|
103
109
|
DEEPEP_MODE: Optional[DeepEPMode] = None
|
104
110
|
IS_TBO_ENABLED: Optional[bool] = None
|
111
|
+
IS_SBO_ENABLED: Optional[bool] = None
|
105
112
|
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
|
106
113
|
DEEPEP_CONFIG: Optional[str] = None
|
107
114
|
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
|
@@ -113,6 +120,7 @@ def initialize_moe_config(server_args: ServerArgs):
|
|
113
120
|
global DEEPEP_MODE
|
114
121
|
global DEEPEP_CONFIG
|
115
122
|
global IS_TBO_ENABLED
|
123
|
+
global IS_SBO_ENABLED
|
116
124
|
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
117
125
|
global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
|
118
126
|
|
@@ -121,6 +129,7 @@ def initialize_moe_config(server_args: ServerArgs):
|
|
121
129
|
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
|
122
130
|
DEEPEP_CONFIG = server_args.deepep_config or ""
|
123
131
|
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
|
132
|
+
IS_SBO_ENABLED = server_args.enable_single_batch_overlap
|
124
133
|
TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
|
125
134
|
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
|
126
135
|
server_args.disable_flashinfer_cutlass_moe_fp4_allgather
|
@@ -131,7 +140,7 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
|
|
131
140
|
global MOE_A2A_BACKEND
|
132
141
|
if MOE_A2A_BACKEND is None:
|
133
142
|
logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
|
134
|
-
MOE_A2A_BACKEND = MoeA2ABackend
|
143
|
+
MOE_A2A_BACKEND = MoeA2ABackend.NONE
|
135
144
|
return MOE_A2A_BACKEND
|
136
145
|
|
137
146
|
|
@@ -139,7 +148,7 @@ def get_moe_runner_backend() -> MoeRunnerBackend:
|
|
139
148
|
global MOE_RUNNER_BACKEND
|
140
149
|
if MOE_RUNNER_BACKEND is None:
|
141
150
|
logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
|
142
|
-
MOE_RUNNER_BACKEND = MoeRunnerBackend
|
151
|
+
MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
|
143
152
|
return MOE_RUNNER_BACKEND
|
144
153
|
|
145
154
|
|
@@ -147,7 +156,7 @@ def get_deepep_mode() -> DeepEPMode:
|
|
147
156
|
global DEEPEP_MODE
|
148
157
|
if DEEPEP_MODE is None:
|
149
158
|
logger.warning("DEEPEP_MODE is not initialized, using auto mode")
|
150
|
-
DEEPEP_MODE = DeepEPMode
|
159
|
+
DEEPEP_MODE = DeepEPMode.AUTO
|
151
160
|
return DEEPEP_MODE
|
152
161
|
|
153
162
|
|
@@ -166,6 +175,13 @@ def is_tbo_enabled() -> bool:
|
|
166
175
|
return IS_TBO_ENABLED
|
167
176
|
|
168
177
|
|
178
|
+
def is_sbo_enabled() -> bool:
|
179
|
+
global IS_SBO_ENABLED
|
180
|
+
if IS_SBO_ENABLED is None:
|
181
|
+
IS_SBO_ENABLED = False
|
182
|
+
return IS_SBO_ENABLED
|
183
|
+
|
184
|
+
|
169
185
|
def get_tbo_token_distribution_threshold() -> float:
|
170
186
|
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
171
187
|
if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
|
sglang/srt/layers/parameter.py
CHANGED
@@ -7,6 +7,7 @@ from typing import Callable, Optional, Union
|
|
7
7
|
import torch
|
8
8
|
from torch.nn import Parameter
|
9
9
|
|
10
|
+
from sglang.srt.layers.utils import pad_or_narrow_weight
|
10
11
|
from sglang.srt.utils import is_cpu
|
11
12
|
|
12
13
|
__all__ = [
|
@@ -156,9 +157,17 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
156
157
|
)
|
157
158
|
else:
|
158
159
|
if not use_presharded_weights:
|
159
|
-
|
160
|
-
|
161
|
-
|
160
|
+
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
161
|
+
start_idx = tp_rank * shard_size
|
162
|
+
end_idx = start_idx + shard_size
|
163
|
+
if end_idx > loaded_weight.shape[self.output_dim]:
|
164
|
+
loaded_weight = pad_or_narrow_weight(
|
165
|
+
loaded_weight, self.output_dim, start_idx, shard_size
|
166
|
+
)
|
167
|
+
else:
|
168
|
+
loaded_weight = loaded_weight.narrow(
|
169
|
+
self.output_dim, start_idx, shard_size
|
170
|
+
)
|
162
171
|
|
163
172
|
assert param_data.shape == loaded_weight.shape
|
164
173
|
param_data.copy_(loaded_weight)
|
@@ -258,9 +267,17 @@ class RowvLLMParameter(BasevLLMParameter):
|
|
258
267
|
|
259
268
|
return
|
260
269
|
else:
|
261
|
-
|
262
|
-
|
263
|
-
|
270
|
+
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
271
|
+
start_idx = tp_rank * shard_size
|
272
|
+
end_idx = start_idx + shard_size
|
273
|
+
if end_idx > loaded_weight.shape[self.input_dim]:
|
274
|
+
loaded_weight = pad_or_narrow_weight(
|
275
|
+
loaded_weight, self.input_dim, start_idx, shard_size
|
276
|
+
)
|
277
|
+
else:
|
278
|
+
loaded_weight = loaded_weight.narrow(
|
279
|
+
self.input_dim, start_idx, shard_size
|
280
|
+
)
|
264
281
|
|
265
282
|
if len(loaded_weight.shape) == 0:
|
266
283
|
loaded_weight = loaded_weight.reshape(1)
|