sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen3.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from functools import partial
|
5
|
-
from typing import Any, Dict, Iterable, Optional, Tuple
|
5
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
6
6
|
|
7
7
|
import torch
|
8
8
|
from torch import nn
|
@@ -11,9 +11,9 @@ from sglang.srt.distributed import (
|
|
11
11
|
get_pp_group,
|
12
12
|
get_tensor_model_parallel_rank,
|
13
13
|
get_tensor_model_parallel_world_size,
|
14
|
-
split_tensor_along_last_dim,
|
15
|
-
tensor_model_parallel_all_gather,
|
16
14
|
)
|
15
|
+
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
16
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
17
17
|
from sglang.srt.layers.layernorm import RMSNorm
|
18
18
|
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
|
19
19
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
@@ -23,15 +23,17 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
|
23
23
|
from sglang.srt.layers.rotary_embedding import get_rope
|
24
24
|
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
25
25
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
26
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
26
27
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
27
28
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
28
29
|
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
29
30
|
from sglang.srt.models.qwen2 import Qwen2Model
|
30
|
-
from sglang.srt.utils import add_prefix
|
31
|
+
from sglang.srt.utils import add_prefix, is_cuda
|
31
32
|
|
32
33
|
Qwen3Config = None
|
33
34
|
|
34
35
|
logger = logging.getLogger(__name__)
|
36
|
+
_is_cuda = is_cuda()
|
35
37
|
|
36
38
|
|
37
39
|
class Qwen3Attention(nn.Module):
|
@@ -49,23 +51,27 @@ class Qwen3Attention(nn.Module):
|
|
49
51
|
rms_norm_eps: float = None,
|
50
52
|
attention_bias: bool = False,
|
51
53
|
prefix: str = "",
|
54
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
52
55
|
) -> None:
|
53
56
|
super().__init__()
|
54
57
|
self.hidden_size = hidden_size
|
55
58
|
self.tp_size = get_tensor_model_parallel_world_size()
|
56
59
|
self.total_num_heads = num_heads
|
57
|
-
|
58
|
-
|
60
|
+
attn_tp_rank = get_attention_tp_rank()
|
61
|
+
attn_tp_size = get_attention_tp_size()
|
62
|
+
|
63
|
+
assert self.total_num_heads % attn_tp_size == 0
|
64
|
+
self.num_heads = self.total_num_heads // attn_tp_size
|
59
65
|
self.total_num_kv_heads = num_kv_heads
|
60
|
-
if self.total_num_kv_heads >=
|
66
|
+
if self.total_num_kv_heads >= attn_tp_size:
|
61
67
|
# Number of KV heads is greater than TP size, so we partition
|
62
68
|
# the KV heads across multiple tensor parallel GPUs.
|
63
|
-
assert self.total_num_kv_heads %
|
69
|
+
assert self.total_num_kv_heads % attn_tp_size == 0
|
64
70
|
else:
|
65
71
|
# Number of KV heads is less than TP size, so we replicate
|
66
72
|
# the KV heads across multiple tensor parallel GPUs.
|
67
|
-
assert
|
68
|
-
self.num_kv_heads = max(1, self.total_num_kv_heads //
|
73
|
+
assert attn_tp_size % self.total_num_kv_heads == 0
|
74
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
|
69
75
|
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
70
76
|
self.q_size = self.num_heads * self.head_dim
|
71
77
|
self.kv_size = self.num_kv_heads * self.head_dim
|
@@ -84,6 +90,8 @@ class Qwen3Attention(nn.Module):
|
|
84
90
|
self.total_num_kv_heads,
|
85
91
|
bias=attention_bias,
|
86
92
|
quant_config=quant_config,
|
93
|
+
tp_rank=attn_tp_rank,
|
94
|
+
tp_size=attn_tp_size,
|
87
95
|
prefix=add_prefix("qkv_proj", prefix),
|
88
96
|
)
|
89
97
|
self.o_proj = RowParallelLinear(
|
@@ -91,6 +99,9 @@ class Qwen3Attention(nn.Module):
|
|
91
99
|
hidden_size,
|
92
100
|
bias=attention_bias,
|
93
101
|
quant_config=quant_config,
|
102
|
+
tp_rank=attn_tp_rank,
|
103
|
+
tp_size=attn_tp_size,
|
104
|
+
reduce_results=False,
|
94
105
|
prefix=add_prefix("o_proj", prefix),
|
95
106
|
)
|
96
107
|
|
@@ -109,15 +120,27 @@ class Qwen3Attention(nn.Module):
|
|
109
120
|
layer_id=layer_id,
|
110
121
|
prefix=add_prefix("attn", prefix),
|
111
122
|
)
|
123
|
+
self.alt_stream = alt_stream
|
112
124
|
|
113
125
|
def _apply_qk_norm(
|
114
126
|
self, q: torch.Tensor, k: torch.Tensor
|
115
127
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
116
|
-
|
117
|
-
|
128
|
+
# overlap qk norm
|
129
|
+
if self.alt_stream is not None and get_is_capture_mode():
|
130
|
+
current_stream = torch.cuda.current_stream()
|
131
|
+
self.alt_stream.wait_stream(current_stream)
|
132
|
+
q_by_head = q.reshape(-1, self.head_dim)
|
133
|
+
q_by_head = self.q_norm(q_by_head)
|
134
|
+
with torch.cuda.stream(self.alt_stream):
|
135
|
+
k_by_head = k.reshape(-1, self.head_dim)
|
136
|
+
k_by_head = self.k_norm(k_by_head)
|
137
|
+
current_stream.wait_stream(self.alt_stream)
|
138
|
+
else:
|
139
|
+
q_by_head = q.reshape(-1, self.head_dim)
|
140
|
+
q_by_head = self.q_norm(q_by_head)
|
141
|
+
k_by_head = k.reshape(-1, self.head_dim)
|
142
|
+
k_by_head = self.k_norm(k_by_head)
|
118
143
|
q = q_by_head.view(q.shape)
|
119
|
-
k_by_head = k.reshape(-1, self.head_dim)
|
120
|
-
k_by_head = self.k_norm(k_by_head)
|
121
144
|
k = k_by_head.view(k.shape)
|
122
145
|
return q, k
|
123
146
|
|
@@ -143,6 +166,7 @@ class Qwen3DecoderLayer(nn.Module):
|
|
143
166
|
layer_id: int = 0,
|
144
167
|
quant_config: Optional[QuantizationConfig] = None,
|
145
168
|
prefix: str = "",
|
169
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
146
170
|
) -> None:
|
147
171
|
super().__init__()
|
148
172
|
self.hidden_size = config.hidden_size
|
@@ -163,6 +187,7 @@ class Qwen3DecoderLayer(nn.Module):
|
|
163
187
|
rms_norm_eps=config.rms_norm_eps,
|
164
188
|
attention_bias=config.attention_bias,
|
165
189
|
prefix=add_prefix("self_attn", prefix),
|
190
|
+
alt_stream=alt_stream,
|
166
191
|
)
|
167
192
|
self.mlp = Qwen3MLP(
|
168
193
|
hidden_size=self.hidden_size,
|
@@ -176,6 +201,18 @@ class Qwen3DecoderLayer(nn.Module):
|
|
176
201
|
config.hidden_size, eps=config.rms_norm_eps
|
177
202
|
)
|
178
203
|
|
204
|
+
self.layer_scatter_modes = LayerScatterModes.init_new(
|
205
|
+
layer_id=layer_id,
|
206
|
+
num_layers=config.num_hidden_layers,
|
207
|
+
is_layer_sparse=False,
|
208
|
+
is_previous_layer_sparse=False,
|
209
|
+
)
|
210
|
+
self.layer_communicator = LayerCommunicator(
|
211
|
+
layer_scatter_modes=self.layer_scatter_modes,
|
212
|
+
input_layernorm=self.input_layernorm,
|
213
|
+
post_attention_layernorm=self.post_attention_layernorm,
|
214
|
+
)
|
215
|
+
|
179
216
|
def forward(
|
180
217
|
self,
|
181
218
|
positions: torch.Tensor,
|
@@ -184,20 +221,24 @@ class Qwen3DecoderLayer(nn.Module):
|
|
184
221
|
residual: Optional[torch.Tensor],
|
185
222
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
186
223
|
# Self Attention
|
187
|
-
|
188
|
-
residual
|
189
|
-
hidden_states = self.input_layernorm(hidden_states)
|
190
|
-
else:
|
191
|
-
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
192
|
-
hidden_states = self.self_attn(
|
193
|
-
positions=positions,
|
194
|
-
hidden_states=hidden_states,
|
195
|
-
forward_batch=forward_batch,
|
224
|
+
hidden_states, residual = self.layer_communicator.prepare_attn(
|
225
|
+
hidden_states, residual, forward_batch
|
196
226
|
)
|
227
|
+
if hidden_states.shape[0] != 0:
|
228
|
+
hidden_states = self.self_attn(
|
229
|
+
positions=positions,
|
230
|
+
hidden_states=hidden_states,
|
231
|
+
forward_batch=forward_batch,
|
232
|
+
)
|
197
233
|
|
198
234
|
# Fully Connected
|
199
|
-
hidden_states, residual = self.
|
235
|
+
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
236
|
+
hidden_states, residual, forward_batch
|
237
|
+
)
|
200
238
|
hidden_states = self.mlp(hidden_states)
|
239
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
240
|
+
hidden_states, residual, forward_batch
|
241
|
+
)
|
201
242
|
return hidden_states, residual
|
202
243
|
|
203
244
|
|
@@ -208,11 +249,13 @@ class Qwen3Model(Qwen2Model):
|
|
208
249
|
quant_config: Optional[QuantizationConfig] = None,
|
209
250
|
prefix: str = "",
|
210
251
|
) -> None:
|
252
|
+
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
211
253
|
super().__init__(
|
212
254
|
config=config,
|
213
255
|
quant_config=quant_config,
|
214
256
|
prefix=prefix,
|
215
257
|
decoder_layer_type=Qwen3DecoderLayer,
|
258
|
+
alt_stream=alt_stream,
|
216
259
|
)
|
217
260
|
|
218
261
|
|
@@ -282,6 +325,9 @@ class Qwen3ForCausalLM(nn.Module):
|
|
282
325
|
self.logits_processor = LogitsProcessor(config)
|
283
326
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
284
327
|
|
328
|
+
# For EAGLE3 support
|
329
|
+
self.capture_aux_hidden_states = False
|
330
|
+
|
285
331
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
286
332
|
return self.model.get_input_embeddings(input_ids)
|
287
333
|
|
@@ -303,10 +349,18 @@ class Qwen3ForCausalLM(nn.Module):
|
|
303
349
|
pp_proxy_tensors=pp_proxy_tensors,
|
304
350
|
)
|
305
351
|
|
352
|
+
aux_hidden_states = None
|
353
|
+
if self.capture_aux_hidden_states:
|
354
|
+
hidden_states, aux_hidden_states = hidden_states
|
355
|
+
|
306
356
|
if self.pp_group.is_last_rank:
|
307
357
|
if not get_embedding:
|
308
358
|
return self.logits_processor(
|
309
|
-
input_ids,
|
359
|
+
input_ids,
|
360
|
+
hidden_states,
|
361
|
+
self.lm_head,
|
362
|
+
forward_batch,
|
363
|
+
aux_hidden_states,
|
310
364
|
)
|
311
365
|
else:
|
312
366
|
return self.pooler(hidden_states, forward_batch)
|
@@ -404,5 +458,20 @@ class Qwen3ForCausalLM(nn.Module):
|
|
404
458
|
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
405
459
|
self.model.load_kv_cache_scales(quantization_param_path)
|
406
460
|
|
461
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
462
|
+
if not self.pp_group.is_last_rank:
|
463
|
+
return
|
464
|
+
|
465
|
+
self.capture_aux_hidden_states = True
|
466
|
+
if layer_ids is None:
|
467
|
+
num_layers = self.config.num_hidden_layers
|
468
|
+
self.model.layers_to_capture = [
|
469
|
+
2,
|
470
|
+
num_layers // 2,
|
471
|
+
num_layers - 3,
|
472
|
+
] # Specific layers for EAGLE3 support
|
473
|
+
else:
|
474
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
475
|
+
|
407
476
|
|
408
477
|
EntryClass = Qwen3ForCausalLM
|
sglang/srt/models/qwen3_moe.py
CHANGED
@@ -18,7 +18,7 @@
|
|
18
18
|
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
|
19
19
|
|
20
20
|
import logging
|
21
|
-
from typing import Any, Dict, Iterable, Optional, Tuple
|
21
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
22
22
|
|
23
23
|
import torch
|
24
24
|
from torch import nn
|
@@ -32,6 +32,9 @@ from sglang.srt.distributed import (
|
|
32
32
|
tensor_model_parallel_all_gather,
|
33
33
|
tensor_model_parallel_all_reduce,
|
34
34
|
)
|
35
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
36
|
+
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
37
|
+
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
35
38
|
from sglang.srt.layers.activation import SiluAndMul
|
36
39
|
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
37
40
|
from sglang.srt.layers.dp_attention import (
|
@@ -63,12 +66,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
63
66
|
ParallelLMHead,
|
64
67
|
VocabParallelEmbedding,
|
65
68
|
)
|
66
|
-
from sglang.srt.managers.expert_distribution import (
|
67
|
-
get_global_expert_distribution_recorder,
|
68
|
-
)
|
69
|
-
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
70
|
-
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
71
69
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
70
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
72
71
|
from sglang.srt.model_executor.forward_batch_info import (
|
73
72
|
ForwardBatch,
|
74
73
|
ForwardMode,
|
@@ -78,11 +77,12 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
78
77
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
79
78
|
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
80
79
|
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
81
|
-
from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty
|
80
|
+
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_non_idle_and_non_empty
|
82
81
|
|
83
82
|
Qwen3MoeConfig = None
|
84
83
|
|
85
84
|
logger = logging.getLogger(__name__)
|
85
|
+
_is_cuda = is_cuda()
|
86
86
|
|
87
87
|
|
88
88
|
class Qwen3MoeSparseMoeBlock(nn.Module):
|
@@ -117,6 +117,15 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
117
117
|
if global_server_args_dict["enable_deepep_moe"]
|
118
118
|
else {}
|
119
119
|
),
|
120
|
+
# Additional args for FusedMoE
|
121
|
+
**(
|
122
|
+
dict(
|
123
|
+
enable_flashinfer_moe=True,
|
124
|
+
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
125
|
+
)
|
126
|
+
if global_server_args_dict["enable_flashinfer_moe"]
|
127
|
+
else {}
|
128
|
+
),
|
120
129
|
)
|
121
130
|
|
122
131
|
self.gate = ReplicatedLinear(
|
@@ -220,7 +229,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
220
229
|
hidden_states=hidden_states,
|
221
230
|
topk_idx=topk_idx,
|
222
231
|
topk_weights=topk_weights,
|
223
|
-
|
232
|
+
forward_batch=forward_batch,
|
224
233
|
)
|
225
234
|
final_hidden_states = self.experts(
|
226
235
|
hidden_states=hidden_states,
|
@@ -231,14 +240,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
231
240
|
masked_m=masked_m,
|
232
241
|
expected_m=expected_m,
|
233
242
|
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
234
|
-
|
243
|
+
forward_batch=forward_batch,
|
235
244
|
)
|
236
245
|
if self.ep_size > 1:
|
237
246
|
final_hidden_states = self.deepep_dispatcher.combine(
|
238
247
|
hidden_states=final_hidden_states,
|
239
248
|
topk_idx=topk_idx,
|
240
249
|
topk_weights=topk_weights,
|
241
|
-
|
250
|
+
forward_batch=forward_batch,
|
242
251
|
)
|
243
252
|
return final_hidden_states
|
244
253
|
|
@@ -284,7 +293,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
284
293
|
hidden_states=state.pop("hidden_states_mlp_input"),
|
285
294
|
topk_idx=state.pop("topk_idx_local"),
|
286
295
|
topk_weights=state.pop("topk_weights_local"),
|
287
|
-
|
296
|
+
forward_batch=state.forward_batch,
|
288
297
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
289
298
|
)
|
290
299
|
|
@@ -316,7 +325,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
316
325
|
masked_m=state.pop("masked_m"),
|
317
326
|
expected_m=state.pop("expected_m"),
|
318
327
|
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
319
|
-
|
328
|
+
forward_batch=state.forward_batch,
|
320
329
|
)
|
321
330
|
|
322
331
|
def op_combine_a(self, state):
|
@@ -325,7 +334,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
325
334
|
hidden_states=state.pop("hidden_states_experts_output"),
|
326
335
|
topk_idx=state.pop("topk_idx_dispatched"),
|
327
336
|
topk_weights=state.pop("topk_weights_dispatched"),
|
328
|
-
|
337
|
+
forward_batch=state.forward_batch,
|
329
338
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
330
339
|
)
|
331
340
|
|
@@ -354,6 +363,7 @@ class Qwen3MoeAttention(nn.Module):
|
|
354
363
|
attention_bias: bool = False,
|
355
364
|
quant_config: Optional[QuantizationConfig] = None,
|
356
365
|
prefix: str = "",
|
366
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
357
367
|
) -> None:
|
358
368
|
super().__init__()
|
359
369
|
self.hidden_size = hidden_size
|
@@ -423,15 +433,27 @@ class Qwen3MoeAttention(nn.Module):
|
|
423
433
|
|
424
434
|
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
425
435
|
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
436
|
+
self.alt_stream = alt_stream
|
426
437
|
|
427
438
|
def _apply_qk_norm(
|
428
439
|
self, q: torch.Tensor, k: torch.Tensor
|
429
440
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
430
|
-
|
431
|
-
|
441
|
+
# overlap qk norm
|
442
|
+
if self.alt_stream is not None and get_is_capture_mode():
|
443
|
+
current_stream = torch.cuda.current_stream()
|
444
|
+
self.alt_stream.wait_stream(current_stream)
|
445
|
+
q_by_head = q.reshape(-1, self.head_dim)
|
446
|
+
q_by_head = self.q_norm(q_by_head)
|
447
|
+
with torch.cuda.stream(self.alt_stream):
|
448
|
+
k_by_head = k.reshape(-1, self.head_dim)
|
449
|
+
k_by_head = self.k_norm(k_by_head)
|
450
|
+
current_stream.wait_stream(self.alt_stream)
|
451
|
+
else:
|
452
|
+
q_by_head = q.reshape(-1, self.head_dim)
|
453
|
+
q_by_head = self.q_norm(q_by_head)
|
454
|
+
k_by_head = k.reshape(-1, self.head_dim)
|
455
|
+
k_by_head = self.k_norm(k_by_head)
|
432
456
|
q = q_by_head.view(q.shape)
|
433
|
-
k_by_head = k.reshape(-1, self.head_dim)
|
434
|
-
k_by_head = self.k_norm(k_by_head)
|
435
457
|
k = k_by_head.view(k.shape)
|
436
458
|
return q, k
|
437
459
|
|
@@ -491,6 +513,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
491
513
|
layer_id: int,
|
492
514
|
quant_config: Optional[QuantizationConfig] = None,
|
493
515
|
prefix: str = "",
|
516
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
494
517
|
) -> None:
|
495
518
|
super().__init__()
|
496
519
|
self.config = config
|
@@ -516,6 +539,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
516
539
|
attention_bias=attention_bias,
|
517
540
|
quant_config=quant_config,
|
518
541
|
prefix=add_prefix("self_attn", prefix),
|
542
|
+
alt_stream=alt_stream,
|
519
543
|
)
|
520
544
|
|
521
545
|
self.layer_id = layer_id
|
@@ -623,9 +647,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
623
647
|
|
624
648
|
def op_mlp(self, state):
|
625
649
|
hidden_states = state.pop("hidden_states_mlp_input")
|
626
|
-
state.hidden_states_mlp_output = self.mlp(
|
627
|
-
hidden_states, state.forward_batch.forward_mode
|
628
|
-
)
|
650
|
+
state.hidden_states_mlp_output = self.mlp(hidden_states, state.forward_batch)
|
629
651
|
|
630
652
|
def op_comm_postprocess_layer(self, state):
|
631
653
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
@@ -659,11 +681,13 @@ class Qwen3MoeModel(Qwen2MoeModel):
|
|
659
681
|
quant_config: Optional[QuantizationConfig] = None,
|
660
682
|
prefix: str = "",
|
661
683
|
) -> None:
|
684
|
+
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
662
685
|
super().__init__(
|
663
686
|
config=config,
|
664
687
|
quant_config=quant_config,
|
665
688
|
prefix=prefix,
|
666
689
|
decoder_layer_type=Qwen3MoeDecoderLayer,
|
690
|
+
alt_stream=alt_stream,
|
667
691
|
)
|
668
692
|
|
669
693
|
|
@@ -691,6 +715,7 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
691
715
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
692
716
|
)
|
693
717
|
self.logits_processor = LogitsProcessor(config)
|
718
|
+
self.capture_aux_hidden_states = False
|
694
719
|
|
695
720
|
@torch.no_grad()
|
696
721
|
def forward(
|
@@ -709,9 +734,13 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
709
734
|
pp_proxy_tensors=pp_proxy_tensors,
|
710
735
|
)
|
711
736
|
|
737
|
+
aux_hidden_states = None
|
738
|
+
if self.capture_aux_hidden_states:
|
739
|
+
hidden_states, aux_hidden_states = hidden_states
|
740
|
+
|
712
741
|
if self.pp_group.is_last_rank:
|
713
742
|
return self.logits_processor(
|
714
|
-
input_ids, hidden_states, self.lm_head, forward_batch
|
743
|
+
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
715
744
|
)
|
716
745
|
else:
|
717
746
|
return hidden_states
|
@@ -724,6 +753,24 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
724
753
|
def end_layer(self):
|
725
754
|
return self.model.end_layer
|
726
755
|
|
756
|
+
def get_embed_and_head(self):
|
757
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
758
|
+
|
759
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
760
|
+
if not self.pp_group.is_last_rank:
|
761
|
+
return
|
762
|
+
|
763
|
+
self.capture_aux_hidden_states = True
|
764
|
+
if layer_ids is None:
|
765
|
+
num_layers = self.config.num_hidden_layers
|
766
|
+
self.model.layers_to_capture = [
|
767
|
+
2,
|
768
|
+
num_layers // 2,
|
769
|
+
num_layers - 3,
|
770
|
+
] # Specific layers for EAGLE3 support
|
771
|
+
else:
|
772
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
773
|
+
|
727
774
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
728
775
|
stacked_params_mapping = [
|
729
776
|
# (param_name, shard_name, shard_id)
|
sglang/srt/models/vila.py
CHANGED
@@ -270,15 +270,10 @@ class VILAForConditionalGeneration(nn.Module):
|
|
270
270
|
weight_loader(param, loaded_weight)
|
271
271
|
|
272
272
|
def pad_input_ids(
|
273
|
-
self,
|
274
|
-
input_ids: List[int],
|
275
|
-
image_inputs: MultimodalInputs,
|
273
|
+
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
276
274
|
) -> List[int]:
|
277
|
-
pattern = MultiModalityDataPaddingPatternMultimodalTokens(
|
278
|
-
|
279
|
-
)
|
280
|
-
|
281
|
-
return pattern.pad_input_tokens(input_ids, image_inputs)
|
275
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
276
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
282
277
|
|
283
278
|
##### BEGIN COPY modeling_vila.py #####
|
284
279
|
|
@@ -28,12 +28,12 @@ LLaVA-Onevision : https://arxiv.org/pdf/2408.03326
|
|
28
28
|
|
29
29
|
"""
|
30
30
|
import ast
|
31
|
-
import base64
|
32
31
|
import math
|
33
32
|
import re
|
34
33
|
from io import BytesIO
|
35
34
|
|
36
35
|
import numpy as np
|
36
|
+
import pybase64
|
37
37
|
from PIL import Image
|
38
38
|
|
39
39
|
from sglang.srt.utils import flatten_nested_list
|
@@ -252,7 +252,7 @@ def process_anyres_image(image, processor, grid_pinpoints):
|
|
252
252
|
|
253
253
|
|
254
254
|
def load_image_from_base64(image):
|
255
|
-
return Image.open(BytesIO(
|
255
|
+
return Image.open(BytesIO(pybase64.b64decode(image, validate=True)))
|
256
256
|
|
257
257
|
|
258
258
|
def expand2square(pil_img, background_color):
|