sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +192 -113
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +132 -57
- sglang/srt/entrypoints/openai/protocol.py +115 -7
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +207 -58
- sglang/srt/entrypoints/openai/serving_completions.py +17 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +49 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +106 -82
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +53 -7
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +225 -57
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +78 -49
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +215 -314
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +147 -19
- sglang/srt/managers/scheduler.py +501 -304
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +321 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +15 -21
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +58 -34
- sglang/srt/mem_cache/hiradix_cache.py +227 -80
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -223
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +519 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +55 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +98 -57
- sglang/srt/model_executor/model_runner.py +433 -158
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +833 -152
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +14 -5
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +124 -14
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +26 -5
- sglang/srt/models/qwen3_moe.py +71 -12
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +10 -3
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1030 -254
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +253 -136
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +445 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +22 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -40,7 +40,6 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
40
40
|
Qwen2_5_VisionRotaryEmbedding,
|
41
41
|
)
|
42
42
|
|
43
|
-
from sglang.srt.hf_transformers_utils import get_processor
|
44
43
|
from sglang.srt.layers.attention.vision import VisionAttention
|
45
44
|
from sglang.srt.layers.layernorm import RMSNorm
|
46
45
|
from sglang.srt.layers.linear import (
|
@@ -61,6 +60,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
61
60
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
62
61
|
from sglang.srt.models.qwen2 import Qwen2Model
|
63
62
|
from sglang.srt.utils import add_prefix
|
63
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
64
64
|
|
65
65
|
logger = logging.getLogger(__name__)
|
66
66
|
|
@@ -113,12 +113,13 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
113
113
|
quant_config: Optional[QuantizationConfig] = None,
|
114
114
|
prefix: str = "",
|
115
115
|
num_dummy_heads: int = 0,
|
116
|
+
rms_norm_eps: float = 1e-6,
|
116
117
|
) -> None:
|
117
118
|
super().__init__()
|
118
119
|
if norm_layer is None:
|
119
120
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
120
|
-
self.norm1 = RMSNorm(dim, eps=
|
121
|
-
self.norm2 = RMSNorm(dim, eps=
|
121
|
+
self.norm1 = RMSNorm(dim, eps=rms_norm_eps)
|
122
|
+
self.norm2 = RMSNorm(dim, eps=rms_norm_eps)
|
122
123
|
|
123
124
|
if attn_implementation is None:
|
124
125
|
softmax_in_single_precision = False
|
@@ -264,7 +265,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
264
265
|
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
|
265
266
|
self.window_size = vision_config.window_size
|
266
267
|
self.patch_size = vision_config.patch_size
|
267
|
-
mlp_hidden_size: int = vision_config.intermediate_size
|
268
|
+
mlp_hidden_size: int = ((vision_config.intermediate_size + 7) // 8) * 8
|
268
269
|
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
269
270
|
patch_size=patch_size,
|
270
271
|
temporal_patch_size=temporal_patch_size,
|
@@ -517,6 +518,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
517
518
|
self.logits_processor = LogitsProcessor(config)
|
518
519
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
519
520
|
|
521
|
+
# For EAGLE3 support
|
522
|
+
self.capture_aux_hidden_states = False
|
523
|
+
|
520
524
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
521
525
|
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
522
526
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
@@ -587,9 +591,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
587
591
|
positions=positions,
|
588
592
|
)
|
589
593
|
|
594
|
+
aux_hidden_states = None
|
595
|
+
if self.capture_aux_hidden_states:
|
596
|
+
hidden_states, aux_hidden_states = hidden_states
|
597
|
+
|
590
598
|
if not get_embedding:
|
591
599
|
return self.logits_processor(
|
592
|
-
input_ids, hidden_states, self.lm_head, forward_batch
|
600
|
+
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
593
601
|
)
|
594
602
|
else:
|
595
603
|
return self.pooler(hidden_states, forward_batch)
|
@@ -643,5 +651,21 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
643
651
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
644
652
|
weight_loader(param, loaded_weight)
|
645
653
|
|
654
|
+
def get_embed_and_head(self):
|
655
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
656
|
+
|
657
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
658
|
+
self.capture_aux_hidden_states = True
|
659
|
+
self.model.capture_aux_hidden_states = True
|
660
|
+
if layer_ids is None:
|
661
|
+
num_layers = self.config.num_hidden_layers
|
662
|
+
self.model.layers_to_capture = [
|
663
|
+
2,
|
664
|
+
num_layers // 2,
|
665
|
+
num_layers - 3,
|
666
|
+
] # Specific layers for EAGLE3 support
|
667
|
+
else:
|
668
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
669
|
+
|
646
670
|
|
647
671
|
EntryClass = [Qwen2_5_VLForConditionalGeneration]
|
sglang/srt/models/qwen2_audio.py
CHANGED
@@ -39,7 +39,6 @@ from transformers.models.qwen2_audio.modeling_qwen2_audio import (
|
|
39
39
|
Qwen2AudioMultiModalProjector,
|
40
40
|
)
|
41
41
|
|
42
|
-
from sglang.srt.hf_transformers_utils import get_processor
|
43
42
|
from sglang.srt.layers.activation import QuickGELU
|
44
43
|
from sglang.srt.layers.attention.vision import VisionAttention
|
45
44
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
@@ -61,6 +60,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
61
60
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
62
61
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
63
62
|
from sglang.srt.utils import add_prefix
|
63
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
64
64
|
|
65
65
|
logger = logging.getLogger(__name__)
|
66
66
|
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -25,12 +25,14 @@ from torch import nn
|
|
25
25
|
from transformers import PretrainedConfig
|
26
26
|
|
27
27
|
from sglang.srt.distributed import (
|
28
|
+
get_moe_expert_parallel_world_size,
|
28
29
|
get_pp_group,
|
29
30
|
get_tensor_model_parallel_world_size,
|
30
31
|
tensor_model_parallel_all_reduce,
|
31
32
|
)
|
32
33
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
33
34
|
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
35
|
+
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
34
36
|
from sglang.srt.layers.activation import SiluAndMul
|
35
37
|
from sglang.srt.layers.communicator import (
|
36
38
|
LayerCommunicator,
|
@@ -50,6 +52,7 @@ from sglang.srt.layers.linear import (
|
|
50
52
|
RowParallelLinear,
|
51
53
|
)
|
52
54
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
55
|
+
from sglang.srt.layers.moe import get_moe_a2a_backend
|
53
56
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
54
57
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
55
58
|
from sglang.srt.layers.moe.topk import TopK
|
@@ -62,13 +65,16 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
62
65
|
VocabParallelEmbedding,
|
63
66
|
)
|
64
67
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
68
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
65
69
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
66
70
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
67
71
|
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
|
68
|
-
from sglang.srt.utils import add_prefix, make_layers
|
72
|
+
from sglang.srt.utils import add_prefix, is_cuda, make_layers
|
69
73
|
|
70
74
|
logger = logging.getLogger(__name__)
|
71
75
|
|
76
|
+
_is_cuda = is_cuda()
|
77
|
+
|
72
78
|
|
73
79
|
class Qwen2MoeMLP(nn.Module):
|
74
80
|
def __init__(
|
@@ -79,6 +85,8 @@ class Qwen2MoeMLP(nn.Module):
|
|
79
85
|
quant_config: Optional[QuantizationConfig] = None,
|
80
86
|
reduce_results: bool = True,
|
81
87
|
prefix: str = "",
|
88
|
+
tp_rank: Optional[int] = None,
|
89
|
+
tp_size: Optional[int] = None,
|
82
90
|
) -> None:
|
83
91
|
super().__init__()
|
84
92
|
self.gate_up_proj = MergedColumnParallelLinear(
|
@@ -87,6 +95,8 @@ class Qwen2MoeMLP(nn.Module):
|
|
87
95
|
bias=False,
|
88
96
|
quant_config=quant_config,
|
89
97
|
prefix=add_prefix("gate_up_proj", prefix),
|
98
|
+
tp_rank=tp_rank,
|
99
|
+
tp_size=tp_size,
|
90
100
|
)
|
91
101
|
self.down_proj = RowParallelLinear(
|
92
102
|
intermediate_size,
|
@@ -95,6 +105,8 @@ class Qwen2MoeMLP(nn.Module):
|
|
95
105
|
quant_config=quant_config,
|
96
106
|
reduce_results=reduce_results,
|
97
107
|
prefix=add_prefix("down_proj", prefix),
|
108
|
+
tp_rank=tp_rank,
|
109
|
+
tp_size=tp_size,
|
98
110
|
)
|
99
111
|
if hidden_act != "silu":
|
100
112
|
raise ValueError(
|
@@ -105,11 +117,14 @@ class Qwen2MoeMLP(nn.Module):
|
|
105
117
|
def forward(
|
106
118
|
self,
|
107
119
|
x,
|
120
|
+
should_allreduce_fusion: bool = False,
|
108
121
|
use_reduce_scatter: bool = False,
|
109
122
|
):
|
110
123
|
gate_up, _ = self.gate_up_proj(x)
|
111
124
|
x = self.act_fn(gate_up)
|
112
|
-
x, _ = self.down_proj(
|
125
|
+
x, _ = self.down_proj(
|
126
|
+
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
|
127
|
+
)
|
113
128
|
return x
|
114
129
|
|
115
130
|
|
@@ -119,11 +134,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
119
134
|
layer_id: int,
|
120
135
|
config: PretrainedConfig,
|
121
136
|
quant_config: Optional[QuantizationConfig] = None,
|
137
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
122
138
|
prefix: str = "",
|
123
139
|
):
|
124
140
|
super().__init__()
|
125
141
|
self.tp_size = get_tensor_model_parallel_world_size()
|
126
142
|
self.layer_id = layer_id
|
143
|
+
self.alt_stream = alt_stream
|
127
144
|
if self.tp_size > config.num_experts:
|
128
145
|
raise ValueError(
|
129
146
|
f"Tensor parallel size {self.tp_size} is greater than "
|
@@ -135,10 +152,11 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
135
152
|
renormalize=config.norm_topk_prob,
|
136
153
|
)
|
137
154
|
|
138
|
-
self.experts = get_moe_impl_class()(
|
155
|
+
self.experts = get_moe_impl_class(quant_config)(
|
139
156
|
layer_id=self.layer_id,
|
140
157
|
top_k=config.num_experts_per_tok,
|
141
|
-
num_experts=config.num_experts
|
158
|
+
num_experts=config.num_experts
|
159
|
+
+ global_server_args_dict["ep_num_redundant_experts"],
|
142
160
|
hidden_size=config.hidden_size,
|
143
161
|
intermediate_size=config.moe_intermediate_size,
|
144
162
|
quant_config=quant_config,
|
@@ -160,19 +178,32 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
160
178
|
quant_config=quant_config,
|
161
179
|
reduce_results=False,
|
162
180
|
prefix=add_prefix("shared_expert", prefix),
|
181
|
+
**(
|
182
|
+
dict(tp_rank=0, tp_size=1)
|
183
|
+
if get_moe_a2a_backend().is_deepep()
|
184
|
+
else {}
|
185
|
+
),
|
163
186
|
)
|
164
187
|
else:
|
165
188
|
self.shared_expert = None
|
166
189
|
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
167
190
|
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
191
|
+
if get_moe_a2a_backend().is_deepep():
|
192
|
+
# TODO: we will support tp < ep in the future
|
193
|
+
self.ep_size = get_moe_expert_parallel_world_size()
|
194
|
+
self.num_experts = (
|
195
|
+
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
|
196
|
+
)
|
197
|
+
self.top_k = config.num_experts_per_tok
|
198
|
+
|
199
|
+
def get_moe_weights(self):
|
200
|
+
return [
|
201
|
+
x.data
|
202
|
+
for name, x in self.experts.named_parameters()
|
203
|
+
if name not in ["correction_bias"]
|
204
|
+
]
|
205
|
+
|
206
|
+
def _forward_shared_experts(self, hidden_states: torch.Tensor):
|
176
207
|
shared_output = None
|
177
208
|
if self.shared_expert is not None:
|
178
209
|
shared_output = self.shared_expert(hidden_states)
|
@@ -180,11 +211,85 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
180
211
|
shared_output = (
|
181
212
|
F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
|
182
213
|
)
|
214
|
+
return shared_output
|
215
|
+
|
216
|
+
def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
|
217
|
+
shared_output = None
|
218
|
+
if hidden_states.shape[0] > 0:
|
219
|
+
# router_logits: (num_tokens, n_experts)
|
220
|
+
router_logits, _ = self.gate(hidden_states)
|
221
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
222
|
+
topk_weights, topk_idx, _ = self.topk(
|
223
|
+
hidden_states,
|
224
|
+
router_logits,
|
225
|
+
num_token_non_padded=forward_batch.num_token_non_padded,
|
226
|
+
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
227
|
+
layer_id=self.layer_id,
|
228
|
+
),
|
229
|
+
)
|
230
|
+
else:
|
231
|
+
topk_weights, topk_idx, _ = self.topk.empty_topk_output(
|
232
|
+
hidden_states.device
|
233
|
+
)
|
234
|
+
final_hidden_states = self.experts(
|
235
|
+
hidden_states=hidden_states,
|
236
|
+
topk_idx=topk_idx,
|
237
|
+
topk_weights=topk_weights,
|
238
|
+
forward_batch=forward_batch,
|
239
|
+
)
|
240
|
+
|
241
|
+
if shared_output is not None:
|
242
|
+
final_hidden_states.add_(shared_output)
|
183
243
|
|
244
|
+
return final_hidden_states
|
245
|
+
|
246
|
+
def _forward_router_experts(self, hidden_states: torch.Tensor):
|
184
247
|
# router_logits: (num_tokens, n_experts)
|
185
248
|
router_logits, _ = self.gate(hidden_states)
|
186
249
|
topk_output = self.topk(hidden_states, router_logits)
|
187
|
-
|
250
|
+
return self.experts(hidden_states, topk_output)
|
251
|
+
|
252
|
+
def forward_normal_dual_stream(
|
253
|
+
self,
|
254
|
+
hidden_states: torch.Tensor,
|
255
|
+
) -> torch.Tensor:
|
256
|
+
current_stream = torch.cuda.current_stream()
|
257
|
+
self.alt_stream.wait_stream(current_stream)
|
258
|
+
shared_output = self._forward_shared_experts(hidden_states.clone())
|
259
|
+
|
260
|
+
with torch.cuda.stream(self.alt_stream):
|
261
|
+
router_output = self._forward_router_experts(hidden_states)
|
262
|
+
|
263
|
+
current_stream.wait_stream(self.alt_stream)
|
264
|
+
|
265
|
+
return router_output, shared_output
|
266
|
+
|
267
|
+
def forward(
|
268
|
+
self,
|
269
|
+
hidden_states: torch.Tensor,
|
270
|
+
forward_batch: Optional[ForwardBatch] = None,
|
271
|
+
use_reduce_scatter: bool = False,
|
272
|
+
) -> torch.Tensor:
|
273
|
+
num_tokens, hidden_dim = hidden_states.shape
|
274
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
275
|
+
|
276
|
+
if get_moe_a2a_backend().is_deepep():
|
277
|
+
return self._forward_deepep(hidden_states, forward_batch)
|
278
|
+
|
279
|
+
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
280
|
+
if (
|
281
|
+
self.alt_stream is not None
|
282
|
+
and hidden_states.shape[0] > 0
|
283
|
+
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
284
|
+
and get_is_capture_mode()
|
285
|
+
):
|
286
|
+
final_hidden_states, shared_output = self.forward_normal_dual_stream(
|
287
|
+
hidden_states
|
288
|
+
)
|
289
|
+
else:
|
290
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
291
|
+
final_hidden_states = self._forward_router_experts(hidden_states)
|
292
|
+
|
188
293
|
if shared_output is not None:
|
189
294
|
final_hidden_states = final_hidden_states + shared_output
|
190
295
|
if self.tp_size > 1 and not use_reduce_scatter:
|
@@ -343,6 +448,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
343
448
|
layer_id=layer_id,
|
344
449
|
config=config,
|
345
450
|
quant_config=quant_config,
|
451
|
+
alt_stream=alt_stream,
|
346
452
|
prefix=add_prefix("mlp", prefix),
|
347
453
|
)
|
348
454
|
else:
|
@@ -525,8 +631,12 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
525
631
|
self.pp_group = get_pp_group()
|
526
632
|
self.config = config
|
527
633
|
self.quant_config = quant_config
|
634
|
+
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
528
635
|
self.model = Qwen2MoeModel(
|
529
|
-
config,
|
636
|
+
config,
|
637
|
+
quant_config,
|
638
|
+
prefix=add_prefix("model", prefix),
|
639
|
+
alt_stream=alt_stream,
|
530
640
|
)
|
531
641
|
self.lm_head = ParallelLMHead(
|
532
642
|
config.vocab_size,
|
sglang/srt/models/qwen2_vl.py
CHANGED
@@ -33,7 +33,6 @@ from einops import rearrange
|
|
33
33
|
from transformers import Qwen2VLConfig
|
34
34
|
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
|
35
35
|
|
36
|
-
from sglang.srt.hf_transformers_utils import get_processor
|
37
36
|
from sglang.srt.layers.activation import QuickGELU
|
38
37
|
from sglang.srt.layers.attention.vision import VisionAttention
|
39
38
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
@@ -50,6 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
50
49
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
51
50
|
from sglang.srt.models.qwen2 import Qwen2Model
|
52
51
|
from sglang.srt.utils import add_prefix
|
52
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
53
53
|
|
54
54
|
logger = logging.getLogger(__name__)
|
55
55
|
|
sglang/srt/models/qwen3.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
# Adapted from qwen2.py
|
2
2
|
import logging
|
3
|
-
from functools import partial
|
4
3
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
5
4
|
|
6
5
|
import torch
|
@@ -24,15 +23,25 @@ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
|
24
23
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
25
24
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
26
25
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
27
|
-
from sglang.srt.model_loader.weight_utils import
|
26
|
+
from sglang.srt.model_loader.weight_utils import (
|
27
|
+
default_weight_loader,
|
28
|
+
maybe_remap_kv_scale_name,
|
29
|
+
)
|
28
30
|
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
29
31
|
from sglang.srt.models.qwen2 import Qwen2Model
|
30
|
-
from sglang.srt.utils import
|
32
|
+
from sglang.srt.utils import (
|
33
|
+
add_prefix,
|
34
|
+
get_cmo_stream,
|
35
|
+
is_cuda,
|
36
|
+
is_npu,
|
37
|
+
wait_cmo_stream,
|
38
|
+
)
|
31
39
|
|
32
40
|
Qwen3Config = None
|
33
41
|
|
34
42
|
logger = logging.getLogger(__name__)
|
35
43
|
_is_cuda = is_cuda()
|
44
|
+
_is_npu = is_npu()
|
36
45
|
|
37
46
|
|
38
47
|
class Qwen3Attention(nn.Module):
|
@@ -232,9 +241,18 @@ class Qwen3DecoderLayer(nn.Module):
|
|
232
241
|
|
233
242
|
# Fully Connected
|
234
243
|
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
235
|
-
hidden_states,
|
244
|
+
hidden_states,
|
245
|
+
residual,
|
246
|
+
forward_batch,
|
247
|
+
cache=(
|
248
|
+
[self.mlp.gate_up_proj.weight, self.mlp.down_proj.weight]
|
249
|
+
if _is_npu
|
250
|
+
else None
|
251
|
+
),
|
236
252
|
)
|
237
253
|
hidden_states = self.mlp(hidden_states)
|
254
|
+
if _is_npu and get_cmo_stream():
|
255
|
+
wait_cmo_stream()
|
238
256
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
239
257
|
hidden_states, residual, forward_batch
|
240
258
|
)
|
@@ -458,7 +476,10 @@ class Qwen3ForCausalLM(nn.Module):
|
|
458
476
|
continue
|
459
477
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
460
478
|
continue
|
461
|
-
|
479
|
+
if "scale" in name:
|
480
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
481
|
+
if name is None:
|
482
|
+
continue
|
462
483
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
463
484
|
if weight_name not in name:
|
464
485
|
continue
|
sglang/srt/models/qwen3_moe.py
CHANGED
@@ -42,13 +42,16 @@ from sglang.srt.layers.linear import (
|
|
42
42
|
RowParallelLinear,
|
43
43
|
)
|
44
44
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
45
|
-
from sglang.srt.layers.moe import
|
45
|
+
from sglang.srt.layers.moe import (
|
46
|
+
get_moe_a2a_backend,
|
47
|
+
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
48
|
+
)
|
46
49
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
47
50
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
48
51
|
from sglang.srt.layers.moe.topk import TopK
|
49
52
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
50
53
|
from sglang.srt.layers.radix_attention import RadixAttention
|
51
|
-
from sglang.srt.layers.rotary_embedding import get_rope
|
54
|
+
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope
|
52
55
|
from sglang.srt.layers.utils import get_layer_id
|
53
56
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
54
57
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -57,10 +60,21 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
|
|
57
60
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
58
61
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
59
62
|
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
60
|
-
from sglang.srt.utils import
|
63
|
+
from sglang.srt.models.utils import (
|
64
|
+
create_fused_set_kv_buffer_arg,
|
65
|
+
enable_fused_set_kv_buffer,
|
66
|
+
)
|
67
|
+
from sglang.srt.utils import (
|
68
|
+
add_prefix,
|
69
|
+
is_cuda,
|
70
|
+
is_flashinfer_available,
|
71
|
+
is_non_idle_and_non_empty,
|
72
|
+
)
|
61
73
|
|
62
74
|
Qwen3MoeConfig = None
|
63
75
|
|
76
|
+
_is_flashinfer_available = is_flashinfer_available()
|
77
|
+
|
64
78
|
logger = logging.getLogger(__name__)
|
65
79
|
_is_cuda = is_cuda()
|
66
80
|
|
@@ -88,7 +102,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
88
102
|
use_grouped_topk=False,
|
89
103
|
)
|
90
104
|
|
91
|
-
self.experts = get_moe_impl_class()(
|
105
|
+
self.experts = get_moe_impl_class(quant_config)(
|
92
106
|
num_experts=config.num_experts
|
93
107
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
94
108
|
top_k=config.num_experts_per_tok,
|
@@ -119,11 +133,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
119
133
|
self,
|
120
134
|
hidden_states: torch.Tensor,
|
121
135
|
forward_batch: Optional[ForwardBatch] = None,
|
136
|
+
should_allreduce_fusion: bool = False,
|
122
137
|
use_reduce_scatter: bool = False,
|
123
138
|
) -> torch.Tensor:
|
124
139
|
|
125
140
|
if not get_moe_a2a_backend().is_deepep():
|
126
|
-
return self.forward_normal(
|
141
|
+
return self.forward_normal(
|
142
|
+
hidden_states, should_allreduce_fusion, use_reduce_scatter
|
143
|
+
)
|
127
144
|
else:
|
128
145
|
return self.forward_deepep(hidden_states, forward_batch)
|
129
146
|
|
@@ -137,6 +154,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
137
154
|
def forward_normal(
|
138
155
|
self,
|
139
156
|
hidden_states: torch.Tensor,
|
157
|
+
should_allreduce_fusion: bool = False,
|
140
158
|
use_reduce_scatter: bool = False,
|
141
159
|
) -> torch.Tensor:
|
142
160
|
num_tokens, hidden_dim = hidden_states.shape
|
@@ -146,7 +164,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
146
164
|
router_logits, _ = self.gate(hidden_states)
|
147
165
|
topk_output = self.topk(hidden_states, router_logits)
|
148
166
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
149
|
-
if
|
167
|
+
if (
|
168
|
+
self.tp_size > 1
|
169
|
+
and not should_allreduce_fusion
|
170
|
+
and not use_reduce_scatter
|
171
|
+
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
|
172
|
+
):
|
150
173
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
151
174
|
|
152
175
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
@@ -335,6 +358,10 @@ class Qwen3MoeAttention(nn.Module):
|
|
335
358
|
rope_scaling=rope_scaling,
|
336
359
|
dual_chunk_attention_config=dual_chunk_attention_config,
|
337
360
|
)
|
361
|
+
self.compatible_with_fused_kv_buffer = (
|
362
|
+
False if isinstance(self.rotary_emb, MRotaryEmbedding) else True
|
363
|
+
)
|
364
|
+
|
338
365
|
self.attn = RadixAttention(
|
339
366
|
self.num_heads,
|
340
367
|
self.head_dim,
|
@@ -393,7 +420,21 @@ class Qwen3MoeAttention(nn.Module):
|
|
393
420
|
qkv, _ = self.qkv_proj(hidden_states)
|
394
421
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
395
422
|
q, k = self._apply_qk_norm(q, k)
|
396
|
-
q, k = self.rotary_emb(
|
423
|
+
q, k = self.rotary_emb(
|
424
|
+
positions,
|
425
|
+
q,
|
426
|
+
k,
|
427
|
+
fused_set_kv_buffer_arg=(
|
428
|
+
create_fused_set_kv_buffer_arg(
|
429
|
+
value=v,
|
430
|
+
layer=self.attn,
|
431
|
+
forward_batch=forward_batch,
|
432
|
+
)
|
433
|
+
if enable_fused_set_kv_buffer(forward_batch)
|
434
|
+
and self.compatible_with_fused_kv_buffer
|
435
|
+
else None
|
436
|
+
),
|
437
|
+
)
|
397
438
|
inner_state = q, k, v, forward_batch
|
398
439
|
return None, forward_batch, inner_state
|
399
440
|
|
@@ -401,7 +442,13 @@ class Qwen3MoeAttention(nn.Module):
|
|
401
442
|
hidden_states, forward_batch, inner_state = intermediate_state
|
402
443
|
if inner_state is None:
|
403
444
|
return hidden_states
|
404
|
-
attn_output = self.attn(
|
445
|
+
attn_output = self.attn(
|
446
|
+
*inner_state,
|
447
|
+
save_kv_cache=not (
|
448
|
+
enable_fused_set_kv_buffer(forward_batch)
|
449
|
+
and self.compatible_with_fused_kv_buffer
|
450
|
+
),
|
451
|
+
)
|
405
452
|
output, _ = self.o_proj(attn_output)
|
406
453
|
return output
|
407
454
|
|
@@ -500,6 +547,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
500
547
|
input_layernorm=self.input_layernorm,
|
501
548
|
post_attention_layernorm=self.post_attention_layernorm,
|
502
549
|
allow_reduce_scatter=True,
|
550
|
+
is_last_layer=(self.layer_id == self.config.num_hidden_layers - 1),
|
503
551
|
)
|
504
552
|
|
505
553
|
def forward(
|
@@ -525,17 +573,28 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
525
573
|
hidden_states, residual, forward_batch
|
526
574
|
)
|
527
575
|
|
576
|
+
should_allreduce_fusion = (
|
577
|
+
self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
|
578
|
+
forward_batch
|
579
|
+
)
|
580
|
+
)
|
581
|
+
|
528
582
|
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
529
583
|
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
530
584
|
forward_batch
|
531
585
|
)
|
532
586
|
|
533
|
-
hidden_states = self.mlp(
|
534
|
-
|
535
|
-
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
536
|
-
hidden_states, residual, forward_batch
|
587
|
+
hidden_states = self.mlp(
|
588
|
+
hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
|
537
589
|
)
|
538
590
|
|
591
|
+
if should_allreduce_fusion:
|
592
|
+
hidden_states._sglang_needs_allreduce_fusion = True
|
593
|
+
else:
|
594
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
595
|
+
hidden_states, residual, forward_batch
|
596
|
+
)
|
597
|
+
|
539
598
|
return hidden_states, residual
|
540
599
|
|
541
600
|
def op_comm_prepare_attn(
|