sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +0 -112
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +1 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +11 -0
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +35 -15
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/hf_transformers_utils.py +25 -10
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/attention/vision.py +27 -10
- sglang/srt/layers/communicator.py +14 -4
- sglang/srt/layers/linear.py +7 -1
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/ep_moe/layer.py +29 -68
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/utils.py +43 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp8.py +57 -1
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/layers/vocab_parallel_embedding.py +7 -1
- sglang/srt/lora/lora_registry.py +7 -0
- sglang/srt/managers/cache_controller.py +43 -39
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/io_struct.py +6 -1
- sglang/srt/managers/schedule_batch.py +3 -2
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +145 -6
- sglang/srt/managers/template_manager.py +25 -22
- sglang/srt/managers/tokenizer_manager.py +114 -62
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -12
- sglang/srt/mem_cache/hiradix_cache.py +21 -4
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +350 -33
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/model_executor/cuda_graph_runner.py +42 -4
- sglang/srt/model_executor/forward_batch_info.py +13 -3
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/deepseek_v2.py +28 -23
- sglang/srt/models/glm4_moe.py +85 -22
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2_moe.py +1 -4
- sglang/srt/models/qwen3_moe.py +7 -8
- sglang/srt/models/step3_vl.py +1 -4
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/server_args.py +115 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +6 -4
- sglang/srt/utils.py +4 -24
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
sglang/srt/models/mllama4.py
CHANGED
@@ -1,17 +1,24 @@
|
|
1
1
|
import json as json_lib
|
2
2
|
import logging
|
3
|
+
import math
|
3
4
|
import os
|
4
5
|
from collections.abc import Iterable
|
5
6
|
from typing import List, Optional, Set, Tuple
|
6
7
|
|
7
8
|
import torch
|
8
9
|
from torch import nn
|
9
|
-
from transformers import Llama4Config
|
10
|
+
from transformers import Llama4Config, Llama4VisionConfig
|
10
11
|
from transformers.models.llama4.modeling_llama4 import (
|
11
12
|
Llama4MultiModalProjector,
|
12
|
-
|
13
|
+
vision_apply_rotary_emb,
|
13
14
|
)
|
14
15
|
|
16
|
+
from sglang.srt.layers.attention.vision import VisionAttention
|
17
|
+
from sglang.srt.layers.linear import (
|
18
|
+
ColumnParallelLinear,
|
19
|
+
ReplicatedLinear,
|
20
|
+
RowParallelLinear,
|
21
|
+
)
|
15
22
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
16
23
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
17
24
|
from sglang.srt.layers.quantization import QuantizationConfig
|
@@ -26,10 +33,10 @@ from sglang.srt.managers.schedule_batch import (
|
|
26
33
|
global_server_args_dict,
|
27
34
|
)
|
28
35
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
29
|
-
from sglang.srt.
|
30
|
-
from sglang.srt.utils import add_prefix, is_cpu
|
36
|
+
from sglang.srt.utils import is_cpu
|
31
37
|
|
32
38
|
_is_cpu = is_cpu()
|
39
|
+
|
33
40
|
from sglang.srt.model_loader.weight_utils import (
|
34
41
|
default_weight_loader,
|
35
42
|
maybe_remap_kv_scale_name,
|
@@ -39,6 +46,376 @@ from sglang.srt.utils import add_prefix
|
|
39
46
|
logger = logging.getLogger(__name__)
|
40
47
|
|
41
48
|
|
49
|
+
class Llama4VisionMLP(nn.Module):
|
50
|
+
|
51
|
+
def __init__(
|
52
|
+
self,
|
53
|
+
input_size: int,
|
54
|
+
intermediate_size: int,
|
55
|
+
output_size: int,
|
56
|
+
bias: bool,
|
57
|
+
output_activation: bool,
|
58
|
+
quant_config: Optional[QuantizationConfig] = None,
|
59
|
+
prefix: str = "",
|
60
|
+
use_data_parallel: bool = False,
|
61
|
+
):
|
62
|
+
super().__init__()
|
63
|
+
cls_fc1 = ReplicatedLinear if use_data_parallel else ColumnParallelLinear
|
64
|
+
self.fc1 = cls_fc1(
|
65
|
+
input_size=input_size,
|
66
|
+
output_size=intermediate_size,
|
67
|
+
bias=bias,
|
68
|
+
quant_config=quant_config,
|
69
|
+
prefix=f"{prefix}.fc1",
|
70
|
+
)
|
71
|
+
cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear
|
72
|
+
self.fc2 = cls_fc2(
|
73
|
+
input_size=intermediate_size,
|
74
|
+
output_size=output_size,
|
75
|
+
bias=bias,
|
76
|
+
quant_config=quant_config,
|
77
|
+
prefix=f"{prefix}.fc2",
|
78
|
+
)
|
79
|
+
self.activation_fn = nn.GELU()
|
80
|
+
self.output_activation = output_activation
|
81
|
+
|
82
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
83
|
+
hidden_states, _ = self.fc1(hidden_states)
|
84
|
+
hidden_states = self.activation_fn(hidden_states)
|
85
|
+
hidden_states, _ = self.fc2(hidden_states)
|
86
|
+
if self.output_activation:
|
87
|
+
return self.activation_fn(hidden_states)
|
88
|
+
return hidden_states
|
89
|
+
|
90
|
+
|
91
|
+
def pixel_shuffle(input_tensor, shuffle_ratio):
|
92
|
+
# input_tensor: [batch_size, num_patches, channels]
|
93
|
+
batch_size, num_patches, channels = input_tensor.shape
|
94
|
+
patch_size = int(math.sqrt(num_patches))
|
95
|
+
|
96
|
+
input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
|
97
|
+
batch_size, height, width, channels = input_tensor.size()
|
98
|
+
|
99
|
+
reshaped_tensor = input_tensor.view(
|
100
|
+
batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
|
101
|
+
)
|
102
|
+
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
103
|
+
|
104
|
+
reshaped_tensor = reshaped_tensor.view(
|
105
|
+
batch_size,
|
106
|
+
int(height * shuffle_ratio),
|
107
|
+
int(width * shuffle_ratio),
|
108
|
+
int(channels / (shuffle_ratio**2)),
|
109
|
+
)
|
110
|
+
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
111
|
+
|
112
|
+
output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])
|
113
|
+
return output_tensor
|
114
|
+
|
115
|
+
|
116
|
+
class Llama4VisionPixelShuffleMLP(nn.Module):
|
117
|
+
|
118
|
+
def __init__(
|
119
|
+
self,
|
120
|
+
config,
|
121
|
+
quant_config: Optional[QuantizationConfig] = None,
|
122
|
+
prefix: str = "",
|
123
|
+
use_data_parallel: bool = False,
|
124
|
+
):
|
125
|
+
super().__init__()
|
126
|
+
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
|
127
|
+
self.mlp = Llama4VisionMLP(
|
128
|
+
input_size=config.intermediate_size,
|
129
|
+
intermediate_size=config.projector_input_dim,
|
130
|
+
output_size=config.projector_output_dim,
|
131
|
+
bias=config.multi_modal_projector_bias,
|
132
|
+
output_activation=True,
|
133
|
+
quant_config=quant_config,
|
134
|
+
prefix=f"{prefix}.mlp",
|
135
|
+
use_data_parallel=use_data_parallel,
|
136
|
+
)
|
137
|
+
|
138
|
+
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
|
139
|
+
encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
|
140
|
+
return self.mlp(encoded_patches)
|
141
|
+
|
142
|
+
|
143
|
+
def apply_position_embedding(q, k, freqs_ci, shape):
|
144
|
+
# [batch_size_times_num_tiles, num_channels]
|
145
|
+
input_shape = shape[:2]
|
146
|
+
# [batch_size_times_num_tiles, num_channels, num_heads, head_dim]
|
147
|
+
hidden_shape = (*input_shape, *q.shape[-2:])
|
148
|
+
q = q.view(hidden_shape)
|
149
|
+
k = k.view(hidden_shape)
|
150
|
+
q, k = vision_apply_rotary_emb(q, k, freqs_ci)
|
151
|
+
return q, k
|
152
|
+
|
153
|
+
|
154
|
+
class Llama4VisionEncoderLayer(nn.Module):
|
155
|
+
|
156
|
+
def __init__(
|
157
|
+
self,
|
158
|
+
config: Llama4VisionConfig,
|
159
|
+
quant_config: Optional[QuantizationConfig],
|
160
|
+
prefix: str = "",
|
161
|
+
use_data_parallel: bool = False,
|
162
|
+
):
|
163
|
+
super().__init__()
|
164
|
+
self.hidden_size = config.hidden_size
|
165
|
+
self.num_attention_heads = config.num_attention_heads
|
166
|
+
self.intermediate_size = config.intermediate_size
|
167
|
+
|
168
|
+
self.self_attn = VisionAttention(
|
169
|
+
self.hidden_size,
|
170
|
+
self.num_attention_heads,
|
171
|
+
self.hidden_size,
|
172
|
+
use_qkv_parallel=True,
|
173
|
+
# vision_model is explicitly ignored in Maverick-17B-128E-Instruct-FP8
|
174
|
+
quant_config=None,
|
175
|
+
dropout=0.0,
|
176
|
+
qkv_backend="sdpa",
|
177
|
+
softmax_in_single_precision=False,
|
178
|
+
flatten_batch=False,
|
179
|
+
prefix=add_prefix("self_attn", prefix),
|
180
|
+
qkv_bias=True,
|
181
|
+
customized_position_embedding_applier=apply_position_embedding,
|
182
|
+
)
|
183
|
+
self.mlp = Llama4VisionMLP(
|
184
|
+
input_size=config.hidden_size,
|
185
|
+
intermediate_size=config.intermediate_size,
|
186
|
+
output_size=config.hidden_size,
|
187
|
+
bias=True,
|
188
|
+
output_activation=False,
|
189
|
+
quant_config=quant_config,
|
190
|
+
prefix=f"{prefix}.mlp",
|
191
|
+
use_data_parallel=use_data_parallel,
|
192
|
+
)
|
193
|
+
|
194
|
+
self.input_layernorm = nn.LayerNorm(config.hidden_size)
|
195
|
+
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
|
196
|
+
|
197
|
+
def forward(
|
198
|
+
self,
|
199
|
+
hidden_state: torch.Tensor,
|
200
|
+
freqs_ci: torch.Tensor,
|
201
|
+
):
|
202
|
+
# Self Attention
|
203
|
+
residual = hidden_state
|
204
|
+
hidden_state = self.input_layernorm(hidden_state)
|
205
|
+
hidden_state = self.self_attn(hidden_state, position_embeddings=freqs_ci)
|
206
|
+
hidden_state = residual + hidden_state
|
207
|
+
|
208
|
+
# Feed forward
|
209
|
+
residual = hidden_state
|
210
|
+
hidden_state = self.post_attention_layernorm(hidden_state)
|
211
|
+
hidden_state = self.mlp(hidden_state)
|
212
|
+
hidden_state = residual + hidden_state
|
213
|
+
|
214
|
+
outputs = hidden_state
|
215
|
+
return outputs
|
216
|
+
|
217
|
+
|
218
|
+
class Llama4VisionEncoder(nn.Module):
|
219
|
+
|
220
|
+
def __init__(
|
221
|
+
self,
|
222
|
+
config: Llama4VisionConfig,
|
223
|
+
quant_config: Optional[QuantizationConfig],
|
224
|
+
prefix: str = "",
|
225
|
+
use_data_parallel: bool = False,
|
226
|
+
):
|
227
|
+
super().__init__()
|
228
|
+
self.config = config
|
229
|
+
self.layers = nn.ModuleList(
|
230
|
+
[
|
231
|
+
Llama4VisionEncoderLayer(
|
232
|
+
config,
|
233
|
+
quant_config=quant_config,
|
234
|
+
prefix=f"{prefix}.layers.{layer_idx}",
|
235
|
+
use_data_parallel=use_data_parallel,
|
236
|
+
)
|
237
|
+
for layer_idx in range(config.num_hidden_layers)
|
238
|
+
]
|
239
|
+
)
|
240
|
+
|
241
|
+
def forward(
|
242
|
+
self,
|
243
|
+
hidden_states: torch.Tensor,
|
244
|
+
freqs_ci: torch.Tensor, # TODO: move this to an attribute instead of keeping it around
|
245
|
+
) -> torch.Tensor:
|
246
|
+
r"""
|
247
|
+
Args:
|
248
|
+
hidden_states (`torch.FloatTensor` of shape
|
249
|
+
`(batch_size, sequence_length, hidden_size)`):
|
250
|
+
Optionally, instead of passing `input_ids` you can choose to
|
251
|
+
directly pass an embedded representation. This is useful if you
|
252
|
+
want more control over how to convert `input_ids` indices into
|
253
|
+
associated vectors than the model's internal embedding
|
254
|
+
lookup matrix.
|
255
|
+
"""
|
256
|
+
|
257
|
+
for encoder_layer in self.layers:
|
258
|
+
layer_outputs = encoder_layer(hidden_states, freqs_ci=freqs_ci)
|
259
|
+
hidden_states = layer_outputs
|
260
|
+
|
261
|
+
return hidden_states
|
262
|
+
|
263
|
+
|
264
|
+
class Llama4UnfoldConvolution(nn.Module):
|
265
|
+
|
266
|
+
def __init__(
|
267
|
+
self,
|
268
|
+
config: Llama4VisionConfig,
|
269
|
+
quant_config: Optional[QuantizationConfig] = None,
|
270
|
+
prefix: str = "",
|
271
|
+
use_data_parallel: bool = False,
|
272
|
+
):
|
273
|
+
super().__init__()
|
274
|
+
kernel_size = config.patch_size
|
275
|
+
if isinstance(kernel_size, int):
|
276
|
+
kernel_size = (kernel_size, kernel_size)
|
277
|
+
self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
|
278
|
+
params = {
|
279
|
+
"input_size": config.num_channels * kernel_size[0] * kernel_size[1],
|
280
|
+
"output_size": config.hidden_size,
|
281
|
+
"bias": False,
|
282
|
+
"quant_config": quant_config,
|
283
|
+
"prefix": f"{prefix}.linear",
|
284
|
+
}
|
285
|
+
if use_data_parallel:
|
286
|
+
cls = ReplicatedLinear
|
287
|
+
else:
|
288
|
+
cls = ColumnParallelLinear
|
289
|
+
params["gather_output"] = True
|
290
|
+
self.linear = cls(**params)
|
291
|
+
|
292
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
293
|
+
hidden_states = self.unfold(hidden_states)
|
294
|
+
hidden_states = hidden_states.permute(0, 2, 1)
|
295
|
+
hidden_states, _ = self.linear(hidden_states)
|
296
|
+
return hidden_states
|
297
|
+
|
298
|
+
|
299
|
+
class Llama4VisionRotaryEmbedding(nn.Module):
|
300
|
+
def __init__(self, config):
|
301
|
+
super().__init__()
|
302
|
+
idx = config.image_size // config.patch_size
|
303
|
+
img_idx = torch.arange(idx**2, dtype=torch.int32).reshape(idx**2, 1)
|
304
|
+
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
|
305
|
+
img_idx[-1, -1] = -2 # ID_CLS_TOKEN
|
306
|
+
frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x
|
307
|
+
frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y
|
308
|
+
freq_dim = config.hidden_size // config.num_attention_heads // 2
|
309
|
+
rope_freq = 1.0 / (
|
310
|
+
config.rope_theta
|
311
|
+
** (torch.arange(0, freq_dim, 2)[: (freq_dim // 2)].float() / freq_dim)
|
312
|
+
)
|
313
|
+
freqs_x = (
|
314
|
+
(frequencies_x + 1)[..., None] * rope_freq[None, None, :]
|
315
|
+
).repeat_interleave(2, dim=-1)
|
316
|
+
freqs_y = (
|
317
|
+
(frequencies_y + 1)[..., None] * rope_freq[None, None, :]
|
318
|
+
).repeat_interleave(2, dim=-1)
|
319
|
+
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
|
320
|
+
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
|
321
|
+
freq_cis = torch.view_as_complex(
|
322
|
+
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
|
323
|
+
)
|
324
|
+
self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2
|
325
|
+
|
326
|
+
def forward(self, hidden_states):
|
327
|
+
return self.freqs_ci.to(hidden_states.device)
|
328
|
+
|
329
|
+
|
330
|
+
class Llama4VisionModel(nn.Module):
|
331
|
+
|
332
|
+
def __init__(
|
333
|
+
self,
|
334
|
+
config: Llama4VisionConfig,
|
335
|
+
quant_config: Optional[QuantizationConfig] = None,
|
336
|
+
prefix: str = "",
|
337
|
+
):
|
338
|
+
super().__init__()
|
339
|
+
self.config = config
|
340
|
+
self.image_size = config.image_size
|
341
|
+
self.patch_size = config.patch_size
|
342
|
+
self.hidden_size = config.hidden_size
|
343
|
+
self.num_channels = config.num_channels
|
344
|
+
|
345
|
+
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
|
346
|
+
self.scale = config.hidden_size**-0.5
|
347
|
+
|
348
|
+
self.patch_embedding = Llama4UnfoldConvolution(
|
349
|
+
config,
|
350
|
+
quant_config=quant_config,
|
351
|
+
prefix=f"{prefix}.patch_embedding",
|
352
|
+
)
|
353
|
+
|
354
|
+
self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
|
355
|
+
self.positional_embedding_vlm = nn.Parameter(
|
356
|
+
self.scale * torch.randn(self.num_patches, self.hidden_size)
|
357
|
+
)
|
358
|
+
|
359
|
+
self.rotary_embedding = Llama4VisionRotaryEmbedding(config)
|
360
|
+
|
361
|
+
# layer norms
|
362
|
+
self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5)
|
363
|
+
self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5)
|
364
|
+
|
365
|
+
# encoders
|
366
|
+
self.model = Llama4VisionEncoder(
|
367
|
+
config,
|
368
|
+
quant_config=quant_config,
|
369
|
+
prefix=f"{prefix}.model",
|
370
|
+
)
|
371
|
+
self.vision_adapter = Llama4VisionPixelShuffleMLP(
|
372
|
+
config,
|
373
|
+
quant_config,
|
374
|
+
prefix=f"{prefix}.vision_adapter",
|
375
|
+
)
|
376
|
+
|
377
|
+
def forward(
|
378
|
+
self,
|
379
|
+
pixel_values: torch.Tensor,
|
380
|
+
) -> torch.Tensor:
|
381
|
+
# Patch embedding
|
382
|
+
hidden_state = self.patch_embedding(pixel_values)
|
383
|
+
num_tiles, num_patches, hidden_dim = hidden_state.shape
|
384
|
+
|
385
|
+
# Add cls token
|
386
|
+
class_embedding = self.class_embedding.expand(
|
387
|
+
hidden_state.shape[0], 1, hidden_state.shape[-1]
|
388
|
+
)
|
389
|
+
hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
|
390
|
+
num_patches += 1
|
391
|
+
|
392
|
+
# Position embeddings
|
393
|
+
hidden_state = hidden_state.reshape(
|
394
|
+
num_tiles,
|
395
|
+
1,
|
396
|
+
num_patches,
|
397
|
+
hidden_dim,
|
398
|
+
)
|
399
|
+
positional_embedding = self.positional_embedding_vlm.to(
|
400
|
+
dtype=hidden_state.dtype, device=hidden_state.device
|
401
|
+
)
|
402
|
+
hidden_state = hidden_state + positional_embedding
|
403
|
+
hidden_state = self.layernorm_pre(hidden_state)
|
404
|
+
hidden_state = hidden_state.view(num_tiles, -1, hidden_dim)
|
405
|
+
freqs_ci = self.rotary_embedding(pixel_values)
|
406
|
+
# Apply encoder
|
407
|
+
hidden_state = self.model(hidden_state, freqs_ci=freqs_ci)
|
408
|
+
hidden_state = self.layernorm_post(hidden_state)
|
409
|
+
|
410
|
+
# Remove CLS token output
|
411
|
+
hidden_state = hidden_state[:, :-1, :]
|
412
|
+
|
413
|
+
# now, we use Llama4VisionPixelShuffle + mlp to project embeddings
|
414
|
+
hidden_state = self.vision_adapter(hidden_state)
|
415
|
+
|
416
|
+
return hidden_state
|
417
|
+
|
418
|
+
|
42
419
|
class Llama4ForConditionalGeneration(nn.Module):
|
43
420
|
packed_modules_mapping = {
|
44
421
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
@@ -60,7 +437,8 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
60
437
|
if not self.has_vision_weights:
|
61
438
|
logger.warning(
|
62
439
|
"No vision weights found in checkpoint. Model will run in text-only mode. "
|
63
|
-
"Multimodal capabilities (
|
440
|
+
"Multimodal capabilities (vision understanding) will be unavailable. "
|
441
|
+
"Please not that this warning might be inaccurate if the weights haven't been fully downloaded"
|
64
442
|
)
|
65
443
|
|
66
444
|
self.has_vision = (
|
@@ -68,7 +446,12 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
68
446
|
)
|
69
447
|
|
70
448
|
if self.has_vision:
|
71
|
-
self.vision_model = Llama4VisionModel(
|
449
|
+
self.vision_model = Llama4VisionModel(
|
450
|
+
config.vision_config,
|
451
|
+
quant_config=quant_config,
|
452
|
+
prefix=add_prefix("vision_model", prefix),
|
453
|
+
)
|
454
|
+
|
72
455
|
self.multi_modal_projector = Llama4MultiModalProjector(config)
|
73
456
|
else:
|
74
457
|
self.vision_model = None
|
@@ -112,7 +495,6 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
112
495
|
filename="model.safetensors.index.json",
|
113
496
|
cache_dir=None,
|
114
497
|
)
|
115
|
-
|
116
498
|
if index_file_path and os.path.exists(index_file_path):
|
117
499
|
return self._check_vision_weights_in_index(index_file_path)
|
118
500
|
|
@@ -120,7 +502,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
120
502
|
# If we can't access the cache, fall back to config-based detection
|
121
503
|
pass
|
122
504
|
|
123
|
-
# Fallback
|
505
|
+
# Fallback, assume text-only
|
124
506
|
return False
|
125
507
|
|
126
508
|
def _check_vision_weights_in_index(self, index_file: str) -> bool:
|
@@ -131,7 +513,6 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
131
513
|
|
132
514
|
vision_patterns = ["vision_model", "vision_tower", "multi_modal_projector"]
|
133
515
|
weight_names = index_data.get("weight_map", {}).keys()
|
134
|
-
|
135
516
|
return any(
|
136
517
|
pattern in weight_name
|
137
518
|
for weight_name in weight_names
|
@@ -150,17 +531,17 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
150
531
|
# For text-only models, return None or raise an error
|
151
532
|
if not self.has_vision or self.vision_model is None:
|
152
533
|
raise ValueError("Vision model not available for text-only checkpoint")
|
153
|
-
|
154
534
|
pixel_values = (
|
155
535
|
torch.concat([item.feature for item in items])
|
156
536
|
.to(next(self.vision_model.parameters()).device)
|
157
537
|
.type(next(self.vision_model.parameters()).dtype)
|
158
538
|
)
|
539
|
+
image_features = self.vision_model(pixel_values)
|
159
540
|
|
160
|
-
image_outputs = self.vision_model(pixel_values, output_hidden_states=False)
|
161
|
-
image_features = image_outputs.last_hidden_state
|
162
541
|
vision_flat = image_features.view(-1, image_features.size(-1))
|
542
|
+
|
163
543
|
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
544
|
+
|
164
545
|
return projected_vision_flat
|
165
546
|
|
166
547
|
def forward(
|
@@ -246,31 +627,47 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
246
627
|
num_experts=num_experts,
|
247
628
|
)
|
248
629
|
|
630
|
+
loaded_params = set()
|
631
|
+
|
249
632
|
for name, loaded_weight in weights:
|
250
633
|
if self._should_skip_weight(name):
|
251
634
|
continue
|
252
635
|
|
253
636
|
name = self._transform_weight_name(name)
|
254
637
|
|
255
|
-
if "vision"
|
638
|
+
if "vision" in name:
|
639
|
+
name = name.replace(".self_attn.o_proj", ".self_attn.proj")
|
640
|
+
else:
|
256
641
|
name, loaded_weight = self.permute_qk_weight_for_rotary(
|
257
642
|
name, loaded_weight
|
258
643
|
)
|
259
644
|
|
260
645
|
if self._handle_scale_remapping(name, params_dict):
|
646
|
+
loaded_params.add(name)
|
261
647
|
continue
|
262
648
|
|
263
649
|
if self._handle_stacked_params(
|
264
|
-
name, loaded_weight, stacked_params_mapping, params_dict
|
650
|
+
name, loaded_weight, stacked_params_mapping, params_dict, loaded_params
|
265
651
|
):
|
266
652
|
continue
|
267
653
|
|
268
654
|
if self._handle_expert_weights(
|
269
|
-
name,
|
655
|
+
name,
|
656
|
+
loaded_weight,
|
657
|
+
expert_params_mapping,
|
658
|
+
params_dict,
|
659
|
+
num_experts,
|
660
|
+
loaded_params,
|
270
661
|
):
|
271
662
|
continue
|
272
663
|
|
664
|
+
loaded_params.add(name)
|
273
665
|
self._handle_default_weight(name, loaded_weight, params_dict)
|
666
|
+
unloaded_params = params_dict.keys() - loaded_params
|
667
|
+
if unloaded_params:
|
668
|
+
logger.warning(
|
669
|
+
f"Some weights are not initialized from checkpoints {unloaded_params}"
|
670
|
+
)
|
274
671
|
|
275
672
|
def _should_skip_weight(self, name: str) -> bool:
|
276
673
|
"""Check if we should skip loading this weight."""
|
@@ -301,11 +698,13 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
301
698
|
loaded_weight: torch.Tensor,
|
302
699
|
stacked_params_mapping: list,
|
303
700
|
params_dict: dict,
|
701
|
+
loaded_params: set,
|
304
702
|
) -> bool:
|
305
703
|
"""Handle stacked parameter loading. Returns True if handled."""
|
306
704
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
307
|
-
if weight_name in name
|
705
|
+
if weight_name in name:
|
308
706
|
transformed_name = name.replace(weight_name, param_name)
|
707
|
+
loaded_params.add(transformed_name)
|
309
708
|
param = params_dict[transformed_name]
|
310
709
|
param.weight_loader(param, loaded_weight, shard_id)
|
311
710
|
return True
|
@@ -318,6 +717,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
318
717
|
expert_params_mapping: list,
|
319
718
|
params_dict: dict,
|
320
719
|
num_experts: int,
|
720
|
+
loaded_params: set,
|
321
721
|
) -> bool:
|
322
722
|
"""Handle expert weight loading for MoE (Mixture of Experts) layers.
|
323
723
|
|
@@ -336,16 +736,16 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
336
736
|
|
337
737
|
if "experts.gate_up_proj" not in name and "experts.down_proj" not in name:
|
338
738
|
return self._handle_other_expert_params(
|
339
|
-
name, loaded_weight, expert_params_mapping, params_dict
|
739
|
+
name, loaded_weight, expert_params_mapping, params_dict, loaded_params
|
340
740
|
)
|
341
741
|
|
342
742
|
if "scale" in name:
|
343
743
|
return self._handle_expert_scale_params(
|
344
|
-
name, loaded_weight, params_dict, num_experts
|
744
|
+
name, loaded_weight, params_dict, num_experts, loaded_params
|
345
745
|
)
|
346
746
|
else:
|
347
747
|
return self._handle_expert_weight_params(
|
348
|
-
name, loaded_weight, params_dict, num_experts
|
748
|
+
name, loaded_weight, params_dict, num_experts, loaded_params
|
349
749
|
)
|
350
750
|
|
351
751
|
def _handle_other_expert_params(
|
@@ -354,6 +754,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
354
754
|
loaded_weight: torch.Tensor,
|
355
755
|
expert_params_mapping: list,
|
356
756
|
params_dict: dict,
|
757
|
+
loaded_params: set,
|
357
758
|
) -> bool:
|
358
759
|
"""Handle expert parameters that are not gate_up_proj or down_proj weights.
|
359
760
|
|
@@ -362,6 +763,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
362
763
|
loaded_weight: The weight tensor to be loaded
|
363
764
|
expert_params_mapping: List of tuples mapping checkpoint names to model parameters
|
364
765
|
params_dict: Dictionary of model parameters
|
766
|
+
loaded_params: Set of loaded parameter names
|
365
767
|
|
366
768
|
Returns:
|
367
769
|
bool: True if parameter was found and handled, False otherwise
|
@@ -373,6 +775,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
373
775
|
param.weight_loader(
|
374
776
|
param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id
|
375
777
|
)
|
778
|
+
loaded_params.add(transformed_name)
|
376
779
|
return True
|
377
780
|
return False
|
378
781
|
|
@@ -411,6 +814,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
411
814
|
loaded_weight: torch.Tensor,
|
412
815
|
params_dict: dict,
|
413
816
|
num_experts: int,
|
817
|
+
loaded_params: set,
|
414
818
|
) -> bool:
|
415
819
|
"""Handle quantization scale parameters for expert weights.
|
416
820
|
|
@@ -419,6 +823,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
419
823
|
loaded_weight: Scale tensor to be loaded
|
420
824
|
params_dict: Dictionary of model parameters
|
421
825
|
num_experts: Total number of experts for broadcast operations
|
826
|
+
loaded_params: Set of loaded parameter names
|
422
827
|
|
423
828
|
Returns:
|
424
829
|
bool: True (always handles scale parameters)
|
@@ -447,6 +852,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
447
852
|
# Load the same scale for all experts
|
448
853
|
for expert_id in range(num_experts):
|
449
854
|
param.data[expert_id] = loaded_weight
|
855
|
+
loaded_params.add(transformed_name)
|
450
856
|
|
451
857
|
return True
|
452
858
|
|
@@ -456,6 +862,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
456
862
|
loaded_weight: torch.Tensor,
|
457
863
|
params_dict: dict,
|
458
864
|
num_experts: int,
|
865
|
+
loaded_params: set,
|
459
866
|
) -> bool:
|
460
867
|
"""Handle actual weight tensors for expert layers (gate_up_proj and down_proj).
|
461
868
|
|
@@ -464,6 +871,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
464
871
|
loaded_weight: Weight tensor(s) to be loaded
|
465
872
|
params_dict: Dictionary of model parameters
|
466
873
|
num_experts: Total number of experts for tensor distribution
|
874
|
+
loaded_params: Set of loaded parameter names
|
467
875
|
|
468
876
|
Returns:
|
469
877
|
bool: True (always handles weight parameters)
|
@@ -486,6 +894,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
486
894
|
|
487
895
|
param = params_dict[param_name]
|
488
896
|
weight_loader = param.weight_loader
|
897
|
+
loaded_params.add(param_name)
|
489
898
|
|
490
899
|
# Handle the case where loaded_weight might be a single tensor for all experts
|
491
900
|
if weight_chunk.dim() == 2:
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -148,7 +148,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
148
148
|
**(
|
149
149
|
dict(
|
150
150
|
enable_flashinfer_cutlass_moe=True,
|
151
|
-
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
152
151
|
)
|
153
152
|
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
154
153
|
else {}
|
@@ -616,9 +615,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
616
615
|
("gate_up_proj", "up_proj", 1),
|
617
616
|
]
|
618
617
|
|
619
|
-
|
620
|
-
|
621
|
-
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
618
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
622
619
|
ckpt_gate_proj_name="gate_proj",
|
623
620
|
ckpt_down_proj_name="down_proj",
|
624
621
|
ckpt_up_proj_name="up_proj",
|