sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -31,7 +31,6 @@ import torch.nn as nn
|
|
31
31
|
import torch.nn.functional as F
|
32
32
|
from einops import rearrange
|
33
33
|
from transformers.activations import ACT2FN
|
34
|
-
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
35
34
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
36
35
|
Qwen2_5_VLConfig,
|
37
36
|
Qwen2_5_VLVisionConfig,
|
@@ -43,7 +42,12 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
43
42
|
|
44
43
|
from sglang.srt.hf_transformers_utils import get_processor
|
45
44
|
from sglang.srt.layers.attention.vision import VisionAttention
|
46
|
-
from sglang.srt.layers.
|
45
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
46
|
+
from sglang.srt.layers.linear import (
|
47
|
+
ColumnParallelLinear,
|
48
|
+
MergedColumnParallelLinear,
|
49
|
+
RowParallelLinear,
|
50
|
+
)
|
47
51
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
48
52
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
49
53
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -62,7 +66,6 @@ logger = logging.getLogger(__name__)
|
|
62
66
|
|
63
67
|
|
64
68
|
class Qwen2_5_VLMLP(nn.Module):
|
65
|
-
|
66
69
|
def __init__(
|
67
70
|
self,
|
68
71
|
in_features: int,
|
@@ -73,19 +76,12 @@ class Qwen2_5_VLMLP(nn.Module):
|
|
73
76
|
prefix: str = "",
|
74
77
|
):
|
75
78
|
super().__init__()
|
76
|
-
self.
|
77
|
-
in_features,
|
78
|
-
hidden_features,
|
79
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
80
|
+
input_size=in_features,
|
81
|
+
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
|
79
82
|
bias=bias,
|
80
83
|
quant_config=quant_config,
|
81
|
-
prefix=add_prefix("
|
82
|
-
)
|
83
|
-
self.up_proj = ColumnParallelLinear(
|
84
|
-
in_features,
|
85
|
-
hidden_features,
|
86
|
-
bias=bias,
|
87
|
-
quant_config=quant_config,
|
88
|
-
prefix=add_prefix("up_proj", prefix),
|
84
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
89
85
|
)
|
90
86
|
self.down_proj = RowParallelLinear(
|
91
87
|
hidden_features,
|
@@ -97,12 +93,11 @@ class Qwen2_5_VLMLP(nn.Module):
|
|
97
93
|
self.act = ACT2FN[hidden_act]
|
98
94
|
|
99
95
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
return x
|
96
|
+
gate_up, _ = self.gate_up_proj(x)
|
97
|
+
gate, up = gate_up.chunk(2, dim=-1)
|
98
|
+
x = self.act(gate) * up
|
99
|
+
x_down, _ = self.down_proj(x)
|
100
|
+
return x_down
|
106
101
|
|
107
102
|
|
108
103
|
class Qwen2_5_VisionBlock(nn.Module):
|
@@ -118,12 +113,13 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
118
113
|
quant_config: Optional[QuantizationConfig] = None,
|
119
114
|
prefix: str = "",
|
120
115
|
num_dummy_heads: int = 0,
|
116
|
+
rms_norm_eps: float = 1e-6,
|
121
117
|
) -> None:
|
122
118
|
super().__init__()
|
123
119
|
if norm_layer is None:
|
124
120
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
125
|
-
self.norm1 =
|
126
|
-
self.norm2 =
|
121
|
+
self.norm1 = RMSNorm(dim, eps=rms_norm_eps)
|
122
|
+
self.norm2 = RMSNorm(dim, eps=rms_norm_eps)
|
127
123
|
|
128
124
|
if attn_implementation is None:
|
129
125
|
softmax_in_single_precision = False
|
@@ -174,18 +170,29 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
174
170
|
cu_seqlens: torch.Tensor,
|
175
171
|
position_embeddings: torch.Tensor,
|
176
172
|
) -> torch.Tensor:
|
177
|
-
|
178
|
-
|
173
|
+
S, B, H = x.shape
|
174
|
+
# norm1: flatten to 2D -> [S*B, H], then reshape back
|
175
|
+
x2d = x.reshape(-1, H)
|
176
|
+
hidden_states = self.norm1(x2d).reshape(S, B, H)
|
177
|
+
|
178
|
+
# Attention expects [B, S, H]
|
179
|
+
hidden_states = rearrange(hidden_states, "s b h -> b s h")
|
179
180
|
attn = self.attn(
|
180
181
|
hidden_states,
|
181
182
|
cu_seqlens=cu_seqlens,
|
182
183
|
position_embeddings=position_embeddings,
|
183
184
|
)
|
184
|
-
attn = rearrange(attn, "b s
|
185
|
-
|
186
|
-
norm2
|
187
|
-
|
188
|
-
|
185
|
+
attn = rearrange(attn, "b s h -> s b h")
|
186
|
+
|
187
|
+
# norm2 with fused residual-add: also 2D
|
188
|
+
attn2d = attn.reshape(-1, H)
|
189
|
+
x_norm_2d, x_after_add_2d = self.norm2(x2d, residual=attn2d)
|
190
|
+
x_norm = x_norm_2d.reshape(S, B, H)
|
191
|
+
x_after_add = x_after_add_2d.reshape(S, B, H)
|
192
|
+
|
193
|
+
# MLP and final residual
|
194
|
+
mlp_out = self.mlp(x_norm)
|
195
|
+
x = x_after_add + mlp_out
|
189
196
|
return x
|
190
197
|
|
191
198
|
|
@@ -201,7 +208,7 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
|
201
208
|
) -> None:
|
202
209
|
super().__init__()
|
203
210
|
self.hidden_size = context_dim * (spatial_merge_size**2)
|
204
|
-
self.ln_q =
|
211
|
+
self.ln_q = RMSNorm(context_dim, eps=1e-6)
|
205
212
|
self.mlp = nn.ModuleList(
|
206
213
|
[
|
207
214
|
ColumnParallelLinear(
|
@@ -223,11 +230,13 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
|
223
230
|
)
|
224
231
|
|
225
232
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
226
|
-
x
|
227
|
-
|
228
|
-
|
233
|
+
# x expected shape: [S, B, context_dim]
|
234
|
+
S, B, D = x.shape
|
235
|
+
x2d = x.reshape(-1, D)
|
236
|
+
x2d = self.ln_q(x2d) # RMSNorm expects 2D
|
237
|
+
x2d = x2d.view(-1, self.hidden_size) # group into spatial_merge_unit
|
229
238
|
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
|
230
|
-
x_parallel, _ = mlp_fc1(
|
239
|
+
x_parallel, _ = mlp_fc1(x2d)
|
231
240
|
x_parallel = mlp_act(x_parallel)
|
232
241
|
out, _ = mlp_fc2(x_parallel)
|
233
242
|
return out
|
@@ -340,7 +349,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
340
349
|
|
341
350
|
@property
|
342
351
|
def device(self) -> torch.device:
|
343
|
-
return self.
|
352
|
+
return self.patch_embed.proj.weight.device
|
344
353
|
|
345
354
|
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
346
355
|
pos_ids = []
|
@@ -394,6 +403,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
394
403
|
)
|
395
404
|
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
396
405
|
|
406
|
+
# Move window_index to the same device as x before using it to index x
|
407
|
+
window_index = window_index.to(device=x.device)
|
408
|
+
|
409
|
+
# Ensure rotary_pos_emb is on the same device/dtype as x
|
410
|
+
rotary_pos_emb = rotary_pos_emb.to(device=x.device, dtype=x.dtype)
|
411
|
+
|
397
412
|
seq_len, _ = x.size()
|
398
413
|
|
399
414
|
x = x.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
@@ -406,12 +421,19 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
406
421
|
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
407
422
|
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
408
423
|
position_embeddings = (emb.cos(), emb.sin())
|
424
|
+
# After building position_embeddings, make sure both cos and sin are on the same device/dtype as the attention input
|
425
|
+
position_embeddings = (
|
426
|
+
position_embeddings[0].to(x.device, x.dtype),
|
427
|
+
position_embeddings[1].to(x.device, x.dtype),
|
428
|
+
)
|
409
429
|
|
410
|
-
# compute cu_seqlens
|
430
|
+
# compute cu_seqlens - move cu_seqlens to GPU and make it int32
|
411
431
|
cu_seqlens = torch.cat(
|
412
432
|
[
|
413
|
-
torch.tensor([0], device=
|
414
|
-
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2])
|
433
|
+
torch.tensor([0], device=x.device, dtype=torch.int32),
|
434
|
+
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2])
|
435
|
+
.cumsum(dim=0)
|
436
|
+
.to(device=x.device, dtype=torch.int32),
|
415
437
|
]
|
416
438
|
)
|
417
439
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
@@ -442,9 +464,8 @@ cached_get_processor = lru_cache(get_processor)
|
|
442
464
|
class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
443
465
|
# BitandBytes specific attributes
|
444
466
|
default_bitsandbytes_target_modules = [
|
445
|
-
".
|
467
|
+
".gate_up_proj.",
|
446
468
|
".down_proj.",
|
447
|
-
".up_proj.",
|
448
469
|
".q_proj.",
|
449
470
|
".k_proj.",
|
450
471
|
".v_proj.",
|
@@ -497,6 +518,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
497
518
|
self.logits_processor = LogitsProcessor(config)
|
498
519
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
499
520
|
|
521
|
+
# For EAGLE3 support
|
522
|
+
self.capture_aux_hidden_states = False
|
523
|
+
|
500
524
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
501
525
|
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
502
526
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
@@ -526,6 +550,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
526
550
|
def get_input_embeddings(self):
|
527
551
|
return self.model.embed_tokens
|
528
552
|
|
553
|
+
@torch.no_grad()
|
529
554
|
def forward(
|
530
555
|
self,
|
531
556
|
input_ids: torch.Tensor,
|
@@ -566,9 +591,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
566
591
|
positions=positions,
|
567
592
|
)
|
568
593
|
|
594
|
+
aux_hidden_states = None
|
595
|
+
if self.capture_aux_hidden_states:
|
596
|
+
hidden_states, aux_hidden_states = hidden_states
|
597
|
+
|
569
598
|
if not get_embedding:
|
570
599
|
return self.logits_processor(
|
571
|
-
input_ids, hidden_states, self.lm_head, forward_batch
|
600
|
+
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
572
601
|
)
|
573
602
|
else:
|
574
603
|
return self.pooler(hidden_states, forward_batch)
|
@@ -590,7 +619,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
590
619
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
591
620
|
if weight_name not in name:
|
592
621
|
continue
|
593
|
-
if
|
622
|
+
if (
|
623
|
+
"visual" in name
|
624
|
+
and "up_proj" not in name
|
625
|
+
and "gate_proj" not in name
|
626
|
+
):
|
594
627
|
continue
|
595
628
|
name = name.replace(weight_name, param_name)
|
596
629
|
|
@@ -618,5 +651,21 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
618
651
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
619
652
|
weight_loader(param, loaded_weight)
|
620
653
|
|
654
|
+
def get_embed_and_head(self):
|
655
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
656
|
+
|
657
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
658
|
+
self.capture_aux_hidden_states = True
|
659
|
+
self.model.capture_aux_hidden_states = True
|
660
|
+
if layer_ids is None:
|
661
|
+
num_layers = self.config.num_hidden_layers
|
662
|
+
self.model.layers_to_capture = [
|
663
|
+
2,
|
664
|
+
num_layers // 2,
|
665
|
+
num_layers - 3,
|
666
|
+
] # Specific layers for EAGLE3 support
|
667
|
+
else:
|
668
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
669
|
+
|
621
670
|
|
622
671
|
EntryClass = [Qwen2_5_VLForConditionalGeneration]
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
18
18
|
|
19
19
|
import logging
|
20
|
-
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
20
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
21
21
|
|
22
22
|
import torch
|
23
23
|
import torch.nn.functional as F
|
@@ -65,10 +65,12 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
65
65
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
66
66
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
67
67
|
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
|
68
|
-
from sglang.srt.utils import add_prefix, make_layers
|
68
|
+
from sglang.srt.utils import add_prefix, is_cuda, make_layers
|
69
69
|
|
70
70
|
logger = logging.getLogger(__name__)
|
71
71
|
|
72
|
+
_is_cuda = is_cuda()
|
73
|
+
|
72
74
|
|
73
75
|
class Qwen2MoeMLP(nn.Module):
|
74
76
|
def __init__(
|
@@ -105,11 +107,14 @@ class Qwen2MoeMLP(nn.Module):
|
|
105
107
|
def forward(
|
106
108
|
self,
|
107
109
|
x,
|
110
|
+
should_allreduce_fusion: bool = False,
|
108
111
|
use_reduce_scatter: bool = False,
|
109
112
|
):
|
110
113
|
gate_up, _ = self.gate_up_proj(x)
|
111
114
|
x = self.act_fn(gate_up)
|
112
|
-
x, _ = self.down_proj(
|
115
|
+
x, _ = self.down_proj(
|
116
|
+
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
|
117
|
+
)
|
113
118
|
return x
|
114
119
|
|
115
120
|
|
@@ -119,11 +124,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
119
124
|
layer_id: int,
|
120
125
|
config: PretrainedConfig,
|
121
126
|
quant_config: Optional[QuantizationConfig] = None,
|
127
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
122
128
|
prefix: str = "",
|
123
129
|
):
|
124
130
|
super().__init__()
|
125
131
|
self.tp_size = get_tensor_model_parallel_world_size()
|
126
132
|
self.layer_id = layer_id
|
133
|
+
self.alt_stream = alt_stream
|
127
134
|
if self.tp_size > config.num_experts:
|
128
135
|
raise ValueError(
|
129
136
|
f"Tensor parallel size {self.tp_size} is greater than "
|
@@ -165,14 +172,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
165
172
|
self.shared_expert = None
|
166
173
|
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
167
174
|
|
168
|
-
def
|
169
|
-
self,
|
170
|
-
hidden_states: torch.Tensor,
|
171
|
-
forward_batch: Optional[ForwardBatch] = None,
|
172
|
-
use_reduce_scatter: bool = False,
|
173
|
-
) -> torch.Tensor:
|
174
|
-
num_tokens, hidden_dim = hidden_states.shape
|
175
|
-
hidden_states = hidden_states.view(-1, hidden_dim)
|
175
|
+
def _forward_shared_experts(self, hidden_states: torch.Tensor):
|
176
176
|
shared_output = None
|
177
177
|
if self.shared_expert is not None:
|
178
178
|
shared_output = self.shared_expert(hidden_states)
|
@@ -180,11 +180,51 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
180
180
|
shared_output = (
|
181
181
|
F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
|
182
182
|
)
|
183
|
+
return shared_output
|
183
184
|
|
185
|
+
def _forward_router_experts(self, hidden_states: torch.Tensor):
|
184
186
|
# router_logits: (num_tokens, n_experts)
|
185
187
|
router_logits, _ = self.gate(hidden_states)
|
186
188
|
topk_output = self.topk(hidden_states, router_logits)
|
187
|
-
|
189
|
+
return self.experts(hidden_states, topk_output)
|
190
|
+
|
191
|
+
def forward_normal_dual_stream(
|
192
|
+
self,
|
193
|
+
hidden_states: torch.Tensor,
|
194
|
+
) -> torch.Tensor:
|
195
|
+
current_stream = torch.cuda.current_stream()
|
196
|
+
self.alt_stream.wait_stream(current_stream)
|
197
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
198
|
+
|
199
|
+
with torch.cuda.stream(self.alt_stream):
|
200
|
+
router_output = self._forward_router_experts(hidden_states)
|
201
|
+
|
202
|
+
current_stream.wait_stream(self.alt_stream)
|
203
|
+
|
204
|
+
return router_output, shared_output
|
205
|
+
|
206
|
+
def forward(
|
207
|
+
self,
|
208
|
+
hidden_states: torch.Tensor,
|
209
|
+
forward_batch: Optional[ForwardBatch] = None,
|
210
|
+
use_reduce_scatter: bool = False,
|
211
|
+
) -> torch.Tensor:
|
212
|
+
num_tokens, hidden_dim = hidden_states.shape
|
213
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
214
|
+
|
215
|
+
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
216
|
+
if (
|
217
|
+
self.alt_stream is not None
|
218
|
+
and hidden_states.shape[0] > 0
|
219
|
+
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
220
|
+
):
|
221
|
+
final_hidden_states, shared_output = self.forward_normal_dual_stream(
|
222
|
+
hidden_states
|
223
|
+
)
|
224
|
+
else:
|
225
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
226
|
+
final_hidden_states = self._forward_router_experts(hidden_states)
|
227
|
+
|
188
228
|
if shared_output is not None:
|
189
229
|
final_hidden_states = final_hidden_states + shared_output
|
190
230
|
if self.tp_size > 1 and not use_reduce_scatter:
|
@@ -343,6 +383,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
343
383
|
layer_id=layer_id,
|
344
384
|
config=config,
|
345
385
|
quant_config=quant_config,
|
386
|
+
alt_stream=alt_stream,
|
346
387
|
prefix=add_prefix("mlp", prefix),
|
347
388
|
)
|
348
389
|
else:
|
@@ -525,8 +566,12 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
525
566
|
self.pp_group = get_pp_group()
|
526
567
|
self.config = config
|
527
568
|
self.quant_config = quant_config
|
569
|
+
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
528
570
|
self.model = Qwen2MoeModel(
|
529
|
-
config,
|
571
|
+
config,
|
572
|
+
quant_config,
|
573
|
+
prefix=add_prefix("model", prefix),
|
574
|
+
alt_stream=alt_stream,
|
530
575
|
)
|
531
576
|
self.lm_head = ParallelLMHead(
|
532
577
|
config.vocab_size,
|
@@ -536,6 +581,8 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
536
581
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
537
582
|
)
|
538
583
|
self.logits_processor = LogitsProcessor(config)
|
584
|
+
# For EAGLE3 support
|
585
|
+
self.capture_aux_hidden_states = False
|
539
586
|
|
540
587
|
@torch.no_grad()
|
541
588
|
def forward(
|
@@ -553,9 +600,12 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
553
600
|
input_embeds,
|
554
601
|
pp_proxy_tensors=pp_proxy_tensors,
|
555
602
|
)
|
603
|
+
aux_hidden_states = None
|
604
|
+
if self.capture_aux_hidden_states:
|
605
|
+
hidden_states, aux_hidden_states = hidden_states
|
556
606
|
if self.pp_group.is_last_rank:
|
557
607
|
return self.logits_processor(
|
558
|
-
input_ids, hidden_states, self.lm_head, forward_batch
|
608
|
+
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
559
609
|
)
|
560
610
|
else:
|
561
611
|
return hidden_states
|
@@ -705,5 +755,20 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
705
755
|
num_groups=None,
|
706
756
|
)
|
707
757
|
|
758
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
759
|
+
if not self.pp_group.is_last_rank:
|
760
|
+
return
|
761
|
+
|
762
|
+
self.capture_aux_hidden_states = True
|
763
|
+
if layer_ids is None:
|
764
|
+
num_layers = self.config.num_hidden_layers
|
765
|
+
self.model.layers_to_capture = [
|
766
|
+
2,
|
767
|
+
num_layers // 2,
|
768
|
+
num_layers - 3,
|
769
|
+
] # Specific layers for EAGLE3 support
|
770
|
+
else:
|
771
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
772
|
+
|
708
773
|
|
709
774
|
EntryClass = Qwen2MoeForCausalLM
|
sglang/srt/models/qwen3.py
CHANGED
@@ -24,7 +24,10 @@ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
|
24
24
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
25
25
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
26
26
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
27
|
-
from sglang.srt.model_loader.weight_utils import
|
27
|
+
from sglang.srt.model_loader.weight_utils import (
|
28
|
+
default_weight_loader,
|
29
|
+
maybe_remap_kv_scale_name,
|
30
|
+
)
|
28
31
|
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
29
32
|
from sglang.srt.models.qwen2 import Qwen2Model
|
30
33
|
from sglang.srt.utils import add_prefix, is_cuda
|
@@ -458,7 +461,10 @@ class Qwen3ForCausalLM(nn.Module):
|
|
458
461
|
continue
|
459
462
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
460
463
|
continue
|
461
|
-
|
464
|
+
if "scale" in name:
|
465
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
466
|
+
if name is None:
|
467
|
+
continue
|
462
468
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
463
469
|
if weight_name not in name:
|
464
470
|
continue
|
sglang/srt/models/qwen3_moe.py
CHANGED
@@ -42,7 +42,10 @@ from sglang.srt.layers.linear import (
|
|
42
42
|
RowParallelLinear,
|
43
43
|
)
|
44
44
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
45
|
-
from sglang.srt.layers.moe import
|
45
|
+
from sglang.srt.layers.moe import (
|
46
|
+
get_moe_a2a_backend,
|
47
|
+
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
48
|
+
)
|
46
49
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
47
50
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
48
51
|
from sglang.srt.layers.moe.topk import TopK
|
@@ -57,10 +60,17 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
|
|
57
60
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
58
61
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
59
62
|
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
60
|
-
from sglang.srt.utils import
|
63
|
+
from sglang.srt.utils import (
|
64
|
+
add_prefix,
|
65
|
+
is_cuda,
|
66
|
+
is_flashinfer_available,
|
67
|
+
is_non_idle_and_non_empty,
|
68
|
+
)
|
61
69
|
|
62
70
|
Qwen3MoeConfig = None
|
63
71
|
|
72
|
+
_is_flashinfer_available = is_flashinfer_available()
|
73
|
+
|
64
74
|
logger = logging.getLogger(__name__)
|
65
75
|
_is_cuda = is_cuda()
|
66
76
|
|
@@ -119,11 +129,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
119
129
|
self,
|
120
130
|
hidden_states: torch.Tensor,
|
121
131
|
forward_batch: Optional[ForwardBatch] = None,
|
132
|
+
should_allreduce_fusion: bool = False,
|
122
133
|
use_reduce_scatter: bool = False,
|
123
134
|
) -> torch.Tensor:
|
124
135
|
|
125
136
|
if not get_moe_a2a_backend().is_deepep():
|
126
|
-
return self.forward_normal(
|
137
|
+
return self.forward_normal(
|
138
|
+
hidden_states, should_allreduce_fusion, use_reduce_scatter
|
139
|
+
)
|
127
140
|
else:
|
128
141
|
return self.forward_deepep(hidden_states, forward_batch)
|
129
142
|
|
@@ -137,6 +150,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
137
150
|
def forward_normal(
|
138
151
|
self,
|
139
152
|
hidden_states: torch.Tensor,
|
153
|
+
should_allreduce_fusion: bool = False,
|
140
154
|
use_reduce_scatter: bool = False,
|
141
155
|
) -> torch.Tensor:
|
142
156
|
num_tokens, hidden_dim = hidden_states.shape
|
@@ -146,7 +160,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
146
160
|
router_logits, _ = self.gate(hidden_states)
|
147
161
|
topk_output = self.topk(hidden_states, router_logits)
|
148
162
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
149
|
-
if
|
163
|
+
if (
|
164
|
+
self.tp_size > 1
|
165
|
+
and not should_allreduce_fusion
|
166
|
+
and not use_reduce_scatter
|
167
|
+
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
|
168
|
+
):
|
150
169
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
151
170
|
|
152
171
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
@@ -500,6 +519,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
500
519
|
input_layernorm=self.input_layernorm,
|
501
520
|
post_attention_layernorm=self.post_attention_layernorm,
|
502
521
|
allow_reduce_scatter=True,
|
522
|
+
is_last_layer=(self.layer_id == self.config.num_hidden_layers - 1),
|
503
523
|
)
|
504
524
|
|
505
525
|
def forward(
|
@@ -525,17 +545,28 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
525
545
|
hidden_states, residual, forward_batch
|
526
546
|
)
|
527
547
|
|
548
|
+
should_allreduce_fusion = (
|
549
|
+
self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
|
550
|
+
forward_batch
|
551
|
+
)
|
552
|
+
)
|
553
|
+
|
528
554
|
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
529
555
|
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
530
556
|
forward_batch
|
531
557
|
)
|
532
558
|
|
533
|
-
hidden_states = self.mlp(
|
534
|
-
|
535
|
-
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
536
|
-
hidden_states, residual, forward_batch
|
559
|
+
hidden_states = self.mlp(
|
560
|
+
hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
|
537
561
|
)
|
538
562
|
|
563
|
+
if should_allreduce_fusion:
|
564
|
+
hidden_states._sglang_needs_allreduce_fusion = True
|
565
|
+
else:
|
566
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
567
|
+
hidden_states, residual, forward_batch
|
568
|
+
)
|
569
|
+
|
539
570
|
return hidden_states, residual
|
540
571
|
|
541
572
|
def op_comm_prepare_attn(
|