sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +49 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +35 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -6
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +100 -52
- sglang/srt/disaggregation/prefill.py +5 -4
- sglang/srt/disaggregation/utils.py +13 -12
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +45 -9
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +51 -6
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +56 -23
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +41 -0
- sglang/srt/layers/linear.py +99 -12
- sglang/srt/layers/logits_processor.py +15 -6
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +115 -25
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +36 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +6 -6
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +105 -13
- sglang/srt/layers/vocab_parallel_embedding.py +19 -2
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +60 -15
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +80 -79
- sglang/srt/managers/scheduler.py +153 -63
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -2
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +302 -58
- sglang/srt/model_loader/loader.py +86 -10
- sglang/srt/model_loader/weight_utils.py +160 -3
- sglang/srt/models/deepseek_nextn.py +5 -4
- sglang/srt/models/deepseek_v2.py +305 -26
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1010 -0
- sglang/srt/models/gemma3n_mm.py +495 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/multimodal/processors/gemma3n.py +82 -0
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +85 -24
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +204 -28
- sglang/srt/utils.py +369 -138
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,200 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Adapted from
|
3
|
+
# https://github.com/huggingface/transformers/blob/1d45d90e5d1552eccb6d8cc9b7bba283ccefb808/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
|
4
|
+
# Copyright 2024 The Qwen team.
|
5
|
+
# Copyright 2023 The vLLM team.
|
6
|
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
7
|
+
#
|
8
|
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
9
|
+
# and OPT implementations in this library. It has been modified from its
|
10
|
+
# original forms to accommodate minor architectural differences compared
|
11
|
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
12
|
+
#
|
13
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
14
|
+
# you may not use this file except in compliance with the License.
|
15
|
+
# You may obtain a copy of the License at
|
16
|
+
#
|
17
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
18
|
+
#
|
19
|
+
# Unless required by applicable law or agreed to in writing, software
|
20
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
21
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
22
|
+
# See the License for the specific language governing permissions and
|
23
|
+
# limitations under the License.
|
24
|
+
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
25
|
+
import logging
|
26
|
+
import math
|
27
|
+
from functools import lru_cache, partial
|
28
|
+
from typing import Any, Iterable, List, Optional, Tuple, Type, TypedDict
|
29
|
+
|
30
|
+
import torch
|
31
|
+
import torch.nn as nn
|
32
|
+
import torch.nn.functional as F
|
33
|
+
from einops import rearrange
|
34
|
+
from transformers import AutoTokenizer, Qwen2AudioEncoderConfig, Qwen2Config
|
35
|
+
from transformers.activations import ACT2FN
|
36
|
+
from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioConfig
|
37
|
+
from transformers.models.qwen2_audio.modeling_qwen2_audio import (
|
38
|
+
Qwen2AudioEncoder,
|
39
|
+
Qwen2AudioMultiModalProjector,
|
40
|
+
)
|
41
|
+
|
42
|
+
from sglang.srt.hf_transformers_utils import get_processor
|
43
|
+
from sglang.srt.layers.activation import QuickGELU
|
44
|
+
from sglang.srt.layers.attention.vision import VisionAttention
|
45
|
+
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
46
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
48
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
49
|
+
from sglang.srt.layers.utils import get_layer_id
|
50
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
51
|
+
from sglang.srt.managers.mm_utils import (
|
52
|
+
MultiModalityDataPaddingPatternMultimodalTokens,
|
53
|
+
general_mm_embed_routine,
|
54
|
+
)
|
55
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
56
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
57
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
58
|
+
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
59
|
+
from sglang.srt.utils import add_prefix
|
60
|
+
|
61
|
+
logger = logging.getLogger(__name__)
|
62
|
+
|
63
|
+
|
64
|
+
class Qwen2AudioForConditionalGeneration(nn.Module):
|
65
|
+
# BitandBytes specific attributes
|
66
|
+
default_bitsandbytes_target_modules = [
|
67
|
+
".gate_proj.",
|
68
|
+
".down_proj.",
|
69
|
+
".up_proj.",
|
70
|
+
".q_proj.",
|
71
|
+
".k_proj.",
|
72
|
+
".v_proj.",
|
73
|
+
".o_proj.",
|
74
|
+
]
|
75
|
+
bitsandbytes_stacked_params_mapping = {
|
76
|
+
# shard_name, weight_name, index
|
77
|
+
"q_proj": ("qkv_proj", 0),
|
78
|
+
"k_proj": ("qkv_proj", 1),
|
79
|
+
"v_proj": ("qkv_proj", 2),
|
80
|
+
"gate_proj": ("gate_up_proj", 0),
|
81
|
+
"up_proj": ("gate_up_proj", 1),
|
82
|
+
}
|
83
|
+
|
84
|
+
def __init__(
|
85
|
+
self,
|
86
|
+
config: Qwen2AudioConfig,
|
87
|
+
quant_config: Optional[QuantizationConfig] = None,
|
88
|
+
prefix: str = "",
|
89
|
+
) -> None:
|
90
|
+
super().__init__()
|
91
|
+
|
92
|
+
self.config = config
|
93
|
+
|
94
|
+
if getattr(self.config, "audio_config", None) is None:
|
95
|
+
self.config.audio_config = Qwen2AudioEncoderConfig(
|
96
|
+
self.config._name_or_path
|
97
|
+
)
|
98
|
+
|
99
|
+
if getattr(self.config, "text_config", None) is None:
|
100
|
+
self.config.text_config = Qwen2Config(self.config._name_or_path)
|
101
|
+
|
102
|
+
self.audio_tower = Qwen2AudioEncoder(
|
103
|
+
config.audio_config,
|
104
|
+
)
|
105
|
+
self.multi_modal_projector = Qwen2AudioMultiModalProjector(config)
|
106
|
+
self.language_model = Qwen2ForCausalLM(
|
107
|
+
config.text_config, quant_config, prefix=add_prefix("model", prefix)
|
108
|
+
)
|
109
|
+
|
110
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
111
|
+
# Get all special token IDs for audio
|
112
|
+
audio_token_id: int = getattr(
|
113
|
+
mm_inputs, "audio_token_id", mm_inputs.im_token_id
|
114
|
+
)
|
115
|
+
|
116
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens([audio_token_id])
|
117
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
118
|
+
|
119
|
+
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
120
|
+
# Extract audio features from input items
|
121
|
+
input_features = torch.cat([item.audio_features for item in items], dim=0).type(
|
122
|
+
self.audio_tower.dtype
|
123
|
+
)
|
124
|
+
|
125
|
+
audio_embeds = self.audio_tower(input_features).last_hidden_state
|
126
|
+
audio_embeds = self.multi_modal_projector(audio_embeds)
|
127
|
+
|
128
|
+
audio_feature_lens = torch.cat([item.audio_feature_lens for item in items])
|
129
|
+
new_embeds = []
|
130
|
+
for i, d in zip(audio_feature_lens, audio_embeds):
|
131
|
+
new_embeds.append(d[: i.item()])
|
132
|
+
|
133
|
+
return torch.cat(new_embeds, dim=0)
|
134
|
+
|
135
|
+
def forward(
|
136
|
+
self,
|
137
|
+
input_ids: torch.Tensor,
|
138
|
+
positions: torch.Tensor,
|
139
|
+
forward_batch: ForwardBatch,
|
140
|
+
**kwargs: Any,
|
141
|
+
) -> torch.Tensor:
|
142
|
+
hidden_states = general_mm_embed_routine(
|
143
|
+
input_ids=input_ids,
|
144
|
+
forward_batch=forward_batch,
|
145
|
+
language_model=self.language_model,
|
146
|
+
audio_data_embedding_func=self.get_audio_feature,
|
147
|
+
positions=positions,
|
148
|
+
)
|
149
|
+
|
150
|
+
return hidden_states
|
151
|
+
|
152
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
153
|
+
stacked_params_mapping = [
|
154
|
+
# (param_name, shard_name, shard_id)
|
155
|
+
("qkv_proj", "q_proj", "q"),
|
156
|
+
("qkv_proj", "k_proj", "k"),
|
157
|
+
("qkv_proj", "v_proj", "v"),
|
158
|
+
("gate_up_proj", "gate_proj", 0),
|
159
|
+
("gate_up_proj", "up_proj", 1),
|
160
|
+
]
|
161
|
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
162
|
+
|
163
|
+
for name, loaded_weight in weights:
|
164
|
+
if "rotary_emb.inv_freq" in name:
|
165
|
+
continue
|
166
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
167
|
+
# Models trained using ColossalAI may include these tensors in
|
168
|
+
# the checkpoint. Skip them.
|
169
|
+
continue
|
170
|
+
|
171
|
+
if self.config.text_config.tie_word_embeddings and "lm_head.weight" in name:
|
172
|
+
continue
|
173
|
+
|
174
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
175
|
+
if weight_name not in name or "audio_tower" in name:
|
176
|
+
continue
|
177
|
+
name_tmp = name.replace(weight_name, param_name)
|
178
|
+
|
179
|
+
# Skip loading extra bias for GPTQ models.
|
180
|
+
if name_tmp.endswith(".bias") and name_tmp not in params_dict:
|
181
|
+
continue
|
182
|
+
param = params_dict[name_tmp]
|
183
|
+
weight_loader = param.weight_loader
|
184
|
+
weight_loader(param, loaded_weight, shard_id)
|
185
|
+
break
|
186
|
+
else:
|
187
|
+
try:
|
188
|
+
# Skip loading extra bias for GPTQ models.
|
189
|
+
if name.endswith(".bias") and name not in params_dict:
|
190
|
+
continue
|
191
|
+
param = params_dict[name]
|
192
|
+
except KeyError:
|
193
|
+
print(params_dict.keys())
|
194
|
+
raise
|
195
|
+
|
196
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
197
|
+
weight_loader(param, loaded_weight)
|
198
|
+
|
199
|
+
|
200
|
+
EntryClass = Qwen2AudioForConditionalGeneration
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -31,6 +31,11 @@ from sglang.srt.distributed import (
|
|
31
31
|
get_tensor_model_parallel_world_size,
|
32
32
|
tensor_model_parallel_all_reduce,
|
33
33
|
)
|
34
|
+
from sglang.srt.eplb.expert_distribution import (
|
35
|
+
ExpertDistributionRecorder,
|
36
|
+
get_global_expert_distribution_recorder,
|
37
|
+
)
|
38
|
+
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
34
39
|
from sglang.srt.layers.activation import SiluAndMul
|
35
40
|
from sglang.srt.layers.communicator import (
|
36
41
|
LayerCommunicator,
|
@@ -64,11 +69,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
64
69
|
ParallelLMHead,
|
65
70
|
VocabParallelEmbedding,
|
66
71
|
)
|
67
|
-
from sglang.srt.managers.expert_distribution import (
|
68
|
-
ExpertDistributionRecorder,
|
69
|
-
get_global_expert_distribution_recorder,
|
70
|
-
)
|
71
|
-
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
72
72
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
73
73
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
74
74
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
@@ -143,6 +143,15 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
143
143
|
renormalize=config.norm_topk_prob,
|
144
144
|
quant_config=quant_config,
|
145
145
|
prefix=add_prefix("experts", prefix),
|
146
|
+
# Additional args for FusedMoE
|
147
|
+
**(
|
148
|
+
dict(
|
149
|
+
enable_flashinfer_moe=True,
|
150
|
+
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
151
|
+
)
|
152
|
+
if global_server_args_dict["enable_flashinfer_moe"]
|
153
|
+
else {}
|
154
|
+
),
|
146
155
|
)
|
147
156
|
|
148
157
|
self.gate = ReplicatedLinear(
|
@@ -291,6 +300,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
291
300
|
layer_id: int,
|
292
301
|
quant_config: Optional[QuantizationConfig] = None,
|
293
302
|
prefix: str = "",
|
303
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
294
304
|
) -> None:
|
295
305
|
super().__init__()
|
296
306
|
self.config = config
|
@@ -393,6 +403,7 @@ class Qwen2MoeModel(nn.Module):
|
|
393
403
|
quant_config: Optional[QuantizationConfig] = None,
|
394
404
|
prefix: str = "",
|
395
405
|
decoder_layer_type: type[nn.Module] = Qwen2MoeDecoderLayer,
|
406
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
396
407
|
) -> None:
|
397
408
|
super().__init__()
|
398
409
|
self.padding_idx = config.pad_token_id
|
@@ -418,6 +429,7 @@ class Qwen2MoeModel(nn.Module):
|
|
418
429
|
config=config,
|
419
430
|
quant_config=quant_config,
|
420
431
|
prefix=prefix,
|
432
|
+
alt_stream=alt_stream,
|
421
433
|
),
|
422
434
|
pp_rank=self.pp_group.rank_in_group,
|
423
435
|
pp_size=self.pp_group.world_size,
|
@@ -428,6 +440,9 @@ class Qwen2MoeModel(nn.Module):
|
|
428
440
|
else:
|
429
441
|
self.norm = PPMissingLayer(return_tuple=True)
|
430
442
|
|
443
|
+
# For EAGLE3 support
|
444
|
+
self.layers_to_capture = []
|
445
|
+
|
431
446
|
def forward(
|
432
447
|
self,
|
433
448
|
input_ids: torch.Tensor,
|
@@ -447,6 +462,7 @@ class Qwen2MoeModel(nn.Module):
|
|
447
462
|
hidden_states = pp_proxy_tensors["hidden_states"]
|
448
463
|
residual = pp_proxy_tensors["residual"]
|
449
464
|
|
465
|
+
aux_hidden_states = []
|
450
466
|
if forward_batch.can_run_tbo:
|
451
467
|
hidden_states, residual = model_forward_maybe_tbo(
|
452
468
|
layers=self.layers,
|
@@ -459,6 +475,12 @@ class Qwen2MoeModel(nn.Module):
|
|
459
475
|
)
|
460
476
|
else:
|
461
477
|
for i in range(self.start_layer, self.end_layer):
|
478
|
+
if i in self.layers_to_capture:
|
479
|
+
aux_hidden_states.append(
|
480
|
+
hidden_states + residual
|
481
|
+
if residual is not None
|
482
|
+
else hidden_states
|
483
|
+
)
|
462
484
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
463
485
|
layer = self.layers[i]
|
464
486
|
hidden_states, residual = layer(
|
@@ -477,7 +499,11 @@ class Qwen2MoeModel(nn.Module):
|
|
477
499
|
hidden_states = self.norm(hidden_states)
|
478
500
|
else:
|
479
501
|
hidden_states, _ = self.norm(hidden_states, residual)
|
480
|
-
|
502
|
+
|
503
|
+
if len(aux_hidden_states) == 0:
|
504
|
+
return hidden_states
|
505
|
+
|
506
|
+
return hidden_states, aux_hidden_states
|
481
507
|
|
482
508
|
|
483
509
|
class Qwen2MoeForCausalLM(nn.Module):
|
sglang/srt/models/qwen2_vl.py
CHANGED
@@ -479,10 +479,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
479
479
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
480
480
|
|
481
481
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
482
|
-
|
483
|
-
im_token_id: int = mm_inputs.im_token_id
|
484
|
-
|
485
|
-
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
|
482
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
486
483
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
487
484
|
|
488
485
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
sglang/srt/models/qwen3.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from functools import partial
|
5
|
-
from typing import Any, Dict, Iterable, Optional, Tuple
|
5
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
6
6
|
|
7
7
|
import torch
|
8
8
|
from torch import nn
|
@@ -11,9 +11,9 @@ from sglang.srt.distributed import (
|
|
11
11
|
get_pp_group,
|
12
12
|
get_tensor_model_parallel_rank,
|
13
13
|
get_tensor_model_parallel_world_size,
|
14
|
-
split_tensor_along_last_dim,
|
15
|
-
tensor_model_parallel_all_gather,
|
16
14
|
)
|
15
|
+
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
16
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
17
17
|
from sglang.srt.layers.layernorm import RMSNorm
|
18
18
|
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
|
19
19
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
@@ -23,15 +23,17 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
|
23
23
|
from sglang.srt.layers.rotary_embedding import get_rope
|
24
24
|
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
25
25
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
26
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
26
27
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
27
28
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
28
29
|
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
29
30
|
from sglang.srt.models.qwen2 import Qwen2Model
|
30
|
-
from sglang.srt.utils import add_prefix
|
31
|
+
from sglang.srt.utils import add_prefix, is_cuda
|
31
32
|
|
32
33
|
Qwen3Config = None
|
33
34
|
|
34
35
|
logger = logging.getLogger(__name__)
|
36
|
+
_is_cuda = is_cuda()
|
35
37
|
|
36
38
|
|
37
39
|
class Qwen3Attention(nn.Module):
|
@@ -49,23 +51,27 @@ class Qwen3Attention(nn.Module):
|
|
49
51
|
rms_norm_eps: float = None,
|
50
52
|
attention_bias: bool = False,
|
51
53
|
prefix: str = "",
|
54
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
52
55
|
) -> None:
|
53
56
|
super().__init__()
|
54
57
|
self.hidden_size = hidden_size
|
55
58
|
self.tp_size = get_tensor_model_parallel_world_size()
|
56
59
|
self.total_num_heads = num_heads
|
57
|
-
|
58
|
-
|
60
|
+
attn_tp_rank = get_attention_tp_rank()
|
61
|
+
attn_tp_size = get_attention_tp_size()
|
62
|
+
|
63
|
+
assert self.total_num_heads % attn_tp_size == 0
|
64
|
+
self.num_heads = self.total_num_heads // attn_tp_size
|
59
65
|
self.total_num_kv_heads = num_kv_heads
|
60
|
-
if self.total_num_kv_heads >=
|
66
|
+
if self.total_num_kv_heads >= attn_tp_size:
|
61
67
|
# Number of KV heads is greater than TP size, so we partition
|
62
68
|
# the KV heads across multiple tensor parallel GPUs.
|
63
|
-
assert self.total_num_kv_heads %
|
69
|
+
assert self.total_num_kv_heads % attn_tp_size == 0
|
64
70
|
else:
|
65
71
|
# Number of KV heads is less than TP size, so we replicate
|
66
72
|
# the KV heads across multiple tensor parallel GPUs.
|
67
|
-
assert
|
68
|
-
self.num_kv_heads = max(1, self.total_num_kv_heads //
|
73
|
+
assert attn_tp_size % self.total_num_kv_heads == 0
|
74
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
|
69
75
|
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
70
76
|
self.q_size = self.num_heads * self.head_dim
|
71
77
|
self.kv_size = self.num_kv_heads * self.head_dim
|
@@ -84,6 +90,8 @@ class Qwen3Attention(nn.Module):
|
|
84
90
|
self.total_num_kv_heads,
|
85
91
|
bias=attention_bias,
|
86
92
|
quant_config=quant_config,
|
93
|
+
tp_rank=attn_tp_rank,
|
94
|
+
tp_size=attn_tp_size,
|
87
95
|
prefix=add_prefix("qkv_proj", prefix),
|
88
96
|
)
|
89
97
|
self.o_proj = RowParallelLinear(
|
@@ -91,6 +99,9 @@ class Qwen3Attention(nn.Module):
|
|
91
99
|
hidden_size,
|
92
100
|
bias=attention_bias,
|
93
101
|
quant_config=quant_config,
|
102
|
+
tp_rank=attn_tp_rank,
|
103
|
+
tp_size=attn_tp_size,
|
104
|
+
reduce_results=False,
|
94
105
|
prefix=add_prefix("o_proj", prefix),
|
95
106
|
)
|
96
107
|
|
@@ -109,15 +120,27 @@ class Qwen3Attention(nn.Module):
|
|
109
120
|
layer_id=layer_id,
|
110
121
|
prefix=add_prefix("attn", prefix),
|
111
122
|
)
|
123
|
+
self.alt_stream = alt_stream
|
112
124
|
|
113
125
|
def _apply_qk_norm(
|
114
126
|
self, q: torch.Tensor, k: torch.Tensor
|
115
127
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
116
|
-
|
117
|
-
|
128
|
+
# overlap qk norm
|
129
|
+
if self.alt_stream is not None and get_is_capture_mode():
|
130
|
+
current_stream = torch.cuda.current_stream()
|
131
|
+
self.alt_stream.wait_stream(current_stream)
|
132
|
+
q_by_head = q.reshape(-1, self.head_dim)
|
133
|
+
q_by_head = self.q_norm(q_by_head)
|
134
|
+
with torch.cuda.stream(self.alt_stream):
|
135
|
+
k_by_head = k.reshape(-1, self.head_dim)
|
136
|
+
k_by_head = self.k_norm(k_by_head)
|
137
|
+
current_stream.wait_stream(self.alt_stream)
|
138
|
+
else:
|
139
|
+
q_by_head = q.reshape(-1, self.head_dim)
|
140
|
+
q_by_head = self.q_norm(q_by_head)
|
141
|
+
k_by_head = k.reshape(-1, self.head_dim)
|
142
|
+
k_by_head = self.k_norm(k_by_head)
|
118
143
|
q = q_by_head.view(q.shape)
|
119
|
-
k_by_head = k.reshape(-1, self.head_dim)
|
120
|
-
k_by_head = self.k_norm(k_by_head)
|
121
144
|
k = k_by_head.view(k.shape)
|
122
145
|
return q, k
|
123
146
|
|
@@ -143,6 +166,7 @@ class Qwen3DecoderLayer(nn.Module):
|
|
143
166
|
layer_id: int = 0,
|
144
167
|
quant_config: Optional[QuantizationConfig] = None,
|
145
168
|
prefix: str = "",
|
169
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
146
170
|
) -> None:
|
147
171
|
super().__init__()
|
148
172
|
self.hidden_size = config.hidden_size
|
@@ -163,6 +187,7 @@ class Qwen3DecoderLayer(nn.Module):
|
|
163
187
|
rms_norm_eps=config.rms_norm_eps,
|
164
188
|
attention_bias=config.attention_bias,
|
165
189
|
prefix=add_prefix("self_attn", prefix),
|
190
|
+
alt_stream=alt_stream,
|
166
191
|
)
|
167
192
|
self.mlp = Qwen3MLP(
|
168
193
|
hidden_size=self.hidden_size,
|
@@ -176,6 +201,18 @@ class Qwen3DecoderLayer(nn.Module):
|
|
176
201
|
config.hidden_size, eps=config.rms_norm_eps
|
177
202
|
)
|
178
203
|
|
204
|
+
self.layer_scatter_modes = LayerScatterModes.init_new(
|
205
|
+
layer_id=layer_id,
|
206
|
+
num_layers=config.num_hidden_layers,
|
207
|
+
is_layer_sparse=False,
|
208
|
+
is_previous_layer_sparse=False,
|
209
|
+
)
|
210
|
+
self.layer_communicator = LayerCommunicator(
|
211
|
+
layer_scatter_modes=self.layer_scatter_modes,
|
212
|
+
input_layernorm=self.input_layernorm,
|
213
|
+
post_attention_layernorm=self.post_attention_layernorm,
|
214
|
+
)
|
215
|
+
|
179
216
|
def forward(
|
180
217
|
self,
|
181
218
|
positions: torch.Tensor,
|
@@ -184,20 +221,24 @@ class Qwen3DecoderLayer(nn.Module):
|
|
184
221
|
residual: Optional[torch.Tensor],
|
185
222
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
186
223
|
# Self Attention
|
187
|
-
|
188
|
-
residual
|
189
|
-
hidden_states = self.input_layernorm(hidden_states)
|
190
|
-
else:
|
191
|
-
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
192
|
-
hidden_states = self.self_attn(
|
193
|
-
positions=positions,
|
194
|
-
hidden_states=hidden_states,
|
195
|
-
forward_batch=forward_batch,
|
224
|
+
hidden_states, residual = self.layer_communicator.prepare_attn(
|
225
|
+
hidden_states, residual, forward_batch
|
196
226
|
)
|
227
|
+
if hidden_states.shape[0] != 0:
|
228
|
+
hidden_states = self.self_attn(
|
229
|
+
positions=positions,
|
230
|
+
hidden_states=hidden_states,
|
231
|
+
forward_batch=forward_batch,
|
232
|
+
)
|
197
233
|
|
198
234
|
# Fully Connected
|
199
|
-
hidden_states, residual = self.
|
235
|
+
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
236
|
+
hidden_states, residual, forward_batch
|
237
|
+
)
|
200
238
|
hidden_states = self.mlp(hidden_states)
|
239
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
240
|
+
hidden_states, residual, forward_batch
|
241
|
+
)
|
201
242
|
return hidden_states, residual
|
202
243
|
|
203
244
|
|
@@ -208,11 +249,13 @@ class Qwen3Model(Qwen2Model):
|
|
208
249
|
quant_config: Optional[QuantizationConfig] = None,
|
209
250
|
prefix: str = "",
|
210
251
|
) -> None:
|
252
|
+
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
211
253
|
super().__init__(
|
212
254
|
config=config,
|
213
255
|
quant_config=quant_config,
|
214
256
|
prefix=prefix,
|
215
257
|
decoder_layer_type=Qwen3DecoderLayer,
|
258
|
+
alt_stream=alt_stream,
|
216
259
|
)
|
217
260
|
|
218
261
|
|
@@ -282,6 +325,9 @@ class Qwen3ForCausalLM(nn.Module):
|
|
282
325
|
self.logits_processor = LogitsProcessor(config)
|
283
326
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
284
327
|
|
328
|
+
# For EAGLE3 support
|
329
|
+
self.capture_aux_hidden_states = False
|
330
|
+
|
285
331
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
286
332
|
return self.model.get_input_embeddings(input_ids)
|
287
333
|
|
@@ -303,10 +349,18 @@ class Qwen3ForCausalLM(nn.Module):
|
|
303
349
|
pp_proxy_tensors=pp_proxy_tensors,
|
304
350
|
)
|
305
351
|
|
352
|
+
aux_hidden_states = None
|
353
|
+
if self.capture_aux_hidden_states:
|
354
|
+
hidden_states, aux_hidden_states = hidden_states
|
355
|
+
|
306
356
|
if self.pp_group.is_last_rank:
|
307
357
|
if not get_embedding:
|
308
358
|
return self.logits_processor(
|
309
|
-
input_ids,
|
359
|
+
input_ids,
|
360
|
+
hidden_states,
|
361
|
+
self.lm_head,
|
362
|
+
forward_batch,
|
363
|
+
aux_hidden_states,
|
310
364
|
)
|
311
365
|
else:
|
312
366
|
return self.pooler(hidden_states, forward_batch)
|
@@ -404,5 +458,20 @@ class Qwen3ForCausalLM(nn.Module):
|
|
404
458
|
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
405
459
|
self.model.load_kv_cache_scales(quantization_param_path)
|
406
460
|
|
461
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
462
|
+
if not self.pp_group.is_last_rank:
|
463
|
+
return
|
464
|
+
|
465
|
+
self.capture_aux_hidden_states = True
|
466
|
+
if layer_ids is None:
|
467
|
+
num_layers = self.config.num_hidden_layers
|
468
|
+
self.model.layers_to_capture = [
|
469
|
+
2,
|
470
|
+
num_layers // 2,
|
471
|
+
num_layers - 3,
|
472
|
+
] # Specific layers for EAGLE3 support
|
473
|
+
else:
|
474
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
475
|
+
|
407
476
|
|
408
477
|
EntryClass = Qwen3ForCausalLM
|