sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -40,7 +40,6 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
40
40
|
Qwen2_5_VisionRotaryEmbedding,
|
41
41
|
)
|
42
42
|
|
43
|
-
from sglang.srt.hf_transformers_utils import get_processor
|
44
43
|
from sglang.srt.layers.attention.vision import VisionAttention
|
45
44
|
from sglang.srt.layers.layernorm import RMSNorm
|
46
45
|
from sglang.srt.layers.linear import (
|
@@ -61,6 +60,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
61
60
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
62
61
|
from sglang.srt.models.qwen2 import Qwen2Model
|
63
62
|
from sglang.srt.utils import add_prefix
|
63
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
64
64
|
|
65
65
|
logger = logging.getLogger(__name__)
|
66
66
|
|
@@ -113,12 +113,13 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
113
113
|
quant_config: Optional[QuantizationConfig] = None,
|
114
114
|
prefix: str = "",
|
115
115
|
num_dummy_heads: int = 0,
|
116
|
+
rms_norm_eps: float = 1e-6,
|
116
117
|
) -> None:
|
117
118
|
super().__init__()
|
118
119
|
if norm_layer is None:
|
119
120
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
120
|
-
self.norm1 = RMSNorm(dim, eps=
|
121
|
-
self.norm2 = RMSNorm(dim, eps=
|
121
|
+
self.norm1 = RMSNorm(dim, eps=rms_norm_eps)
|
122
|
+
self.norm2 = RMSNorm(dim, eps=rms_norm_eps)
|
122
123
|
|
123
124
|
if attn_implementation is None:
|
124
125
|
softmax_in_single_precision = False
|
@@ -264,7 +265,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
264
265
|
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
|
265
266
|
self.window_size = vision_config.window_size
|
266
267
|
self.patch_size = vision_config.patch_size
|
267
|
-
mlp_hidden_size: int = vision_config.intermediate_size
|
268
|
+
mlp_hidden_size: int = ((vision_config.intermediate_size + 7) // 8) * 8
|
268
269
|
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
269
270
|
patch_size=patch_size,
|
270
271
|
temporal_patch_size=temporal_patch_size,
|
@@ -517,6 +518,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
517
518
|
self.logits_processor = LogitsProcessor(config)
|
518
519
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
519
520
|
|
521
|
+
# For EAGLE3 support
|
522
|
+
self.capture_aux_hidden_states = False
|
523
|
+
|
520
524
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
521
525
|
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
522
526
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
@@ -587,9 +591,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
587
591
|
positions=positions,
|
588
592
|
)
|
589
593
|
|
594
|
+
aux_hidden_states = None
|
595
|
+
if self.capture_aux_hidden_states:
|
596
|
+
hidden_states, aux_hidden_states = hidden_states
|
597
|
+
|
590
598
|
if not get_embedding:
|
591
599
|
return self.logits_processor(
|
592
|
-
input_ids, hidden_states, self.lm_head, forward_batch
|
600
|
+
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
593
601
|
)
|
594
602
|
else:
|
595
603
|
return self.pooler(hidden_states, forward_batch)
|
@@ -643,5 +651,21 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
643
651
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
644
652
|
weight_loader(param, loaded_weight)
|
645
653
|
|
654
|
+
def get_embed_and_head(self):
|
655
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
656
|
+
|
657
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
658
|
+
self.capture_aux_hidden_states = True
|
659
|
+
self.model.capture_aux_hidden_states = True
|
660
|
+
if layer_ids is None:
|
661
|
+
num_layers = self.config.num_hidden_layers
|
662
|
+
self.model.layers_to_capture = [
|
663
|
+
2,
|
664
|
+
num_layers // 2,
|
665
|
+
num_layers - 3,
|
666
|
+
] # Specific layers for EAGLE3 support
|
667
|
+
else:
|
668
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
669
|
+
|
646
670
|
|
647
671
|
EntryClass = [Qwen2_5_VLForConditionalGeneration]
|
sglang/srt/models/qwen2_audio.py
CHANGED
@@ -39,7 +39,6 @@ from transformers.models.qwen2_audio.modeling_qwen2_audio import (
|
|
39
39
|
Qwen2AudioMultiModalProjector,
|
40
40
|
)
|
41
41
|
|
42
|
-
from sglang.srt.hf_transformers_utils import get_processor
|
43
42
|
from sglang.srt.layers.activation import QuickGELU
|
44
43
|
from sglang.srt.layers.attention.vision import VisionAttention
|
45
44
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
@@ -61,6 +60,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
61
60
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
62
61
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
63
62
|
from sglang.srt.utils import add_prefix
|
63
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
64
64
|
|
65
65
|
logger = logging.getLogger(__name__)
|
66
66
|
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -25,12 +25,14 @@ from torch import nn
|
|
25
25
|
from transformers import PretrainedConfig
|
26
26
|
|
27
27
|
from sglang.srt.distributed import (
|
28
|
+
get_moe_expert_parallel_world_size,
|
28
29
|
get_pp_group,
|
29
30
|
get_tensor_model_parallel_world_size,
|
30
31
|
tensor_model_parallel_all_reduce,
|
31
32
|
)
|
32
33
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
33
34
|
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
35
|
+
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
34
36
|
from sglang.srt.layers.activation import SiluAndMul
|
35
37
|
from sglang.srt.layers.communicator import (
|
36
38
|
LayerCommunicator,
|
@@ -50,6 +52,7 @@ from sglang.srt.layers.linear import (
|
|
50
52
|
RowParallelLinear,
|
51
53
|
)
|
52
54
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
55
|
+
from sglang.srt.layers.moe import get_moe_a2a_backend
|
53
56
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
54
57
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
55
58
|
from sglang.srt.layers.moe.topk import TopK
|
@@ -62,13 +65,16 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
62
65
|
VocabParallelEmbedding,
|
63
66
|
)
|
64
67
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
68
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
65
69
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
66
70
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
67
71
|
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
|
68
|
-
from sglang.srt.utils import add_prefix, make_layers
|
72
|
+
from sglang.srt.utils import add_prefix, is_cuda, make_layers
|
69
73
|
|
70
74
|
logger = logging.getLogger(__name__)
|
71
75
|
|
76
|
+
_is_cuda = is_cuda()
|
77
|
+
|
72
78
|
|
73
79
|
class Qwen2MoeMLP(nn.Module):
|
74
80
|
def __init__(
|
@@ -79,6 +85,8 @@ class Qwen2MoeMLP(nn.Module):
|
|
79
85
|
quant_config: Optional[QuantizationConfig] = None,
|
80
86
|
reduce_results: bool = True,
|
81
87
|
prefix: str = "",
|
88
|
+
tp_rank: Optional[int] = None,
|
89
|
+
tp_size: Optional[int] = None,
|
82
90
|
) -> None:
|
83
91
|
super().__init__()
|
84
92
|
self.gate_up_proj = MergedColumnParallelLinear(
|
@@ -87,6 +95,8 @@ class Qwen2MoeMLP(nn.Module):
|
|
87
95
|
bias=False,
|
88
96
|
quant_config=quant_config,
|
89
97
|
prefix=add_prefix("gate_up_proj", prefix),
|
98
|
+
tp_rank=tp_rank,
|
99
|
+
tp_size=tp_size,
|
90
100
|
)
|
91
101
|
self.down_proj = RowParallelLinear(
|
92
102
|
intermediate_size,
|
@@ -95,6 +105,8 @@ class Qwen2MoeMLP(nn.Module):
|
|
95
105
|
quant_config=quant_config,
|
96
106
|
reduce_results=reduce_results,
|
97
107
|
prefix=add_prefix("down_proj", prefix),
|
108
|
+
tp_rank=tp_rank,
|
109
|
+
tp_size=tp_size,
|
98
110
|
)
|
99
111
|
if hidden_act != "silu":
|
100
112
|
raise ValueError(
|
@@ -122,11 +134,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
122
134
|
layer_id: int,
|
123
135
|
config: PretrainedConfig,
|
124
136
|
quant_config: Optional[QuantizationConfig] = None,
|
137
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
125
138
|
prefix: str = "",
|
126
139
|
):
|
127
140
|
super().__init__()
|
128
141
|
self.tp_size = get_tensor_model_parallel_world_size()
|
129
142
|
self.layer_id = layer_id
|
143
|
+
self.alt_stream = alt_stream
|
130
144
|
if self.tp_size > config.num_experts:
|
131
145
|
raise ValueError(
|
132
146
|
f"Tensor parallel size {self.tp_size} is greater than "
|
@@ -138,10 +152,11 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
138
152
|
renormalize=config.norm_topk_prob,
|
139
153
|
)
|
140
154
|
|
141
|
-
self.experts = get_moe_impl_class()(
|
155
|
+
self.experts = get_moe_impl_class(quant_config)(
|
142
156
|
layer_id=self.layer_id,
|
143
157
|
top_k=config.num_experts_per_tok,
|
144
|
-
num_experts=config.num_experts
|
158
|
+
num_experts=config.num_experts
|
159
|
+
+ global_server_args_dict["ep_num_redundant_experts"],
|
145
160
|
hidden_size=config.hidden_size,
|
146
161
|
intermediate_size=config.moe_intermediate_size,
|
147
162
|
quant_config=quant_config,
|
@@ -163,19 +178,32 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
163
178
|
quant_config=quant_config,
|
164
179
|
reduce_results=False,
|
165
180
|
prefix=add_prefix("shared_expert", prefix),
|
181
|
+
**(
|
182
|
+
dict(tp_rank=0, tp_size=1)
|
183
|
+
if get_moe_a2a_backend().is_deepep()
|
184
|
+
else {}
|
185
|
+
),
|
166
186
|
)
|
167
187
|
else:
|
168
188
|
self.shared_expert = None
|
169
189
|
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
170
190
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
191
|
+
if get_moe_a2a_backend().is_deepep():
|
192
|
+
# TODO: we will support tp < ep in the future
|
193
|
+
self.ep_size = get_moe_expert_parallel_world_size()
|
194
|
+
self.num_experts = (
|
195
|
+
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
|
196
|
+
)
|
197
|
+
self.top_k = config.num_experts_per_tok
|
198
|
+
|
199
|
+
def get_moe_weights(self):
|
200
|
+
return [
|
201
|
+
x.data
|
202
|
+
for name, x in self.experts.named_parameters()
|
203
|
+
if name not in ["correction_bias"]
|
204
|
+
]
|
205
|
+
|
206
|
+
def _forward_shared_experts(self, hidden_states: torch.Tensor):
|
179
207
|
shared_output = None
|
180
208
|
if self.shared_expert is not None:
|
181
209
|
shared_output = self.shared_expert(hidden_states)
|
@@ -183,11 +211,85 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
183
211
|
shared_output = (
|
184
212
|
F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
|
185
213
|
)
|
214
|
+
return shared_output
|
215
|
+
|
216
|
+
def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
|
217
|
+
shared_output = None
|
218
|
+
if hidden_states.shape[0] > 0:
|
219
|
+
# router_logits: (num_tokens, n_experts)
|
220
|
+
router_logits, _ = self.gate(hidden_states)
|
221
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
222
|
+
topk_weights, topk_idx, _ = self.topk(
|
223
|
+
hidden_states,
|
224
|
+
router_logits,
|
225
|
+
num_token_non_padded=forward_batch.num_token_non_padded,
|
226
|
+
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
227
|
+
layer_id=self.layer_id,
|
228
|
+
),
|
229
|
+
)
|
230
|
+
else:
|
231
|
+
topk_weights, topk_idx, _ = self.topk.empty_topk_output(
|
232
|
+
hidden_states.device
|
233
|
+
)
|
234
|
+
final_hidden_states = self.experts(
|
235
|
+
hidden_states=hidden_states,
|
236
|
+
topk_idx=topk_idx,
|
237
|
+
topk_weights=topk_weights,
|
238
|
+
forward_batch=forward_batch,
|
239
|
+
)
|
240
|
+
|
241
|
+
if shared_output is not None:
|
242
|
+
final_hidden_states.add_(shared_output)
|
186
243
|
|
244
|
+
return final_hidden_states
|
245
|
+
|
246
|
+
def _forward_router_experts(self, hidden_states: torch.Tensor):
|
187
247
|
# router_logits: (num_tokens, n_experts)
|
188
248
|
router_logits, _ = self.gate(hidden_states)
|
189
249
|
topk_output = self.topk(hidden_states, router_logits)
|
190
|
-
|
250
|
+
return self.experts(hidden_states, topk_output)
|
251
|
+
|
252
|
+
def forward_normal_dual_stream(
|
253
|
+
self,
|
254
|
+
hidden_states: torch.Tensor,
|
255
|
+
) -> torch.Tensor:
|
256
|
+
current_stream = torch.cuda.current_stream()
|
257
|
+
self.alt_stream.wait_stream(current_stream)
|
258
|
+
shared_output = self._forward_shared_experts(hidden_states.clone())
|
259
|
+
|
260
|
+
with torch.cuda.stream(self.alt_stream):
|
261
|
+
router_output = self._forward_router_experts(hidden_states)
|
262
|
+
|
263
|
+
current_stream.wait_stream(self.alt_stream)
|
264
|
+
|
265
|
+
return router_output, shared_output
|
266
|
+
|
267
|
+
def forward(
|
268
|
+
self,
|
269
|
+
hidden_states: torch.Tensor,
|
270
|
+
forward_batch: Optional[ForwardBatch] = None,
|
271
|
+
use_reduce_scatter: bool = False,
|
272
|
+
) -> torch.Tensor:
|
273
|
+
num_tokens, hidden_dim = hidden_states.shape
|
274
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
275
|
+
|
276
|
+
if get_moe_a2a_backend().is_deepep():
|
277
|
+
return self._forward_deepep(hidden_states, forward_batch)
|
278
|
+
|
279
|
+
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
280
|
+
if (
|
281
|
+
self.alt_stream is not None
|
282
|
+
and hidden_states.shape[0] > 0
|
283
|
+
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
284
|
+
and get_is_capture_mode()
|
285
|
+
):
|
286
|
+
final_hidden_states, shared_output = self.forward_normal_dual_stream(
|
287
|
+
hidden_states
|
288
|
+
)
|
289
|
+
else:
|
290
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
291
|
+
final_hidden_states = self._forward_router_experts(hidden_states)
|
292
|
+
|
191
293
|
if shared_output is not None:
|
192
294
|
final_hidden_states = final_hidden_states + shared_output
|
193
295
|
if self.tp_size > 1 and not use_reduce_scatter:
|
@@ -346,6 +448,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
346
448
|
layer_id=layer_id,
|
347
449
|
config=config,
|
348
450
|
quant_config=quant_config,
|
451
|
+
alt_stream=alt_stream,
|
349
452
|
prefix=add_prefix("mlp", prefix),
|
350
453
|
)
|
351
454
|
else:
|
@@ -528,8 +631,12 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
528
631
|
self.pp_group = get_pp_group()
|
529
632
|
self.config = config
|
530
633
|
self.quant_config = quant_config
|
634
|
+
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
531
635
|
self.model = Qwen2MoeModel(
|
532
|
-
config,
|
636
|
+
config,
|
637
|
+
quant_config,
|
638
|
+
prefix=add_prefix("model", prefix),
|
639
|
+
alt_stream=alt_stream,
|
533
640
|
)
|
534
641
|
self.lm_head = ParallelLMHead(
|
535
642
|
config.vocab_size,
|
sglang/srt/models/qwen2_vl.py
CHANGED
@@ -33,7 +33,6 @@ from einops import rearrange
|
|
33
33
|
from transformers import Qwen2VLConfig
|
34
34
|
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
|
35
35
|
|
36
|
-
from sglang.srt.hf_transformers_utils import get_processor
|
37
36
|
from sglang.srt.layers.activation import QuickGELU
|
38
37
|
from sglang.srt.layers.attention.vision import VisionAttention
|
39
38
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
@@ -50,6 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
50
49
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
51
50
|
from sglang.srt.models.qwen2 import Qwen2Model
|
52
51
|
from sglang.srt.utils import add_prefix
|
52
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
53
53
|
|
54
54
|
logger = logging.getLogger(__name__)
|
55
55
|
|
sglang/srt/models/qwen3.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
# Adapted from qwen2.py
|
2
2
|
import logging
|
3
|
-
from functools import partial
|
4
3
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
5
4
|
|
6
5
|
import torch
|
@@ -30,12 +29,19 @@ from sglang.srt.model_loader.weight_utils import (
|
|
30
29
|
)
|
31
30
|
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
32
31
|
from sglang.srt.models.qwen2 import Qwen2Model
|
33
|
-
from sglang.srt.utils import
|
32
|
+
from sglang.srt.utils import (
|
33
|
+
add_prefix,
|
34
|
+
get_cmo_stream,
|
35
|
+
is_cuda,
|
36
|
+
is_npu,
|
37
|
+
wait_cmo_stream,
|
38
|
+
)
|
34
39
|
|
35
40
|
Qwen3Config = None
|
36
41
|
|
37
42
|
logger = logging.getLogger(__name__)
|
38
43
|
_is_cuda = is_cuda()
|
44
|
+
_is_npu = is_npu()
|
39
45
|
|
40
46
|
|
41
47
|
class Qwen3Attention(nn.Module):
|
@@ -235,9 +241,18 @@ class Qwen3DecoderLayer(nn.Module):
|
|
235
241
|
|
236
242
|
# Fully Connected
|
237
243
|
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
238
|
-
hidden_states,
|
244
|
+
hidden_states,
|
245
|
+
residual,
|
246
|
+
forward_batch,
|
247
|
+
cache=(
|
248
|
+
[self.mlp.gate_up_proj.weight, self.mlp.down_proj.weight]
|
249
|
+
if _is_npu
|
250
|
+
else None
|
251
|
+
),
|
239
252
|
)
|
240
253
|
hidden_states = self.mlp(hidden_states)
|
254
|
+
if _is_npu and get_cmo_stream():
|
255
|
+
wait_cmo_stream()
|
241
256
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
242
257
|
hidden_states, residual, forward_batch
|
243
258
|
)
|
sglang/srt/models/qwen3_moe.py
CHANGED
@@ -51,7 +51,7 @@ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
|
51
51
|
from sglang.srt.layers.moe.topk import TopK
|
52
52
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
53
53
|
from sglang.srt.layers.radix_attention import RadixAttention
|
54
|
-
from sglang.srt.layers.rotary_embedding import get_rope
|
54
|
+
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope
|
55
55
|
from sglang.srt.layers.utils import get_layer_id
|
56
56
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
57
57
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -60,6 +60,10 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
|
|
60
60
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
61
61
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
62
62
|
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
63
|
+
from sglang.srt.models.utils import (
|
64
|
+
create_fused_set_kv_buffer_arg,
|
65
|
+
enable_fused_set_kv_buffer,
|
66
|
+
)
|
63
67
|
from sglang.srt.utils import (
|
64
68
|
add_prefix,
|
65
69
|
is_cuda,
|
@@ -98,7 +102,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
98
102
|
use_grouped_topk=False,
|
99
103
|
)
|
100
104
|
|
101
|
-
self.experts = get_moe_impl_class()(
|
105
|
+
self.experts = get_moe_impl_class(quant_config)(
|
102
106
|
num_experts=config.num_experts
|
103
107
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
104
108
|
top_k=config.num_experts_per_tok,
|
@@ -354,6 +358,10 @@ class Qwen3MoeAttention(nn.Module):
|
|
354
358
|
rope_scaling=rope_scaling,
|
355
359
|
dual_chunk_attention_config=dual_chunk_attention_config,
|
356
360
|
)
|
361
|
+
self.compatible_with_fused_kv_buffer = (
|
362
|
+
False if isinstance(self.rotary_emb, MRotaryEmbedding) else True
|
363
|
+
)
|
364
|
+
|
357
365
|
self.attn = RadixAttention(
|
358
366
|
self.num_heads,
|
359
367
|
self.head_dim,
|
@@ -412,7 +420,21 @@ class Qwen3MoeAttention(nn.Module):
|
|
412
420
|
qkv, _ = self.qkv_proj(hidden_states)
|
413
421
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
414
422
|
q, k = self._apply_qk_norm(q, k)
|
415
|
-
q, k = self.rotary_emb(
|
423
|
+
q, k = self.rotary_emb(
|
424
|
+
positions,
|
425
|
+
q,
|
426
|
+
k,
|
427
|
+
fused_set_kv_buffer_arg=(
|
428
|
+
create_fused_set_kv_buffer_arg(
|
429
|
+
value=v,
|
430
|
+
layer=self.attn,
|
431
|
+
forward_batch=forward_batch,
|
432
|
+
)
|
433
|
+
if enable_fused_set_kv_buffer(forward_batch)
|
434
|
+
and self.compatible_with_fused_kv_buffer
|
435
|
+
else None
|
436
|
+
),
|
437
|
+
)
|
416
438
|
inner_state = q, k, v, forward_batch
|
417
439
|
return None, forward_batch, inner_state
|
418
440
|
|
@@ -420,7 +442,13 @@ class Qwen3MoeAttention(nn.Module):
|
|
420
442
|
hidden_states, forward_batch, inner_state = intermediate_state
|
421
443
|
if inner_state is None:
|
422
444
|
return hidden_states
|
423
|
-
attn_output = self.attn(
|
445
|
+
attn_output = self.attn(
|
446
|
+
*inner_state,
|
447
|
+
save_kv_cache=not (
|
448
|
+
enable_fused_set_kv_buffer(forward_batch)
|
449
|
+
and self.compatible_with_fused_kv_buffer
|
450
|
+
),
|
451
|
+
)
|
424
452
|
output, _ = self.o_proj(attn_output)
|
425
453
|
return output
|
426
454
|
|