sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +67 -43
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +88 -53
- sglang/srt/entrypoints/openai/protocol.py +7 -4
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +39 -19
- sglang/srt/entrypoints/openai/serving_completions.py +15 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -7
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +182 -49
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +68 -41
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +200 -199
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +191 -139
- sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +260 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +18 -33
- sglang/srt/mem_cache/hiradix_cache.py +108 -48
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +121 -57
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +502 -77
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +75 -19
- sglang/srt/model_executor/model_runner.py +357 -30
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +346 -48
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +11 -2
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +60 -13
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +40 -9
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +355 -37
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +197 -112
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +46 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +12 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -113,12 +113,13 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
113
113
|
quant_config: Optional[QuantizationConfig] = None,
|
114
114
|
prefix: str = "",
|
115
115
|
num_dummy_heads: int = 0,
|
116
|
+
rms_norm_eps: float = 1e-6,
|
116
117
|
) -> None:
|
117
118
|
super().__init__()
|
118
119
|
if norm_layer is None:
|
119
120
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
120
|
-
self.norm1 = RMSNorm(dim, eps=
|
121
|
-
self.norm2 = RMSNorm(dim, eps=
|
121
|
+
self.norm1 = RMSNorm(dim, eps=rms_norm_eps)
|
122
|
+
self.norm2 = RMSNorm(dim, eps=rms_norm_eps)
|
122
123
|
|
123
124
|
if attn_implementation is None:
|
124
125
|
softmax_in_single_precision = False
|
@@ -517,6 +518,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
517
518
|
self.logits_processor = LogitsProcessor(config)
|
518
519
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
519
520
|
|
521
|
+
# For EAGLE3 support
|
522
|
+
self.capture_aux_hidden_states = False
|
523
|
+
|
520
524
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
521
525
|
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
522
526
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
@@ -587,9 +591,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
587
591
|
positions=positions,
|
588
592
|
)
|
589
593
|
|
594
|
+
aux_hidden_states = None
|
595
|
+
if self.capture_aux_hidden_states:
|
596
|
+
hidden_states, aux_hidden_states = hidden_states
|
597
|
+
|
590
598
|
if not get_embedding:
|
591
599
|
return self.logits_processor(
|
592
|
-
input_ids, hidden_states, self.lm_head, forward_batch
|
600
|
+
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
593
601
|
)
|
594
602
|
else:
|
595
603
|
return self.pooler(hidden_states, forward_batch)
|
@@ -643,5 +651,21 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
643
651
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
644
652
|
weight_loader(param, loaded_weight)
|
645
653
|
|
654
|
+
def get_embed_and_head(self):
|
655
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
656
|
+
|
657
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
658
|
+
self.capture_aux_hidden_states = True
|
659
|
+
self.model.capture_aux_hidden_states = True
|
660
|
+
if layer_ids is None:
|
661
|
+
num_layers = self.config.num_hidden_layers
|
662
|
+
self.model.layers_to_capture = [
|
663
|
+
2,
|
664
|
+
num_layers // 2,
|
665
|
+
num_layers - 3,
|
666
|
+
] # Specific layers for EAGLE3 support
|
667
|
+
else:
|
668
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
669
|
+
|
646
670
|
|
647
671
|
EntryClass = [Qwen2_5_VLForConditionalGeneration]
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -62,13 +62,16 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
62
62
|
VocabParallelEmbedding,
|
63
63
|
)
|
64
64
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
65
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
65
66
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
66
67
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
67
68
|
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
|
68
|
-
from sglang.srt.utils import add_prefix, make_layers
|
69
|
+
from sglang.srt.utils import add_prefix, is_cuda, make_layers
|
69
70
|
|
70
71
|
logger = logging.getLogger(__name__)
|
71
72
|
|
73
|
+
_is_cuda = is_cuda()
|
74
|
+
|
72
75
|
|
73
76
|
class Qwen2MoeMLP(nn.Module):
|
74
77
|
def __init__(
|
@@ -105,11 +108,14 @@ class Qwen2MoeMLP(nn.Module):
|
|
105
108
|
def forward(
|
106
109
|
self,
|
107
110
|
x,
|
111
|
+
should_allreduce_fusion: bool = False,
|
108
112
|
use_reduce_scatter: bool = False,
|
109
113
|
):
|
110
114
|
gate_up, _ = self.gate_up_proj(x)
|
111
115
|
x = self.act_fn(gate_up)
|
112
|
-
x, _ = self.down_proj(
|
116
|
+
x, _ = self.down_proj(
|
117
|
+
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
|
118
|
+
)
|
113
119
|
return x
|
114
120
|
|
115
121
|
|
@@ -119,11 +125,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
119
125
|
layer_id: int,
|
120
126
|
config: PretrainedConfig,
|
121
127
|
quant_config: Optional[QuantizationConfig] = None,
|
128
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
122
129
|
prefix: str = "",
|
123
130
|
):
|
124
131
|
super().__init__()
|
125
132
|
self.tp_size = get_tensor_model_parallel_world_size()
|
126
133
|
self.layer_id = layer_id
|
134
|
+
self.alt_stream = alt_stream
|
127
135
|
if self.tp_size > config.num_experts:
|
128
136
|
raise ValueError(
|
129
137
|
f"Tensor parallel size {self.tp_size} is greater than "
|
@@ -135,7 +143,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
135
143
|
renormalize=config.norm_topk_prob,
|
136
144
|
)
|
137
145
|
|
138
|
-
self.experts = get_moe_impl_class()(
|
146
|
+
self.experts = get_moe_impl_class(quant_config)(
|
139
147
|
layer_id=self.layer_id,
|
140
148
|
top_k=config.num_experts_per_tok,
|
141
149
|
num_experts=config.num_experts,
|
@@ -165,14 +173,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
165
173
|
self.shared_expert = None
|
166
174
|
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
167
175
|
|
168
|
-
def
|
169
|
-
self,
|
170
|
-
hidden_states: torch.Tensor,
|
171
|
-
forward_batch: Optional[ForwardBatch] = None,
|
172
|
-
use_reduce_scatter: bool = False,
|
173
|
-
) -> torch.Tensor:
|
174
|
-
num_tokens, hidden_dim = hidden_states.shape
|
175
|
-
hidden_states = hidden_states.view(-1, hidden_dim)
|
176
|
+
def _forward_shared_experts(self, hidden_states: torch.Tensor):
|
176
177
|
shared_output = None
|
177
178
|
if self.shared_expert is not None:
|
178
179
|
shared_output = self.shared_expert(hidden_states)
|
@@ -180,11 +181,52 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
180
181
|
shared_output = (
|
181
182
|
F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
|
182
183
|
)
|
184
|
+
return shared_output
|
183
185
|
|
186
|
+
def _forward_router_experts(self, hidden_states: torch.Tensor):
|
184
187
|
# router_logits: (num_tokens, n_experts)
|
185
188
|
router_logits, _ = self.gate(hidden_states)
|
186
189
|
topk_output = self.topk(hidden_states, router_logits)
|
187
|
-
|
190
|
+
return self.experts(hidden_states, topk_output)
|
191
|
+
|
192
|
+
def forward_normal_dual_stream(
|
193
|
+
self,
|
194
|
+
hidden_states: torch.Tensor,
|
195
|
+
) -> torch.Tensor:
|
196
|
+
current_stream = torch.cuda.current_stream()
|
197
|
+
self.alt_stream.wait_stream(current_stream)
|
198
|
+
shared_output = self._forward_shared_experts(hidden_states.clone())
|
199
|
+
|
200
|
+
with torch.cuda.stream(self.alt_stream):
|
201
|
+
router_output = self._forward_router_experts(hidden_states)
|
202
|
+
|
203
|
+
current_stream.wait_stream(self.alt_stream)
|
204
|
+
|
205
|
+
return router_output, shared_output
|
206
|
+
|
207
|
+
def forward(
|
208
|
+
self,
|
209
|
+
hidden_states: torch.Tensor,
|
210
|
+
forward_batch: Optional[ForwardBatch] = None,
|
211
|
+
use_reduce_scatter: bool = False,
|
212
|
+
) -> torch.Tensor:
|
213
|
+
num_tokens, hidden_dim = hidden_states.shape
|
214
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
215
|
+
|
216
|
+
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
217
|
+
if (
|
218
|
+
self.alt_stream is not None
|
219
|
+
and hidden_states.shape[0] > 0
|
220
|
+
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
221
|
+
and get_is_capture_mode()
|
222
|
+
):
|
223
|
+
final_hidden_states, shared_output = self.forward_normal_dual_stream(
|
224
|
+
hidden_states
|
225
|
+
)
|
226
|
+
else:
|
227
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
228
|
+
final_hidden_states = self._forward_router_experts(hidden_states)
|
229
|
+
|
188
230
|
if shared_output is not None:
|
189
231
|
final_hidden_states = final_hidden_states + shared_output
|
190
232
|
if self.tp_size > 1 and not use_reduce_scatter:
|
@@ -343,6 +385,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
343
385
|
layer_id=layer_id,
|
344
386
|
config=config,
|
345
387
|
quant_config=quant_config,
|
388
|
+
alt_stream=alt_stream,
|
346
389
|
prefix=add_prefix("mlp", prefix),
|
347
390
|
)
|
348
391
|
else:
|
@@ -525,8 +568,12 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
525
568
|
self.pp_group = get_pp_group()
|
526
569
|
self.config = config
|
527
570
|
self.quant_config = quant_config
|
571
|
+
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
528
572
|
self.model = Qwen2MoeModel(
|
529
|
-
config,
|
573
|
+
config,
|
574
|
+
quant_config,
|
575
|
+
prefix=add_prefix("model", prefix),
|
576
|
+
alt_stream=alt_stream,
|
530
577
|
)
|
531
578
|
self.lm_head = ParallelLMHead(
|
532
579
|
config.vocab_size,
|
sglang/srt/models/qwen3.py
CHANGED
@@ -24,7 +24,10 @@ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
|
24
24
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
25
25
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
26
26
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
27
|
-
from sglang.srt.model_loader.weight_utils import
|
27
|
+
from sglang.srt.model_loader.weight_utils import (
|
28
|
+
default_weight_loader,
|
29
|
+
maybe_remap_kv_scale_name,
|
30
|
+
)
|
28
31
|
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
29
32
|
from sglang.srt.models.qwen2 import Qwen2Model
|
30
33
|
from sglang.srt.utils import add_prefix, is_cuda
|
@@ -458,7 +461,10 @@ class Qwen3ForCausalLM(nn.Module):
|
|
458
461
|
continue
|
459
462
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
460
463
|
continue
|
461
|
-
|
464
|
+
if "scale" in name:
|
465
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
466
|
+
if name is None:
|
467
|
+
continue
|
462
468
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
463
469
|
if weight_name not in name:
|
464
470
|
continue
|
sglang/srt/models/qwen3_moe.py
CHANGED
@@ -42,7 +42,10 @@ from sglang.srt.layers.linear import (
|
|
42
42
|
RowParallelLinear,
|
43
43
|
)
|
44
44
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
45
|
-
from sglang.srt.layers.moe import
|
45
|
+
from sglang.srt.layers.moe import (
|
46
|
+
get_moe_a2a_backend,
|
47
|
+
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
48
|
+
)
|
46
49
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
47
50
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
48
51
|
from sglang.srt.layers.moe.topk import TopK
|
@@ -57,10 +60,17 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
|
|
57
60
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
58
61
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
59
62
|
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
60
|
-
from sglang.srt.utils import
|
63
|
+
from sglang.srt.utils import (
|
64
|
+
add_prefix,
|
65
|
+
is_cuda,
|
66
|
+
is_flashinfer_available,
|
67
|
+
is_non_idle_and_non_empty,
|
68
|
+
)
|
61
69
|
|
62
70
|
Qwen3MoeConfig = None
|
63
71
|
|
72
|
+
_is_flashinfer_available = is_flashinfer_available()
|
73
|
+
|
64
74
|
logger = logging.getLogger(__name__)
|
65
75
|
_is_cuda = is_cuda()
|
66
76
|
|
@@ -88,7 +98,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
88
98
|
use_grouped_topk=False,
|
89
99
|
)
|
90
100
|
|
91
|
-
self.experts = get_moe_impl_class()(
|
101
|
+
self.experts = get_moe_impl_class(quant_config)(
|
92
102
|
num_experts=config.num_experts
|
93
103
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
94
104
|
top_k=config.num_experts_per_tok,
|
@@ -119,11 +129,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
119
129
|
self,
|
120
130
|
hidden_states: torch.Tensor,
|
121
131
|
forward_batch: Optional[ForwardBatch] = None,
|
132
|
+
should_allreduce_fusion: bool = False,
|
122
133
|
use_reduce_scatter: bool = False,
|
123
134
|
) -> torch.Tensor:
|
124
135
|
|
125
136
|
if not get_moe_a2a_backend().is_deepep():
|
126
|
-
return self.forward_normal(
|
137
|
+
return self.forward_normal(
|
138
|
+
hidden_states, should_allreduce_fusion, use_reduce_scatter
|
139
|
+
)
|
127
140
|
else:
|
128
141
|
return self.forward_deepep(hidden_states, forward_batch)
|
129
142
|
|
@@ -137,6 +150,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
137
150
|
def forward_normal(
|
138
151
|
self,
|
139
152
|
hidden_states: torch.Tensor,
|
153
|
+
should_allreduce_fusion: bool = False,
|
140
154
|
use_reduce_scatter: bool = False,
|
141
155
|
) -> torch.Tensor:
|
142
156
|
num_tokens, hidden_dim = hidden_states.shape
|
@@ -146,7 +160,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
146
160
|
router_logits, _ = self.gate(hidden_states)
|
147
161
|
topk_output = self.topk(hidden_states, router_logits)
|
148
162
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
149
|
-
if
|
163
|
+
if (
|
164
|
+
self.tp_size > 1
|
165
|
+
and not should_allreduce_fusion
|
166
|
+
and not use_reduce_scatter
|
167
|
+
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
|
168
|
+
):
|
150
169
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
151
170
|
|
152
171
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
@@ -500,6 +519,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
500
519
|
input_layernorm=self.input_layernorm,
|
501
520
|
post_attention_layernorm=self.post_attention_layernorm,
|
502
521
|
allow_reduce_scatter=True,
|
522
|
+
is_last_layer=(self.layer_id == self.config.num_hidden_layers - 1),
|
503
523
|
)
|
504
524
|
|
505
525
|
def forward(
|
@@ -525,17 +545,28 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
525
545
|
hidden_states, residual, forward_batch
|
526
546
|
)
|
527
547
|
|
548
|
+
should_allreduce_fusion = (
|
549
|
+
self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
|
550
|
+
forward_batch
|
551
|
+
)
|
552
|
+
)
|
553
|
+
|
528
554
|
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
529
555
|
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
530
556
|
forward_batch
|
531
557
|
)
|
532
558
|
|
533
|
-
hidden_states = self.mlp(
|
534
|
-
|
535
|
-
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
536
|
-
hidden_states, residual, forward_batch
|
559
|
+
hidden_states = self.mlp(
|
560
|
+
hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
|
537
561
|
)
|
538
562
|
|
563
|
+
if should_allreduce_fusion:
|
564
|
+
hidden_states._sglang_needs_allreduce_fusion = True
|
565
|
+
else:
|
566
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
567
|
+
hidden_states, residual, forward_batch
|
568
|
+
)
|
569
|
+
|
539
570
|
return hidden_states, residual
|
540
571
|
|
541
572
|
def op_comm_prepare_attn(
|