sglang 0.4.5__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +330 -200
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +38 -21
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +12 -5
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +25 -13
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +1 -0
- sglang/srt/layers/radix_attention.py +13 -1
- sglang/srt/layers/rotary_embedding.py +12 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +48 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +1 -0
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +120 -106
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/llama4.py
CHANGED
@@ -27,6 +27,13 @@ from sglang.srt.distributed import (
|
|
27
27
|
get_tensor_model_parallel_world_size,
|
28
28
|
tensor_model_parallel_all_reduce,
|
29
29
|
)
|
30
|
+
from sglang.srt.layers.dp_attention import (
|
31
|
+
dp_gather_partial,
|
32
|
+
dp_scatter,
|
33
|
+
get_attention_dp_size,
|
34
|
+
get_attention_tp_rank,
|
35
|
+
get_attention_tp_size,
|
36
|
+
)
|
30
37
|
from sglang.srt.layers.layernorm import RMSNorm
|
31
38
|
from sglang.srt.layers.linear import (
|
32
39
|
QKVParallelLinear,
|
@@ -38,9 +45,10 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
38
45
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
46
|
from sglang.srt.layers.rotary_embedding import get_rope
|
40
47
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
48
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
41
49
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
50
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
43
|
-
from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers
|
51
|
+
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
|
44
52
|
|
45
53
|
logger = logging.getLogger(__name__)
|
46
54
|
|
@@ -55,7 +63,7 @@ class Llama4MoE(nn.Module):
|
|
55
63
|
topk: int,
|
56
64
|
renormalize: bool,
|
57
65
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
58
|
-
router_scores_aK, router_indices_aK =
|
66
|
+
router_scores_aK, router_indices_aK = fast_topk(gating_output, topk, dim=-1)
|
59
67
|
router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
|
60
68
|
hidden_states.dtype
|
61
69
|
)
|
@@ -143,20 +151,24 @@ class Llama4Attention(nn.Module):
|
|
143
151
|
self.hidden_size = hidden_size
|
144
152
|
self.use_rope = int((layer_id + 1) % 4 != 0)
|
145
153
|
self.use_qk_norm = config.use_qk_norm and self.use_rope
|
146
|
-
|
154
|
+
|
155
|
+
self.dp_size = get_attention_dp_size()
|
156
|
+
attn_tp_rank = get_attention_tp_rank()
|
157
|
+
attn_tp_size = get_attention_tp_size()
|
158
|
+
|
147
159
|
self.total_num_heads = num_heads
|
148
|
-
assert self.total_num_heads %
|
149
|
-
self.num_heads = self.total_num_heads //
|
160
|
+
assert self.total_num_heads % attn_tp_size == 0
|
161
|
+
self.num_heads = self.total_num_heads // attn_tp_size
|
150
162
|
self.total_num_kv_heads = num_kv_heads
|
151
|
-
if self.total_num_kv_heads >=
|
163
|
+
if self.total_num_kv_heads >= attn_tp_size:
|
152
164
|
# Number of KV heads is greater than TP size, so we partition
|
153
165
|
# the KV heads across multiple tensor parallel GPUs.
|
154
|
-
assert self.total_num_kv_heads %
|
166
|
+
assert self.total_num_kv_heads % attn_tp_size == 0
|
155
167
|
else:
|
156
168
|
# Number of KV heads is less than TP size, so we replicate
|
157
169
|
# the KV heads across multiple tensor parallel GPUs.
|
158
|
-
assert
|
159
|
-
self.num_kv_heads = max(1, self.total_num_kv_heads //
|
170
|
+
assert attn_tp_size % self.total_num_kv_heads == 0
|
171
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
|
160
172
|
self.head_dim = config.head_dim
|
161
173
|
self.q_size = self.num_heads * self.head_dim
|
162
174
|
self.kv_size = self.num_kv_heads * self.head_dim
|
@@ -183,6 +195,8 @@ class Llama4Attention(nn.Module):
|
|
183
195
|
bias=bias,
|
184
196
|
quant_config=quant_config,
|
185
197
|
prefix=add_prefix("qkv_proj", prefix),
|
198
|
+
tp_rank=attn_tp_rank,
|
199
|
+
tp_size=attn_tp_size,
|
186
200
|
)
|
187
201
|
|
188
202
|
self.o_proj = RowParallelLinear(
|
@@ -191,6 +205,9 @@ class Llama4Attention(nn.Module):
|
|
191
205
|
bias=bias_o_proj,
|
192
206
|
quant_config=quant_config,
|
193
207
|
prefix=add_prefix("o_proj", prefix),
|
208
|
+
tp_rank=attn_tp_rank,
|
209
|
+
tp_size=attn_tp_size,
|
210
|
+
reduce_results=False,
|
194
211
|
)
|
195
212
|
is_neox_style = True
|
196
213
|
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
@@ -223,9 +240,13 @@ class Llama4Attention(nn.Module):
|
|
223
240
|
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
224
241
|
floor = torch.floor((positions + 1.0) / self.floor_scale)
|
225
242
|
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
|
226
|
-
|
227
243
|
return attn_scale.unsqueeze(-1)
|
228
244
|
|
245
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
246
|
+
def _mul_attn_scale(self, positions, q):
|
247
|
+
attn_scale = self._get_attn_scale(positions)
|
248
|
+
return (q * attn_scale).to(q.dtype)
|
249
|
+
|
229
250
|
def forward(
|
230
251
|
self,
|
231
252
|
positions: torch.Tensor,
|
@@ -233,27 +254,29 @@ class Llama4Attention(nn.Module):
|
|
233
254
|
forward_batch: ForwardBatch,
|
234
255
|
) -> torch.Tensor:
|
235
256
|
qkv, _ = self.qkv_proj(hidden_states)
|
236
|
-
|
257
|
+
|
258
|
+
qk, v = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
|
237
259
|
|
238
260
|
if self.rotary_emb is not None:
|
239
|
-
|
261
|
+
q_view, k_view = qk.split([self.q_size, self.kv_size], dim=-1)
|
262
|
+
q_out_unused, k_out_unused = self.rotary_emb(positions, q_view, k_view)
|
263
|
+
assert (q_out_unused is q_view) and (k_out_unused is k_view)
|
264
|
+
del q_view, k_view, q_out_unused, k_out_unused
|
240
265
|
|
241
266
|
if self.qk_norm is not None:
|
242
|
-
# TODO
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
k = k.reshape(-1, self.kv_size)
|
267
|
+
# TODO there are still 2 redundant direct_copy_kernel_cuda for this `reshape` and (in attn backend) q.contiguous(), maybe we can fuse them later
|
268
|
+
qk = qk.reshape(-1, self.head_dim).contiguous().bfloat16()
|
269
|
+
qk = self.qk_norm(qk).to(torch.bfloat16)
|
270
|
+
qk = qk.reshape(-1, self.q_size + self.kv_size)
|
271
|
+
|
272
|
+
q, k = qk.split([self.q_size, self.kv_size], dim=-1)
|
249
273
|
|
250
274
|
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
|
251
275
|
# the inference-time temperature tuning function is customized to not affect short context
|
252
276
|
# while working at very long context
|
253
277
|
# https://arxiv.org/abs/2501.19399
|
254
278
|
if self.attn_temperature_tuning and not self.use_rope:
|
255
|
-
|
256
|
-
q = (q * attn_scale).to(q.dtype)
|
279
|
+
q = self._mul_attn_scale(positions=positions, q=q)
|
257
280
|
|
258
281
|
attn_output = self.attn(q, k, v, forward_batch)
|
259
282
|
output, _ = self.o_proj(attn_output)
|
@@ -274,6 +297,9 @@ class Llama4DecoderLayer(nn.Module):
|
|
274
297
|
rope_theta = config.rope_theta
|
275
298
|
rope_scaling = config.rope_scaling
|
276
299
|
max_position_embeddings = config.max_position_embeddings
|
300
|
+
self.dp_size = get_attention_dp_size()
|
301
|
+
self.attn_tp_size = get_attention_tp_size()
|
302
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
277
303
|
|
278
304
|
self.self_attn = Llama4Attention(
|
279
305
|
config=config,
|
@@ -316,21 +342,58 @@ class Llama4DecoderLayer(nn.Module):
|
|
316
342
|
forward_batch: ForwardBatch,
|
317
343
|
residual: Optional[torch.Tensor],
|
318
344
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
319
|
-
|
320
|
-
if residual is None:
|
345
|
+
if hidden_states.shape[0] == 0:
|
321
346
|
residual = hidden_states
|
322
|
-
hidden_states = self.input_layernorm(hidden_states)
|
323
347
|
else:
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
348
|
+
# Self Attention
|
349
|
+
if residual is None:
|
350
|
+
residual = hidden_states
|
351
|
+
hidden_states = self.input_layernorm(hidden_states)
|
352
|
+
else:
|
353
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
354
|
+
hidden_states = self.self_attn(
|
355
|
+
positions=positions,
|
356
|
+
hidden_states=hidden_states,
|
357
|
+
forward_batch=forward_batch,
|
358
|
+
)
|
359
|
+
|
360
|
+
# Gather
|
361
|
+
if get_tensor_model_parallel_world_size() > 1:
|
362
|
+
# all gather and all reduce
|
363
|
+
if self.dp_size != 1:
|
364
|
+
if self.attn_tp_rank == 0:
|
365
|
+
hidden_states += residual
|
366
|
+
hidden_states, local_hidden_states = (
|
367
|
+
forward_batch.gathered_buffer,
|
368
|
+
hidden_states,
|
369
|
+
)
|
370
|
+
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
371
|
+
dp_scatter(residual, hidden_states, forward_batch)
|
372
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
373
|
+
else:
|
374
|
+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
375
|
+
hidden_states, residual = self.post_attention_layernorm(
|
376
|
+
hidden_states, residual
|
377
|
+
)
|
378
|
+
else:
|
379
|
+
hidden_states, residual = self.post_attention_layernorm(
|
380
|
+
hidden_states, residual
|
381
|
+
)
|
330
382
|
|
331
383
|
# Fully Connected
|
332
|
-
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
333
384
|
hidden_states = self.feed_forward(hidden_states)
|
385
|
+
|
386
|
+
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
|
387
|
+
# Scatter
|
388
|
+
if self.dp_size != 1:
|
389
|
+
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
390
|
+
# be careful about this!
|
391
|
+
hidden_states, global_hidden_states = (
|
392
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
393
|
+
hidden_states,
|
394
|
+
)
|
395
|
+
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
396
|
+
|
334
397
|
return hidden_states, residual
|
335
398
|
|
336
399
|
|
@@ -350,13 +413,14 @@ class Llama4Model(nn.Module):
|
|
350
413
|
config.hidden_size,
|
351
414
|
quant_config=quant_config,
|
352
415
|
prefix=add_prefix("embed_tokens", prefix),
|
416
|
+
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
353
417
|
)
|
354
418
|
self.layers = make_layers(
|
355
419
|
config.num_hidden_layers,
|
356
420
|
lambda idx, prefix: Llama4DecoderLayer(
|
357
421
|
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
|
358
422
|
),
|
359
|
-
prefix="
|
423
|
+
prefix=add_prefix("layers", prefix),
|
360
424
|
)
|
361
425
|
|
362
426
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
@@ -385,7 +449,8 @@ class Llama4Model(nn.Module):
|
|
385
449
|
forward_batch,
|
386
450
|
residual,
|
387
451
|
)
|
388
|
-
|
452
|
+
if not forward_batch.forward_mode.is_idle():
|
453
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
389
454
|
|
390
455
|
if len(aux_hidden_states) == 0:
|
391
456
|
return hidden_states
|
@@ -394,7 +459,6 @@ class Llama4Model(nn.Module):
|
|
394
459
|
|
395
460
|
|
396
461
|
class Llama4ForCausalLM(LlamaForCausalLM):
|
397
|
-
|
398
462
|
packed_modules_mapping = {
|
399
463
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
400
464
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
@@ -408,6 +472,9 @@ class Llama4ForCausalLM(LlamaForCausalLM):
|
|
408
472
|
):
|
409
473
|
super().__init__(config, quant_config, prefix)
|
410
474
|
|
475
|
+
def get_input_embeddings(self):
|
476
|
+
return self.model.embed_tokens
|
477
|
+
|
411
478
|
def _init_model(
|
412
479
|
self,
|
413
480
|
config: Llama4TextConfig,
|
sglang/srt/models/minicpm.py
CHANGED
sglang/srt/models/minicpm3.py
CHANGED
@@ -192,6 +192,7 @@ class MiniCPM3Attention(nn.Module):
|
|
192
192
|
self.scaling,
|
193
193
|
num_kv_heads=self.num_local_heads,
|
194
194
|
layer_id=layer_id,
|
195
|
+
quant_config=quant_config,
|
195
196
|
prefix=add_prefix("attn", prefix),
|
196
197
|
)
|
197
198
|
|
@@ -343,6 +344,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
|
343
344
|
num_kv_heads=1,
|
344
345
|
layer_id=layer_id,
|
345
346
|
v_head_dim=self.kv_lora_rank,
|
347
|
+
quant_config=quant_config,
|
346
348
|
prefix=add_prefix("attn", prefix),
|
347
349
|
)
|
348
350
|
|
sglang/srt/models/mixtral.py
CHANGED
sglang/srt/models/mllama.py
CHANGED
@@ -22,6 +22,7 @@ from sglang.srt.layers.layernorm import RMSNorm
|
|
22
22
|
from sglang.srt.layers.linear import (
|
23
23
|
ColumnParallelLinear,
|
24
24
|
QKVParallelLinear,
|
25
|
+
ReplicatedLinear,
|
25
26
|
RowParallelLinear,
|
26
27
|
)
|
27
28
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
@@ -184,6 +185,7 @@ class MllamaVisionEncoderLayer(nn.Module):
|
|
184
185
|
def __init__(
|
185
186
|
self,
|
186
187
|
config: config_mllama.MllamaVisionConfig,
|
188
|
+
quant_config: Optional[QuantizationConfig] = None,
|
187
189
|
is_gated: bool = False,
|
188
190
|
prefix: str = "",
|
189
191
|
):
|
@@ -199,14 +201,16 @@ class MllamaVisionEncoderLayer(nn.Module):
|
|
199
201
|
self.num_attention_heads,
|
200
202
|
self.hidden_size,
|
201
203
|
use_qkv_parallel=True,
|
202
|
-
quant_config=
|
204
|
+
quant_config=quant_config,
|
203
205
|
dropout=0.0,
|
204
206
|
use_context_forward=False,
|
205
207
|
softmax_in_single_precision=False,
|
206
208
|
flatten_batch=False,
|
207
209
|
prefix=add_prefix("self_attn", prefix),
|
208
210
|
)
|
209
|
-
self.mlp = MllamaVisionMLP(
|
211
|
+
self.mlp = MllamaVisionMLP(
|
212
|
+
config, quant_config, prefix=add_prefix("mlp", prefix)
|
213
|
+
)
|
210
214
|
|
211
215
|
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
|
212
216
|
self.post_attention_layernorm = nn.LayerNorm(
|
@@ -244,6 +248,7 @@ class MllamaVisionEncoder(nn.Module):
|
|
244
248
|
def __init__(
|
245
249
|
self,
|
246
250
|
config: config_mllama.MllamaVisionConfig,
|
251
|
+
quant_config: Optional[QuantizationConfig] = None,
|
247
252
|
num_layers=32,
|
248
253
|
is_gated=False,
|
249
254
|
output_hidden_states=None,
|
@@ -254,7 +259,10 @@ class MllamaVisionEncoder(nn.Module):
|
|
254
259
|
self.layers = nn.ModuleList(
|
255
260
|
[
|
256
261
|
MllamaVisionEncoderLayer(
|
257
|
-
config,
|
262
|
+
config,
|
263
|
+
quant_config,
|
264
|
+
is_gated,
|
265
|
+
prefix=add_prefix(f"layers.{i}", prefix),
|
258
266
|
)
|
259
267
|
for i in range(num_layers)
|
260
268
|
]
|
@@ -283,7 +291,12 @@ class MllamaVisionEncoder(nn.Module):
|
|
283
291
|
|
284
292
|
|
285
293
|
class MllamaVisionModel(nn.Module):
|
286
|
-
def __init__(
|
294
|
+
def __init__(
|
295
|
+
self,
|
296
|
+
config: config_mllama.MllamaVisionConfig,
|
297
|
+
quant_config: Optional[QuantizationConfig] = None,
|
298
|
+
prefix: str = "",
|
299
|
+
):
|
287
300
|
super().__init__()
|
288
301
|
self.image_size = config.image_size
|
289
302
|
self.patch_size = config.patch_size
|
@@ -320,6 +333,7 @@ class MllamaVisionModel(nn.Module):
|
|
320
333
|
# encoders
|
321
334
|
self.transformer = MllamaVisionEncoder(
|
322
335
|
config,
|
336
|
+
quant_config,
|
323
337
|
config.num_hidden_layers,
|
324
338
|
is_gated=False,
|
325
339
|
output_hidden_states=config.intermediate_layers_indices,
|
@@ -327,6 +341,7 @@ class MllamaVisionModel(nn.Module):
|
|
327
341
|
)
|
328
342
|
self.global_transformer = MllamaVisionEncoder(
|
329
343
|
config,
|
344
|
+
quant_config,
|
330
345
|
config.num_global_layers,
|
331
346
|
is_gated=True,
|
332
347
|
prefix=add_prefix("global_transformer", prefix),
|
@@ -535,6 +550,7 @@ class MllamaTextCrossAttention(nn.Module):
|
|
535
550
|
self.num_local_key_value_heads,
|
536
551
|
layer_id=layer_id,
|
537
552
|
is_cross_attention=True,
|
553
|
+
quant_config=quant_config,
|
538
554
|
prefix=add_prefix("attn", prefix),
|
539
555
|
)
|
540
556
|
|
@@ -764,6 +780,27 @@ class MllamaForCausalLM(nn.Module):
|
|
764
780
|
|
765
781
|
|
766
782
|
class MllamaForConditionalGeneration(nn.Module):
|
783
|
+
# BitandBytes specific attributes
|
784
|
+
default_bitsandbytes_target_modules = [
|
785
|
+
".gate_proj.",
|
786
|
+
".down_proj.",
|
787
|
+
".up_proj.",
|
788
|
+
".q_proj.",
|
789
|
+
".k_proj.",
|
790
|
+
".v_proj.",
|
791
|
+
".o_proj.",
|
792
|
+
]
|
793
|
+
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
794
|
+
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
795
|
+
bitsandbytes_stacked_params_mapping = {
|
796
|
+
# shard_name, weight_name, index
|
797
|
+
"q_proj": ("qkv_proj", 0),
|
798
|
+
"k_proj": ("qkv_proj", 1),
|
799
|
+
"v_proj": ("qkv_proj", 2),
|
800
|
+
"gate_proj": ("gate_up_proj", 0),
|
801
|
+
"up_proj": ("gate_up_proj", 1),
|
802
|
+
}
|
803
|
+
|
767
804
|
def __init__(
|
768
805
|
self,
|
769
806
|
config: config_mllama.MllamaConfig,
|
@@ -771,6 +808,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
771
808
|
prefix: str = "",
|
772
809
|
):
|
773
810
|
super().__init__()
|
811
|
+
self.quant_config = quant_config
|
774
812
|
self.vocab_size = config.text_config.vocab_size
|
775
813
|
self.hidden_size = config.text_config.hidden_size
|
776
814
|
self.max_num_tiles = config.vision_config.max_num_tiles
|
@@ -781,17 +819,21 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
781
819
|
self.image_size = config.vision_config.image_size
|
782
820
|
|
783
821
|
self.vision_model = MllamaVisionModel(
|
784
|
-
config.vision_config,
|
822
|
+
config.vision_config,
|
823
|
+
quant_config=quant_config,
|
824
|
+
prefix=add_prefix("vision_model", prefix),
|
785
825
|
)
|
786
826
|
self.language_model = MllamaForCausalLM(
|
787
827
|
config.text_config,
|
788
828
|
quant_config=quant_config,
|
789
829
|
prefix=add_prefix("language_model", prefix),
|
790
830
|
)
|
791
|
-
self.multi_modal_projector =
|
831
|
+
self.multi_modal_projector = ReplicatedLinear(
|
792
832
|
config.vision_config.vision_output_dim,
|
793
833
|
config.text_config.hidden_size,
|
794
834
|
bias=True,
|
835
|
+
quant_config=quant_config,
|
836
|
+
prefix="multi_modal_projector",
|
795
837
|
)
|
796
838
|
self.logits_processor = LogitsProcessor(config.text_config)
|
797
839
|
self.capture_mode = False
|
@@ -958,7 +1000,9 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
958
1000
|
cross_attention_states = self.vision_model(
|
959
1001
|
batched_images, batched_ar_ids, batched_ar_mask
|
960
1002
|
)
|
961
|
-
cross_attention_states = self.multi_modal_projector(
|
1003
|
+
cross_attention_states, _ = self.multi_modal_projector(
|
1004
|
+
cross_attention_states
|
1005
|
+
)
|
962
1006
|
|
963
1007
|
bs, _, _, _, image_token_dim = cross_attention_states.shape
|
964
1008
|
cross_attention_states = cross_attention_states.view(
|
@@ -1012,7 +1056,6 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
1012
1056
|
if "vision_model" in name:
|
1013
1057
|
# adapt to VisionAttention
|
1014
1058
|
name = name.replace("self_attn.o_proj", "self_attn.proj")
|
1015
|
-
|
1016
1059
|
param = params_dict.pop(name)
|
1017
1060
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
1018
1061
|
weight_loader(param, loaded_weight)
|
sglang/srt/models/mllama4.py
CHANGED
@@ -1,13 +1,19 @@
|
|
1
|
-
# TODO: add Aapted from vllm/mllama4.py
|
2
1
|
from collections.abc import Iterable
|
3
|
-
from typing import Optional, Set, Tuple
|
2
|
+
from typing import List, Optional, Set, Tuple
|
4
3
|
|
5
4
|
import torch
|
6
5
|
from torch import nn
|
7
|
-
from transformers import Llama4Config
|
6
|
+
from transformers import Llama4Config, Llama4VisionModel
|
7
|
+
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
|
8
8
|
|
9
9
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
10
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
10
11
|
from sglang.srt.layers.quantization import QuantizationConfig
|
12
|
+
from sglang.srt.managers.mm_utils import (
|
13
|
+
MultiModalityDataPaddingPatternImageTokens,
|
14
|
+
general_mm_embed_routine,
|
15
|
+
)
|
16
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
11
17
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
12
18
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
13
19
|
from sglang.srt.utils import add_prefix
|
@@ -16,6 +22,7 @@ from sglang.srt.utils import add_prefix
|
|
16
22
|
class Llama4ForConditionalGeneration(nn.Module):
|
17
23
|
packed_modules_mapping = {
|
18
24
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
25
|
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
19
26
|
}
|
20
27
|
|
21
28
|
def __init__(
|
@@ -28,6 +35,9 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
28
35
|
self.config = config
|
29
36
|
self.quant_config = quant_config
|
30
37
|
|
38
|
+
self.vision_model = Llama4VisionModel(config.vision_config)
|
39
|
+
self.multi_modal_projector = Llama4MultiModalProjector(config)
|
40
|
+
|
31
41
|
# Initialize the language model
|
32
42
|
from sglang.srt.models.llama4 import Llama4ForCausalLM
|
33
43
|
|
@@ -39,6 +49,29 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
39
49
|
|
40
50
|
self.logits_processor = LogitsProcessor(config.text_config)
|
41
51
|
|
52
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
53
|
+
# Get all special token IDs
|
54
|
+
im_token_id: int = mm_inputs.im_token_id
|
55
|
+
|
56
|
+
pattern = MultiModalityDataPaddingPatternImageTokens(torch.tensor(im_token_id))
|
57
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
58
|
+
|
59
|
+
def get_image_feature(
|
60
|
+
self,
|
61
|
+
items: List[MultimodalDataItem],
|
62
|
+
) -> torch.Tensor:
|
63
|
+
pixel_values = (
|
64
|
+
torch.concat([item.pixel_values for item in items])
|
65
|
+
.to(next(self.vision_model.parameters()).device)
|
66
|
+
.type(next(self.vision_model.parameters()).dtype)
|
67
|
+
)
|
68
|
+
|
69
|
+
image_outputs = self.vision_model(pixel_values, output_hidden_states=False)
|
70
|
+
image_features = image_outputs.last_hidden_state
|
71
|
+
vision_flat = image_features.view(-1, image_features.size(-1))
|
72
|
+
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
73
|
+
return projected_vision_flat
|
74
|
+
|
42
75
|
def forward(
|
43
76
|
self,
|
44
77
|
input_ids: torch.Tensor,
|
@@ -47,7 +80,15 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
47
80
|
**kwargs: object,
|
48
81
|
) -> torch.Tensor:
|
49
82
|
|
50
|
-
|
83
|
+
hs = general_mm_embed_routine(
|
84
|
+
input_ids=input_ids,
|
85
|
+
forward_batch=forward_batch,
|
86
|
+
language_model=self.language_model,
|
87
|
+
image_data_embedding_func=self.get_image_feature,
|
88
|
+
positions=positions,
|
89
|
+
)
|
90
|
+
|
91
|
+
return hs
|
51
92
|
|
52
93
|
def permute_qk_weight_for_rotary(
|
53
94
|
self,
|
@@ -96,18 +137,27 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
96
137
|
|
97
138
|
num_experts = self.config.text_config.num_local_experts
|
98
139
|
|
99
|
-
for
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
140
|
+
# Params for weights, fp8 weight scales, fp8 activation scales
|
141
|
+
# (param_name, weight_name, expert_id, shard_id)
|
142
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
143
|
+
ckpt_gate_proj_name="gate_proj",
|
144
|
+
ckpt_down_proj_name="down_proj",
|
145
|
+
ckpt_up_proj_name="up_proj",
|
146
|
+
num_experts=num_experts,
|
147
|
+
)
|
105
148
|
|
106
|
-
|
149
|
+
for name, loaded_weight in weights:
|
150
|
+
if not "vision" in name:
|
151
|
+
name, loaded_weight = self.permute_qk_weight_for_rotary(
|
152
|
+
name, loaded_weight
|
153
|
+
)
|
107
154
|
|
108
155
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
109
156
|
if weight_name not in name:
|
110
157
|
continue
|
158
|
+
|
159
|
+
if "vision" in name:
|
160
|
+
continue
|
111
161
|
name = name.replace(weight_name, param_name)
|
112
162
|
param = params_dict[name]
|
113
163
|
weight_loader = param.weight_loader
|
@@ -115,31 +165,54 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
115
165
|
break
|
116
166
|
else:
|
117
167
|
if ".experts" in name:
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
|
123
|
-
shard_id_list = ["w1", "w3"]
|
124
|
-
else:
|
125
|
-
name_list = [
|
126
|
-
name.replace(".experts.down_proj", ".experts.w2_weight")
|
127
|
-
]
|
128
|
-
shard_id_list = ["w2"]
|
129
|
-
loaded_weight_list = [loaded_weight]
|
130
|
-
for name, loaded_weight, shard_id in zip(
|
131
|
-
name_list, loaded_weight_list, shard_id_list
|
168
|
+
# NOTE: llama4 fp8 has different weight format for experts
|
169
|
+
if (
|
170
|
+
"experts.gate_up_proj" not in name
|
171
|
+
and "experts.down_proj" not in name
|
132
172
|
):
|
133
|
-
|
134
|
-
|
135
|
-
|
173
|
+
for mapping in expert_params_mapping:
|
174
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
175
|
+
if weight_name not in name:
|
176
|
+
continue
|
177
|
+
name = name.replace(weight_name, param_name)
|
178
|
+
param = params_dict[name]
|
179
|
+
weight_loader = param.weight_loader
|
136
180
|
weight_loader(
|
137
181
|
param,
|
138
|
-
loaded_weight
|
182
|
+
loaded_weight,
|
139
183
|
name,
|
140
184
|
shard_id=shard_id,
|
141
185
|
expert_id=expert_id,
|
142
186
|
)
|
187
|
+
break
|
188
|
+
else:
|
189
|
+
if ".gate_up_proj" in name:
|
190
|
+
name_list = [
|
191
|
+
name.replace(
|
192
|
+
".experts.gate_up_proj", ".experts.w13_weight"
|
193
|
+
)
|
194
|
+
] * 2
|
195
|
+
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
|
196
|
+
shard_id_list = ["w1", "w3"]
|
197
|
+
else:
|
198
|
+
name_list = [
|
199
|
+
name.replace(".experts.down_proj", ".experts.w2_weight")
|
200
|
+
]
|
201
|
+
shard_id_list = ["w2"]
|
202
|
+
loaded_weight_list = [loaded_weight]
|
203
|
+
for name, loaded_weight, shard_id in zip(
|
204
|
+
name_list, loaded_weight_list, shard_id_list
|
205
|
+
):
|
206
|
+
param = params_dict[name]
|
207
|
+
weight_loader = param.weight_loader
|
208
|
+
for expert_id in range(num_experts):
|
209
|
+
weight_loader(
|
210
|
+
param,
|
211
|
+
loaded_weight[expert_id].T,
|
212
|
+
name,
|
213
|
+
shard_id=shard_id,
|
214
|
+
expert_id=expert_id,
|
215
|
+
)
|
143
216
|
else:
|
144
217
|
# Skip loading extra bias for GPTQ models.
|
145
218
|
if name.endswith(".bias") and name not in params_dict:
|