sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/lang/chat_template.py +24 -0
  3. sglang/srt/_custom_ops.py +59 -92
  4. sglang/srt/configs/model_config.py +5 -0
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/conversation.py +29 -4
  7. sglang/srt/custom_op.py +5 -0
  8. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  9. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/layers/attention/flashattention_backend.py +678 -83
  12. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  13. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  14. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  15. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  16. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  17. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  18. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
  30. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  31. sglang/srt/layers/moe/topk.py +49 -3
  32. sglang/srt/layers/quantization/__init__.py +5 -1
  33. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  35. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  36. sglang/srt/layers/quantization/fp8.py +3 -1
  37. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  38. sglang/srt/layers/quantization/moe_wna16.py +503 -0
  39. sglang/srt/layers/quantization/utils.py +1 -1
  40. sglang/srt/layers/quantization/w8a8_int8.py +2 -0
  41. sglang/srt/layers/radix_attention.py +2 -0
  42. sglang/srt/layers/rotary_embedding.py +63 -12
  43. sglang/srt/managers/cache_controller.py +34 -11
  44. sglang/srt/managers/mm_utils.py +202 -156
  45. sglang/srt/managers/multimodal_processor.py +0 -2
  46. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  47. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  48. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  49. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  50. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  51. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  52. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  53. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  54. sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
  55. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  56. sglang/srt/managers/schedule_batch.py +185 -128
  57. sglang/srt/managers/scheduler.py +4 -4
  58. sglang/srt/managers/tokenizer_manager.py +1 -1
  59. sglang/srt/managers/utils.py +1 -6
  60. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  61. sglang/srt/mem_cache/memory_pool.py +72 -6
  62. sglang/srt/mem_cache/paged_allocator.py +39 -0
  63. sglang/srt/metrics/collector.py +23 -53
  64. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  65. sglang/srt/model_executor/forward_batch_info.py +10 -10
  66. sglang/srt/model_executor/model_runner.py +60 -57
  67. sglang/srt/model_loader/loader.py +8 -0
  68. sglang/srt/models/clip.py +12 -7
  69. sglang/srt/models/deepseek_janus_pro.py +10 -15
  70. sglang/srt/models/deepseek_v2.py +212 -121
  71. sglang/srt/models/deepseek_vl2.py +105 -104
  72. sglang/srt/models/gemma3_mm.py +14 -80
  73. sglang/srt/models/llama.py +16 -5
  74. sglang/srt/models/llama4.py +420 -0
  75. sglang/srt/models/llava.py +31 -19
  76. sglang/srt/models/llavavid.py +16 -7
  77. sglang/srt/models/minicpmo.py +63 -147
  78. sglang/srt/models/minicpmv.py +17 -27
  79. sglang/srt/models/mllama.py +29 -14
  80. sglang/srt/models/mllama4.py +154 -0
  81. sglang/srt/models/qwen2.py +9 -6
  82. sglang/srt/models/qwen2_5_vl.py +21 -31
  83. sglang/srt/models/qwen2_vl.py +20 -21
  84. sglang/srt/openai_api/adapter.py +18 -6
  85. sglang/srt/platforms/interface.py +371 -0
  86. sglang/srt/server_args.py +99 -14
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  88. sglang/srt/speculative/eagle_utils.py +140 -28
  89. sglang/srt/speculative/eagle_worker.py +93 -24
  90. sglang/srt/utils.py +104 -51
  91. sglang/test/test_custom_ops.py +55 -0
  92. sglang/test/test_utils.py +13 -26
  93. sglang/utils.py +2 -2
  94. sglang/version.py +1 -1
  95. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
  96. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
  97. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,420 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ # Adapted from
16
+ # https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/model_executor/models/llama4.py
17
+ """Inference-only LLaMA model compatible with HuggingFace weights."""
18
+
19
+ import logging
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ from torch import nn
24
+ from transformers import Llama4TextConfig
25
+
26
+ from sglang.srt.distributed import (
27
+ get_tensor_model_parallel_world_size,
28
+ tensor_model_parallel_all_reduce,
29
+ )
30
+ from sglang.srt.layers.layernorm import RMSNorm
31
+ from sglang.srt.layers.linear import (
32
+ QKVParallelLinear,
33
+ ReplicatedLinear,
34
+ RowParallelLinear,
35
+ )
36
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
37
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
+ from sglang.srt.layers.radix_attention import RadixAttention
39
+ from sglang.srt.layers.rotary_embedding import get_rope
40
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
41
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
+ from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
43
+ from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ class Llama4MoE(nn.Module):
49
+
50
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
51
+ @staticmethod
52
+ def custom_routing_function(
53
+ hidden_states: torch.Tensor,
54
+ gating_output: torch.Tensor,
55
+ topk: int,
56
+ renormalize: bool,
57
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
58
+ router_scores_aK, router_indices_aK = torch.topk(gating_output, topk, dim=-1)
59
+ router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
60
+ hidden_states.dtype
61
+ )
62
+ return (
63
+ router_scores_aK.view(-1).reshape(router_scores_aK.shape),
64
+ router_indices_aK.to(torch.int32),
65
+ )
66
+
67
+ def __init__(
68
+ self,
69
+ config: Llama4TextConfig,
70
+ quant_config: Optional[QuantizationConfig] = None,
71
+ prefix: str = "",
72
+ ):
73
+ super().__init__()
74
+ self.tp_size = get_tensor_model_parallel_world_size()
75
+ self.top_k = config.num_experts_per_tok
76
+
77
+ intermediate_size_moe = config.intermediate_size
78
+ self.router = ReplicatedLinear(
79
+ config.hidden_size,
80
+ config.num_local_experts,
81
+ bias=False,
82
+ quant_config=None,
83
+ prefix=add_prefix("router", prefix),
84
+ )
85
+
86
+ self.experts = FusedMoE(
87
+ num_experts=config.num_local_experts,
88
+ top_k=config.num_experts_per_tok,
89
+ hidden_size=config.hidden_size,
90
+ custom_routing_function=Llama4MoE.custom_routing_function,
91
+ intermediate_size=intermediate_size_moe,
92
+ reduce_results=False,
93
+ renormalize=False,
94
+ quant_config=quant_config,
95
+ apply_router_weight_on_input=True,
96
+ prefix=add_prefix("experts", prefix),
97
+ )
98
+
99
+ self.shared_expert = LlamaMLP(
100
+ hidden_size=config.hidden_size,
101
+ intermediate_size=intermediate_size_moe,
102
+ hidden_act="silu",
103
+ quant_config=quant_config,
104
+ prefix=add_prefix("shared_expert", prefix),
105
+ reduce_results=False, # We need to do scatter before reduce
106
+ )
107
+
108
+ def forward(self, hidden_states):
109
+ # router_scores: [num_tokens, num_experts]
110
+ router_logits, _ = self.router(hidden_states)
111
+ shared_out = self.shared_expert(hidden_states)
112
+ routed_out = self.experts(
113
+ hidden_states=hidden_states,
114
+ router_logits=router_logits,
115
+ )
116
+ out_aD = routed_out + shared_out
117
+
118
+ if self.tp_size > 1:
119
+ out_aD = tensor_model_parallel_all_reduce(out_aD)
120
+
121
+ return out_aD
122
+
123
+
124
+ class Llama4Attention(nn.Module):
125
+
126
+ def __init__(
127
+ self,
128
+ config: Llama4TextConfig,
129
+ layer_id: int,
130
+ hidden_size: int,
131
+ num_heads: int,
132
+ num_kv_heads: int,
133
+ rope_theta: float = 10000,
134
+ rope_scaling: Optional[Dict[str, Any]] = None,
135
+ max_position_embeddings: int = 8192,
136
+ quant_config: Optional[QuantizationConfig] = None,
137
+ bias: bool = False,
138
+ bias_o_proj: bool = False,
139
+ prefix: str = "",
140
+ ) -> None:
141
+ super().__init__()
142
+ self.layer_id = layer_id
143
+ self.hidden_size = hidden_size
144
+ self.use_rope = int((layer_id + 1) % 4 != 0)
145
+ self.use_qk_norm = config.use_qk_norm and self.use_rope
146
+ tp_size = get_tensor_model_parallel_world_size()
147
+ self.total_num_heads = num_heads
148
+ assert self.total_num_heads % tp_size == 0
149
+ self.num_heads = self.total_num_heads // tp_size
150
+ self.total_num_kv_heads = num_kv_heads
151
+ if self.total_num_kv_heads >= tp_size:
152
+ # Number of KV heads is greater than TP size, so we partition
153
+ # the KV heads across multiple tensor parallel GPUs.
154
+ assert self.total_num_kv_heads % tp_size == 0
155
+ else:
156
+ # Number of KV heads is less than TP size, so we replicate
157
+ # the KV heads across multiple tensor parallel GPUs.
158
+ assert tp_size % self.total_num_kv_heads == 0
159
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
160
+ self.head_dim = config.head_dim
161
+ self.q_size = self.num_heads * self.head_dim
162
+ self.kv_size = self.num_kv_heads * self.head_dim
163
+ self.scaling = self.head_dim**-0.5
164
+ self.attn_temperature_tuning = config.attn_temperature_tuning
165
+ self.floor_scale = config.floor_scale
166
+ self.attn_scale = config.attn_scale
167
+ self.rope_theta = rope_theta
168
+ self.max_position_embeddings = max_position_embeddings
169
+ self.n_rep = self.num_heads // self.num_kv_heads
170
+ self.qk_norm = (
171
+ RMSNorm(
172
+ hidden_size=self.head_dim,
173
+ eps=config.rms_norm_eps,
174
+ )
175
+ if self.use_qk_norm
176
+ else None
177
+ )
178
+ self.qkv_proj = QKVParallelLinear(
179
+ hidden_size=hidden_size,
180
+ head_size=self.head_dim,
181
+ total_num_heads=self.total_num_heads,
182
+ total_num_kv_heads=self.total_num_kv_heads,
183
+ bias=bias,
184
+ quant_config=quant_config,
185
+ prefix=add_prefix("qkv_proj", prefix),
186
+ )
187
+
188
+ self.o_proj = RowParallelLinear(
189
+ input_size=self.total_num_heads * self.head_dim,
190
+ output_size=hidden_size,
191
+ bias=bias_o_proj,
192
+ quant_config=quant_config,
193
+ prefix=add_prefix("o_proj", prefix),
194
+ )
195
+ is_neox_style = True
196
+ is_gguf = quant_config and quant_config.get_name() == "gguf"
197
+ if is_gguf and config.model_type in ["llama", "llama4"]:
198
+ is_neox_style = False
199
+
200
+ self.rotary_emb = (
201
+ get_rope(
202
+ self.head_dim,
203
+ rotary_dim=self.head_dim,
204
+ max_position=max_position_embeddings,
205
+ base=int(rope_theta),
206
+ rope_scaling=rope_scaling if rope_scaling != "default" else None,
207
+ is_neox_style=is_neox_style,
208
+ )
209
+ if self.use_rope
210
+ else None
211
+ )
212
+
213
+ self.attn = RadixAttention(
214
+ self.num_heads,
215
+ self.head_dim,
216
+ self.scaling,
217
+ num_kv_heads=self.num_kv_heads,
218
+ layer_id=layer_id,
219
+ prefix=add_prefix("attn", prefix),
220
+ use_irope=self.use_rope,
221
+ )
222
+
223
+ def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
224
+ floor = torch.floor((positions + 1.0) / self.floor_scale)
225
+ attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
226
+
227
+ return attn_scale.unsqueeze(-1)
228
+
229
+ def forward(
230
+ self,
231
+ positions: torch.Tensor,
232
+ hidden_states: torch.Tensor,
233
+ forward_batch: ForwardBatch,
234
+ ) -> torch.Tensor:
235
+ qkv, _ = self.qkv_proj(hidden_states)
236
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
237
+
238
+ if self.rotary_emb is not None:
239
+ q, k = self.rotary_emb(positions, q, k)
240
+
241
+ if self.qk_norm is not None:
242
+ # TODO: support float
243
+ q = q.reshape(-1, self.head_dim).contiguous().bfloat16()
244
+ k = k.reshape(-1, self.head_dim).contiguous().bfloat16()
245
+ q = self.qk_norm(q).to(q.dtype)
246
+ k = self.qk_norm(k).to(k.dtype)
247
+ q = q.reshape(-1, self.q_size)
248
+ k = k.reshape(-1, self.kv_size)
249
+
250
+ # We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
251
+ # the inference-time temperature tuning function is customized to not affect short context
252
+ # while working at very long context
253
+ # https://arxiv.org/abs/2501.19399
254
+ if self.attn_temperature_tuning and not self.use_rope:
255
+ attn_scale = self._get_attn_scale(positions)
256
+ q = (q * attn_scale).to(q.dtype)
257
+
258
+ attn_output = self.attn(q, k, v, forward_batch)
259
+ output, _ = self.o_proj(attn_output)
260
+ return output
261
+
262
+
263
+ class Llama4DecoderLayer(nn.Module):
264
+ def __init__(
265
+ self,
266
+ config: Llama4TextConfig,
267
+ layer_id: int = 0,
268
+ quant_config: Optional[QuantizationConfig] = None,
269
+ prefix: str = "",
270
+ ):
271
+ super().__init__()
272
+ self.layer_id = layer_id
273
+ self.hidden_size = config.hidden_size
274
+ rope_theta = config.rope_theta
275
+ rope_scaling = config.rope_scaling
276
+ max_position_embeddings = config.max_position_embeddings
277
+
278
+ self.self_attn = Llama4Attention(
279
+ config=config,
280
+ layer_id=layer_id,
281
+ hidden_size=self.hidden_size,
282
+ num_heads=config.num_attention_heads,
283
+ num_kv_heads=config.num_key_value_heads,
284
+ rope_theta=rope_theta,
285
+ rope_scaling=rope_scaling,
286
+ max_position_embeddings=max_position_embeddings,
287
+ quant_config=quant_config,
288
+ bias=False,
289
+ bias_o_proj=False,
290
+ prefix=add_prefix("self_attn", prefix),
291
+ )
292
+ is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0
293
+ if is_moe_layer:
294
+ self.feed_forward = Llama4MoE(
295
+ config=config,
296
+ quant_config=quant_config,
297
+ prefix=add_prefix("feed_forward", prefix),
298
+ )
299
+ else:
300
+ self.feed_forward = LlamaMLP(
301
+ hidden_size=self.hidden_size,
302
+ intermediate_size=config.intermediate_size_mlp,
303
+ hidden_act="silu",
304
+ quant_config=quant_config,
305
+ prefix=add_prefix("feed_forward", prefix),
306
+ )
307
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
308
+ self.post_attention_layernorm = RMSNorm(
309
+ config.hidden_size, eps=config.rms_norm_eps
310
+ )
311
+
312
+ def forward(
313
+ self,
314
+ positions: torch.Tensor,
315
+ hidden_states: torch.Tensor,
316
+ forward_batch: ForwardBatch,
317
+ residual: Optional[torch.Tensor],
318
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
319
+ # Self Attention
320
+ if residual is None:
321
+ residual = hidden_states
322
+ hidden_states = self.input_layernorm(hidden_states)
323
+ else:
324
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
325
+ hidden_states = self.self_attn(
326
+ positions=positions,
327
+ hidden_states=hidden_states,
328
+ forward_batch=forward_batch,
329
+ )
330
+
331
+ # Fully Connected
332
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
333
+ hidden_states = self.feed_forward(hidden_states)
334
+ return hidden_states, residual
335
+
336
+
337
+ class Llama4Model(nn.Module):
338
+ def __init__(
339
+ self,
340
+ config: Llama4TextConfig,
341
+ quant_config: Optional[QuantizationConfig] = None,
342
+ prefix: str = "",
343
+ ) -> None:
344
+ super().__init__()
345
+ self.config = config
346
+ self.padding_idx = config.pad_token_id
347
+ self.vocab_size = config.vocab_size
348
+ self.embed_tokens = VocabParallelEmbedding(
349
+ config.vocab_size,
350
+ config.hidden_size,
351
+ quant_config=quant_config,
352
+ prefix=add_prefix("embed_tokens", prefix),
353
+ )
354
+ self.layers = make_layers(
355
+ config.num_hidden_layers,
356
+ lambda idx, prefix: Llama4DecoderLayer(
357
+ config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
358
+ ),
359
+ prefix="model.layers",
360
+ )
361
+
362
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
363
+ self.layers_to_capture = []
364
+
365
+ def forward(
366
+ self,
367
+ input_ids: torch.Tensor,
368
+ positions: torch.Tensor,
369
+ forward_batch: ForwardBatch,
370
+ input_embeds: torch.Tensor = None,
371
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
372
+ if input_embeds is None:
373
+ hidden_states = self.embed_tokens(input_ids)
374
+ else:
375
+ hidden_states = input_embeds
376
+ residual = None
377
+ aux_hidden_states = []
378
+ for i in range(len(self.layers)):
379
+ if i in self.layers_to_capture:
380
+ aux_hidden_states.append(hidden_states + residual)
381
+ layer = self.layers[i]
382
+ hidden_states, residual = layer(
383
+ positions,
384
+ hidden_states,
385
+ forward_batch,
386
+ residual,
387
+ )
388
+ hidden_states, _ = self.norm(hidden_states, residual)
389
+
390
+ if len(aux_hidden_states) == 0:
391
+ return hidden_states
392
+
393
+ return hidden_states, aux_hidden_states
394
+
395
+
396
+ class Llama4ForCausalLM(LlamaForCausalLM):
397
+
398
+ packed_modules_mapping = {
399
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
400
+ "gate_up_proj": ["gate_proj", "up_proj"],
401
+ }
402
+
403
+ def __init__(
404
+ self,
405
+ config: Llama4TextConfig,
406
+ quant_config: Optional[QuantizationConfig] = None,
407
+ prefix: str = "",
408
+ ):
409
+ super().__init__(config, quant_config, prefix)
410
+
411
+ def _init_model(
412
+ self,
413
+ config: Llama4TextConfig,
414
+ quant_config: Optional[QuantizationConfig] = None,
415
+ prefix: str = "",
416
+ ):
417
+ return Llama4Model(config, quant_config=quant_config, prefix=prefix)
418
+
419
+
420
+ EntryClass = [Llama4ForCausalLM]
@@ -31,7 +31,7 @@ from transformers import (
31
31
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
32
32
 
33
33
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
34
- from sglang.srt.managers.schedule_batch import MultimodalInputs
34
+ from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs
35
35
  from sglang.srt.mm_utils import (
36
36
  get_anyres_image_grid_shape,
37
37
  unpad_image,
@@ -42,17 +42,21 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
42
42
  from sglang.srt.models.llama import LlamaForCausalLM
43
43
  from sglang.srt.models.mistral import MistralForCausalLM
44
44
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
45
- from sglang.srt.utils import add_prefix
45
+ from sglang.srt.utils import add_prefix, flatten_nested_list
46
46
 
47
47
 
48
48
  class LlavaBaseForCausalLM(nn.Module):
49
49
  def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
50
- image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
50
+ image_sizes = flatten_nested_list(
51
+ [item.image_sizes for item in image_inputs.mm_items]
52
+ )
53
+
54
+ pad_values = [item.pad_value for item in image_inputs.mm_items]
51
55
 
52
56
  # hardcode for spatial_unpad + anyres
53
- if image_inputs.modalities is not None and (
54
- "multi-images" in image_inputs.modalities
55
- or "video" in image_inputs.modalities
57
+ if any(
58
+ item.modality == Modality.MULTI_IMAGES or item.modality == Modality.VIDEO
59
+ for item in image_inputs.mm_items
56
60
  ):
57
61
  image_aspect_ratio = "pad"
58
62
  else:
@@ -66,7 +70,7 @@ class LlavaBaseForCausalLM(nn.Module):
66
70
  math.ceil(self.image_size / self.patch_size / 2) ** 2
67
71
  )
68
72
  else:
69
- new_image_feature_len = self.image_feature_len # multiimage
73
+ new_image_feature_len = self.image_feature_len # multi-image
70
74
 
71
75
  height = width = self.num_patches_per_side
72
76
  if "anyres" in image_aspect_ratio:
@@ -101,7 +105,7 @@ class LlavaBaseForCausalLM(nn.Module):
101
105
  # old_len + pad_len - 1, because we need to remove image_token_id
102
106
  input_ids = (
103
107
  input_ids[:offset]
104
- + [pad_values[image_idx]] * new_image_feature_len
108
+ + [pad_values[image_idx % len(pad_values)]] * new_image_feature_len
105
109
  + input_ids[offset + 1 :]
106
110
  )
107
111
  offset_list.append(offset)
@@ -150,8 +154,8 @@ class LlavaBaseForCausalLM(nn.Module):
150
154
  modalities_list = []
151
155
  max_image_offset = []
152
156
  for im in image_inputs:
153
- if im and im.modalities is not None:
154
- modalities_list.extend(im.modalities)
157
+ if im:
158
+ modalities_list.extend([item.modality for item in im.mm_items])
155
159
  if im and im.image_offsets:
156
160
  max_image_offset.append(
157
161
  np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
@@ -164,11 +168,19 @@ class LlavaBaseForCausalLM(nn.Module):
164
168
 
165
169
  if need_vision.any():
166
170
  bs = forward_batch.batch_size
167
- pixel_values = [
168
- image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
169
- ]
171
+ pixel_values = flatten_nested_list(
172
+ [
173
+ [item.pixel_values for item in image_inputs[i].mm_items]
174
+ for i in range(bs)
175
+ if need_vision[i]
176
+ ]
177
+ )
170
178
  image_sizes = [
171
- image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
179
+ flatten_nested_list(
180
+ [item.image_sizes for item in image_inputs[i].mm_items]
181
+ )
182
+ for i in range(bs)
183
+ if need_vision[i]
172
184
  ]
173
185
 
174
186
  ########## Encode Image ########
@@ -197,13 +209,13 @@ class LlavaBaseForCausalLM(nn.Module):
197
209
  new_image_features = []
198
210
  height = width = self.num_patches_per_side
199
211
  for image_idx, image_feature in enumerate(image_features):
200
- if modalities_list[image_idx] == "image":
212
+ if modalities_list[image_idx] == Modality.IMAGE:
201
213
  image_aspect_ratio = (
202
214
  self.config.image_aspect_ratio
203
215
  ) # single image
204
216
  elif (
205
- modalities_list[image_idx] == "multi-images"
206
- or modalities_list[image_idx] == "video"
217
+ modalities_list[image_idx] == Modality.MULTI_IMAGES
218
+ or modalities_list[image_idx] == Modality.VIDEO
207
219
  ):
208
220
  image_aspect_ratio = "pad" # multi image
209
221
  # image_aspect_ratio = (
@@ -212,7 +224,7 @@ class LlavaBaseForCausalLM(nn.Module):
212
224
  if (
213
225
  image_feature.shape[0] > 1
214
226
  and "anyres" in image_aspect_ratio
215
- and modalities_list[image_idx] == "image"
227
+ and modalities_list[image_idx] == Modality.IMAGE
216
228
  ):
217
229
  base_image_feature = image_feature[0]
218
230
  image_feature = image_feature[1:]
@@ -312,7 +324,7 @@ class LlavaBaseForCausalLM(nn.Module):
312
324
  )
313
325
  image_feature = image_feature.unsqueeze(0)
314
326
  else:
315
- if modalities_list[image_idx] == "video": # video
327
+ if modalities_list[image_idx] == Modality.VIDEO: # video
316
328
  # 2x2 pooling
317
329
  num_of_frames = image_feature.shape[0]
318
330
  image_feature = image_feature.view(
@@ -22,7 +22,7 @@ from transformers import CLIPVisionModel, LlavaConfig
22
22
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
23
23
 
24
24
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
25
- from sglang.srt.managers.schedule_batch import MultimodalInputs
25
+ from sglang.srt.managers.schedule_batch import MultimodalInputs, flatten_nested_list
26
26
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
27
  from sglang.srt.model_loader.weight_utils import default_weight_loader
28
28
  from sglang.srt.models.llama import LlamaForCausalLM
@@ -58,7 +58,7 @@ class LlavaVidForCausalLM(nn.Module):
58
58
  )
59
59
 
60
60
  def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
61
- pad_values = image_inputs.pad_values
61
+ pad_values = [item.pad_value for item in image_inputs.mm_items]
62
62
  new_image_feature_len = self.image_feature_len
63
63
 
64
64
  pad_ids = pad_values * (
@@ -133,11 +133,19 @@ class LlavaVidForCausalLM(nn.Module):
133
133
  need_vision = start_positions <= np.array(max_image_offset)
134
134
 
135
135
  if need_vision.any():
136
- pixel_values = [
137
- image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
138
- ]
136
+ pixel_values = flatten_nested_list(
137
+ [
138
+ [item.pixel_values for item in image_inputs[i].mm_items]
139
+ for i in range(bs)
140
+ if need_vision[i]
141
+ ]
142
+ )
139
143
  image_offsets = [
140
- image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
144
+ flatten_nested_list(
145
+ [item.image_offsets for item in image_inputs[i].mm_items]
146
+ )
147
+ for i in range(bs)
148
+ if need_vision[i]
141
149
  ]
142
150
 
143
151
  ########## Encode Image ########
@@ -246,7 +254,8 @@ class LlavaVidForCausalLM(nn.Module):
246
254
  "model.mm_projector.2": "multi_modal_projector.linear_2",
247
255
  "model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
248
256
  "model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
249
- "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
257
+ "model.vision_tower.vision_tower": "vision_tower",
258
+ # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
250
259
  "model.image_newline": "language_model.model.image_newline",
251
260
  }
252
261
  params_dict = dict(self.named_parameters())