sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +49 -7
- sglang/lang/chat_template.py +24 -0
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/model_config.py +5 -0
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/conversation.py +29 -4
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/layers/attention/flashattention_backend.py +678 -83
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +49 -3
- sglang/srt/layers/quantization/__init__.py +5 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8.py +3 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/moe_wna16.py +503 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/quantization/w8a8_int8.py +2 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +63 -12
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +7 -26
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -128
- sglang/srt/managers/scheduler.py +4 -4
- sglang/srt/managers/tokenizer_manager.py +1 -1
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +8 -6
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +60 -57
- sglang/srt/model_loader/loader.py +8 -0
- sglang/srt/models/clip.py +12 -7
- sglang/srt/models/deepseek_janus_pro.py +10 -15
- sglang/srt/models/deepseek_v2.py +212 -121
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_mm.py +14 -80
- sglang/srt/models/llama.py +16 -5
- sglang/srt/models/llama4.py +420 -0
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/mllama4.py +154 -0
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +18 -6
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +99 -14
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +93 -24
- sglang/srt/utils.py +104 -51
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +13 -26
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,420 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
# Adapted from
|
16
|
+
# https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/model_executor/models/llama4.py
|
17
|
+
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
18
|
+
|
19
|
+
import logging
|
20
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
21
|
+
|
22
|
+
import torch
|
23
|
+
from torch import nn
|
24
|
+
from transformers import Llama4TextConfig
|
25
|
+
|
26
|
+
from sglang.srt.distributed import (
|
27
|
+
get_tensor_model_parallel_world_size,
|
28
|
+
tensor_model_parallel_all_reduce,
|
29
|
+
)
|
30
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
31
|
+
from sglang.srt.layers.linear import (
|
32
|
+
QKVParallelLinear,
|
33
|
+
ReplicatedLinear,
|
34
|
+
RowParallelLinear,
|
35
|
+
)
|
36
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
37
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
40
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
41
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
|
+
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
43
|
+
from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers
|
44
|
+
|
45
|
+
logger = logging.getLogger(__name__)
|
46
|
+
|
47
|
+
|
48
|
+
class Llama4MoE(nn.Module):
|
49
|
+
|
50
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
51
|
+
@staticmethod
|
52
|
+
def custom_routing_function(
|
53
|
+
hidden_states: torch.Tensor,
|
54
|
+
gating_output: torch.Tensor,
|
55
|
+
topk: int,
|
56
|
+
renormalize: bool,
|
57
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
58
|
+
router_scores_aK, router_indices_aK = torch.topk(gating_output, topk, dim=-1)
|
59
|
+
router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
|
60
|
+
hidden_states.dtype
|
61
|
+
)
|
62
|
+
return (
|
63
|
+
router_scores_aK.view(-1).reshape(router_scores_aK.shape),
|
64
|
+
router_indices_aK.to(torch.int32),
|
65
|
+
)
|
66
|
+
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
config: Llama4TextConfig,
|
70
|
+
quant_config: Optional[QuantizationConfig] = None,
|
71
|
+
prefix: str = "",
|
72
|
+
):
|
73
|
+
super().__init__()
|
74
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
75
|
+
self.top_k = config.num_experts_per_tok
|
76
|
+
|
77
|
+
intermediate_size_moe = config.intermediate_size
|
78
|
+
self.router = ReplicatedLinear(
|
79
|
+
config.hidden_size,
|
80
|
+
config.num_local_experts,
|
81
|
+
bias=False,
|
82
|
+
quant_config=None,
|
83
|
+
prefix=add_prefix("router", prefix),
|
84
|
+
)
|
85
|
+
|
86
|
+
self.experts = FusedMoE(
|
87
|
+
num_experts=config.num_local_experts,
|
88
|
+
top_k=config.num_experts_per_tok,
|
89
|
+
hidden_size=config.hidden_size,
|
90
|
+
custom_routing_function=Llama4MoE.custom_routing_function,
|
91
|
+
intermediate_size=intermediate_size_moe,
|
92
|
+
reduce_results=False,
|
93
|
+
renormalize=False,
|
94
|
+
quant_config=quant_config,
|
95
|
+
apply_router_weight_on_input=True,
|
96
|
+
prefix=add_prefix("experts", prefix),
|
97
|
+
)
|
98
|
+
|
99
|
+
self.shared_expert = LlamaMLP(
|
100
|
+
hidden_size=config.hidden_size,
|
101
|
+
intermediate_size=intermediate_size_moe,
|
102
|
+
hidden_act="silu",
|
103
|
+
quant_config=quant_config,
|
104
|
+
prefix=add_prefix("shared_expert", prefix),
|
105
|
+
reduce_results=False, # We need to do scatter before reduce
|
106
|
+
)
|
107
|
+
|
108
|
+
def forward(self, hidden_states):
|
109
|
+
# router_scores: [num_tokens, num_experts]
|
110
|
+
router_logits, _ = self.router(hidden_states)
|
111
|
+
shared_out = self.shared_expert(hidden_states)
|
112
|
+
routed_out = self.experts(
|
113
|
+
hidden_states=hidden_states,
|
114
|
+
router_logits=router_logits,
|
115
|
+
)
|
116
|
+
out_aD = routed_out + shared_out
|
117
|
+
|
118
|
+
if self.tp_size > 1:
|
119
|
+
out_aD = tensor_model_parallel_all_reduce(out_aD)
|
120
|
+
|
121
|
+
return out_aD
|
122
|
+
|
123
|
+
|
124
|
+
class Llama4Attention(nn.Module):
|
125
|
+
|
126
|
+
def __init__(
|
127
|
+
self,
|
128
|
+
config: Llama4TextConfig,
|
129
|
+
layer_id: int,
|
130
|
+
hidden_size: int,
|
131
|
+
num_heads: int,
|
132
|
+
num_kv_heads: int,
|
133
|
+
rope_theta: float = 10000,
|
134
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
135
|
+
max_position_embeddings: int = 8192,
|
136
|
+
quant_config: Optional[QuantizationConfig] = None,
|
137
|
+
bias: bool = False,
|
138
|
+
bias_o_proj: bool = False,
|
139
|
+
prefix: str = "",
|
140
|
+
) -> None:
|
141
|
+
super().__init__()
|
142
|
+
self.layer_id = layer_id
|
143
|
+
self.hidden_size = hidden_size
|
144
|
+
self.use_rope = int((layer_id + 1) % 4 != 0)
|
145
|
+
self.use_qk_norm = config.use_qk_norm and self.use_rope
|
146
|
+
tp_size = get_tensor_model_parallel_world_size()
|
147
|
+
self.total_num_heads = num_heads
|
148
|
+
assert self.total_num_heads % tp_size == 0
|
149
|
+
self.num_heads = self.total_num_heads // tp_size
|
150
|
+
self.total_num_kv_heads = num_kv_heads
|
151
|
+
if self.total_num_kv_heads >= tp_size:
|
152
|
+
# Number of KV heads is greater than TP size, so we partition
|
153
|
+
# the KV heads across multiple tensor parallel GPUs.
|
154
|
+
assert self.total_num_kv_heads % tp_size == 0
|
155
|
+
else:
|
156
|
+
# Number of KV heads is less than TP size, so we replicate
|
157
|
+
# the KV heads across multiple tensor parallel GPUs.
|
158
|
+
assert tp_size % self.total_num_kv_heads == 0
|
159
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
160
|
+
self.head_dim = config.head_dim
|
161
|
+
self.q_size = self.num_heads * self.head_dim
|
162
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
163
|
+
self.scaling = self.head_dim**-0.5
|
164
|
+
self.attn_temperature_tuning = config.attn_temperature_tuning
|
165
|
+
self.floor_scale = config.floor_scale
|
166
|
+
self.attn_scale = config.attn_scale
|
167
|
+
self.rope_theta = rope_theta
|
168
|
+
self.max_position_embeddings = max_position_embeddings
|
169
|
+
self.n_rep = self.num_heads // self.num_kv_heads
|
170
|
+
self.qk_norm = (
|
171
|
+
RMSNorm(
|
172
|
+
hidden_size=self.head_dim,
|
173
|
+
eps=config.rms_norm_eps,
|
174
|
+
)
|
175
|
+
if self.use_qk_norm
|
176
|
+
else None
|
177
|
+
)
|
178
|
+
self.qkv_proj = QKVParallelLinear(
|
179
|
+
hidden_size=hidden_size,
|
180
|
+
head_size=self.head_dim,
|
181
|
+
total_num_heads=self.total_num_heads,
|
182
|
+
total_num_kv_heads=self.total_num_kv_heads,
|
183
|
+
bias=bias,
|
184
|
+
quant_config=quant_config,
|
185
|
+
prefix=add_prefix("qkv_proj", prefix),
|
186
|
+
)
|
187
|
+
|
188
|
+
self.o_proj = RowParallelLinear(
|
189
|
+
input_size=self.total_num_heads * self.head_dim,
|
190
|
+
output_size=hidden_size,
|
191
|
+
bias=bias_o_proj,
|
192
|
+
quant_config=quant_config,
|
193
|
+
prefix=add_prefix("o_proj", prefix),
|
194
|
+
)
|
195
|
+
is_neox_style = True
|
196
|
+
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
197
|
+
if is_gguf and config.model_type in ["llama", "llama4"]:
|
198
|
+
is_neox_style = False
|
199
|
+
|
200
|
+
self.rotary_emb = (
|
201
|
+
get_rope(
|
202
|
+
self.head_dim,
|
203
|
+
rotary_dim=self.head_dim,
|
204
|
+
max_position=max_position_embeddings,
|
205
|
+
base=int(rope_theta),
|
206
|
+
rope_scaling=rope_scaling if rope_scaling != "default" else None,
|
207
|
+
is_neox_style=is_neox_style,
|
208
|
+
)
|
209
|
+
if self.use_rope
|
210
|
+
else None
|
211
|
+
)
|
212
|
+
|
213
|
+
self.attn = RadixAttention(
|
214
|
+
self.num_heads,
|
215
|
+
self.head_dim,
|
216
|
+
self.scaling,
|
217
|
+
num_kv_heads=self.num_kv_heads,
|
218
|
+
layer_id=layer_id,
|
219
|
+
prefix=add_prefix("attn", prefix),
|
220
|
+
use_irope=self.use_rope,
|
221
|
+
)
|
222
|
+
|
223
|
+
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
224
|
+
floor = torch.floor((positions + 1.0) / self.floor_scale)
|
225
|
+
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
|
226
|
+
|
227
|
+
return attn_scale.unsqueeze(-1)
|
228
|
+
|
229
|
+
def forward(
|
230
|
+
self,
|
231
|
+
positions: torch.Tensor,
|
232
|
+
hidden_states: torch.Tensor,
|
233
|
+
forward_batch: ForwardBatch,
|
234
|
+
) -> torch.Tensor:
|
235
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
236
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
237
|
+
|
238
|
+
if self.rotary_emb is not None:
|
239
|
+
q, k = self.rotary_emb(positions, q, k)
|
240
|
+
|
241
|
+
if self.qk_norm is not None:
|
242
|
+
# TODO: support float
|
243
|
+
q = q.reshape(-1, self.head_dim).contiguous().bfloat16()
|
244
|
+
k = k.reshape(-1, self.head_dim).contiguous().bfloat16()
|
245
|
+
q = self.qk_norm(q).to(q.dtype)
|
246
|
+
k = self.qk_norm(k).to(k.dtype)
|
247
|
+
q = q.reshape(-1, self.q_size)
|
248
|
+
k = k.reshape(-1, self.kv_size)
|
249
|
+
|
250
|
+
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
|
251
|
+
# the inference-time temperature tuning function is customized to not affect short context
|
252
|
+
# while working at very long context
|
253
|
+
# https://arxiv.org/abs/2501.19399
|
254
|
+
if self.attn_temperature_tuning and not self.use_rope:
|
255
|
+
attn_scale = self._get_attn_scale(positions)
|
256
|
+
q = (q * attn_scale).to(q.dtype)
|
257
|
+
|
258
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
259
|
+
output, _ = self.o_proj(attn_output)
|
260
|
+
return output
|
261
|
+
|
262
|
+
|
263
|
+
class Llama4DecoderLayer(nn.Module):
|
264
|
+
def __init__(
|
265
|
+
self,
|
266
|
+
config: Llama4TextConfig,
|
267
|
+
layer_id: int = 0,
|
268
|
+
quant_config: Optional[QuantizationConfig] = None,
|
269
|
+
prefix: str = "",
|
270
|
+
):
|
271
|
+
super().__init__()
|
272
|
+
self.layer_id = layer_id
|
273
|
+
self.hidden_size = config.hidden_size
|
274
|
+
rope_theta = config.rope_theta
|
275
|
+
rope_scaling = config.rope_scaling
|
276
|
+
max_position_embeddings = config.max_position_embeddings
|
277
|
+
|
278
|
+
self.self_attn = Llama4Attention(
|
279
|
+
config=config,
|
280
|
+
layer_id=layer_id,
|
281
|
+
hidden_size=self.hidden_size,
|
282
|
+
num_heads=config.num_attention_heads,
|
283
|
+
num_kv_heads=config.num_key_value_heads,
|
284
|
+
rope_theta=rope_theta,
|
285
|
+
rope_scaling=rope_scaling,
|
286
|
+
max_position_embeddings=max_position_embeddings,
|
287
|
+
quant_config=quant_config,
|
288
|
+
bias=False,
|
289
|
+
bias_o_proj=False,
|
290
|
+
prefix=add_prefix("self_attn", prefix),
|
291
|
+
)
|
292
|
+
is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0
|
293
|
+
if is_moe_layer:
|
294
|
+
self.feed_forward = Llama4MoE(
|
295
|
+
config=config,
|
296
|
+
quant_config=quant_config,
|
297
|
+
prefix=add_prefix("feed_forward", prefix),
|
298
|
+
)
|
299
|
+
else:
|
300
|
+
self.feed_forward = LlamaMLP(
|
301
|
+
hidden_size=self.hidden_size,
|
302
|
+
intermediate_size=config.intermediate_size_mlp,
|
303
|
+
hidden_act="silu",
|
304
|
+
quant_config=quant_config,
|
305
|
+
prefix=add_prefix("feed_forward", prefix),
|
306
|
+
)
|
307
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
308
|
+
self.post_attention_layernorm = RMSNorm(
|
309
|
+
config.hidden_size, eps=config.rms_norm_eps
|
310
|
+
)
|
311
|
+
|
312
|
+
def forward(
|
313
|
+
self,
|
314
|
+
positions: torch.Tensor,
|
315
|
+
hidden_states: torch.Tensor,
|
316
|
+
forward_batch: ForwardBatch,
|
317
|
+
residual: Optional[torch.Tensor],
|
318
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
319
|
+
# Self Attention
|
320
|
+
if residual is None:
|
321
|
+
residual = hidden_states
|
322
|
+
hidden_states = self.input_layernorm(hidden_states)
|
323
|
+
else:
|
324
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
325
|
+
hidden_states = self.self_attn(
|
326
|
+
positions=positions,
|
327
|
+
hidden_states=hidden_states,
|
328
|
+
forward_batch=forward_batch,
|
329
|
+
)
|
330
|
+
|
331
|
+
# Fully Connected
|
332
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
333
|
+
hidden_states = self.feed_forward(hidden_states)
|
334
|
+
return hidden_states, residual
|
335
|
+
|
336
|
+
|
337
|
+
class Llama4Model(nn.Module):
|
338
|
+
def __init__(
|
339
|
+
self,
|
340
|
+
config: Llama4TextConfig,
|
341
|
+
quant_config: Optional[QuantizationConfig] = None,
|
342
|
+
prefix: str = "",
|
343
|
+
) -> None:
|
344
|
+
super().__init__()
|
345
|
+
self.config = config
|
346
|
+
self.padding_idx = config.pad_token_id
|
347
|
+
self.vocab_size = config.vocab_size
|
348
|
+
self.embed_tokens = VocabParallelEmbedding(
|
349
|
+
config.vocab_size,
|
350
|
+
config.hidden_size,
|
351
|
+
quant_config=quant_config,
|
352
|
+
prefix=add_prefix("embed_tokens", prefix),
|
353
|
+
)
|
354
|
+
self.layers = make_layers(
|
355
|
+
config.num_hidden_layers,
|
356
|
+
lambda idx, prefix: Llama4DecoderLayer(
|
357
|
+
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
|
358
|
+
),
|
359
|
+
prefix="model.layers",
|
360
|
+
)
|
361
|
+
|
362
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
363
|
+
self.layers_to_capture = []
|
364
|
+
|
365
|
+
def forward(
|
366
|
+
self,
|
367
|
+
input_ids: torch.Tensor,
|
368
|
+
positions: torch.Tensor,
|
369
|
+
forward_batch: ForwardBatch,
|
370
|
+
input_embeds: torch.Tensor = None,
|
371
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
372
|
+
if input_embeds is None:
|
373
|
+
hidden_states = self.embed_tokens(input_ids)
|
374
|
+
else:
|
375
|
+
hidden_states = input_embeds
|
376
|
+
residual = None
|
377
|
+
aux_hidden_states = []
|
378
|
+
for i in range(len(self.layers)):
|
379
|
+
if i in self.layers_to_capture:
|
380
|
+
aux_hidden_states.append(hidden_states + residual)
|
381
|
+
layer = self.layers[i]
|
382
|
+
hidden_states, residual = layer(
|
383
|
+
positions,
|
384
|
+
hidden_states,
|
385
|
+
forward_batch,
|
386
|
+
residual,
|
387
|
+
)
|
388
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
389
|
+
|
390
|
+
if len(aux_hidden_states) == 0:
|
391
|
+
return hidden_states
|
392
|
+
|
393
|
+
return hidden_states, aux_hidden_states
|
394
|
+
|
395
|
+
|
396
|
+
class Llama4ForCausalLM(LlamaForCausalLM):
|
397
|
+
|
398
|
+
packed_modules_mapping = {
|
399
|
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
400
|
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
401
|
+
}
|
402
|
+
|
403
|
+
def __init__(
|
404
|
+
self,
|
405
|
+
config: Llama4TextConfig,
|
406
|
+
quant_config: Optional[QuantizationConfig] = None,
|
407
|
+
prefix: str = "",
|
408
|
+
):
|
409
|
+
super().__init__(config, quant_config, prefix)
|
410
|
+
|
411
|
+
def _init_model(
|
412
|
+
self,
|
413
|
+
config: Llama4TextConfig,
|
414
|
+
quant_config: Optional[QuantizationConfig] = None,
|
415
|
+
prefix: str = "",
|
416
|
+
):
|
417
|
+
return Llama4Model(config, quant_config=quant_config, prefix=prefix)
|
418
|
+
|
419
|
+
|
420
|
+
EntryClass = [Llama4ForCausalLM]
|
sglang/srt/models/llava.py
CHANGED
@@ -31,7 +31,7 @@ from transformers import (
|
|
31
31
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
32
32
|
|
33
33
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
34
|
-
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
34
|
+
from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs
|
35
35
|
from sglang.srt.mm_utils import (
|
36
36
|
get_anyres_image_grid_shape,
|
37
37
|
unpad_image,
|
@@ -42,17 +42,21 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
42
42
|
from sglang.srt.models.llama import LlamaForCausalLM
|
43
43
|
from sglang.srt.models.mistral import MistralForCausalLM
|
44
44
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
45
|
-
from sglang.srt.utils import add_prefix
|
45
|
+
from sglang.srt.utils import add_prefix, flatten_nested_list
|
46
46
|
|
47
47
|
|
48
48
|
class LlavaBaseForCausalLM(nn.Module):
|
49
49
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
50
|
-
image_sizes
|
50
|
+
image_sizes = flatten_nested_list(
|
51
|
+
[item.image_sizes for item in image_inputs.mm_items]
|
52
|
+
)
|
53
|
+
|
54
|
+
pad_values = [item.pad_value for item in image_inputs.mm_items]
|
51
55
|
|
52
56
|
# hardcode for spatial_unpad + anyres
|
53
|
-
if
|
54
|
-
|
55
|
-
|
57
|
+
if any(
|
58
|
+
item.modality == Modality.MULTI_IMAGES or item.modality == Modality.VIDEO
|
59
|
+
for item in image_inputs.mm_items
|
56
60
|
):
|
57
61
|
image_aspect_ratio = "pad"
|
58
62
|
else:
|
@@ -66,7 +70,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
66
70
|
math.ceil(self.image_size / self.patch_size / 2) ** 2
|
67
71
|
)
|
68
72
|
else:
|
69
|
-
new_image_feature_len = self.image_feature_len #
|
73
|
+
new_image_feature_len = self.image_feature_len # multi-image
|
70
74
|
|
71
75
|
height = width = self.num_patches_per_side
|
72
76
|
if "anyres" in image_aspect_ratio:
|
@@ -101,7 +105,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
101
105
|
# old_len + pad_len - 1, because we need to remove image_token_id
|
102
106
|
input_ids = (
|
103
107
|
input_ids[:offset]
|
104
|
-
+ [pad_values[image_idx]] * new_image_feature_len
|
108
|
+
+ [pad_values[image_idx % len(pad_values)]] * new_image_feature_len
|
105
109
|
+ input_ids[offset + 1 :]
|
106
110
|
)
|
107
111
|
offset_list.append(offset)
|
@@ -150,8 +154,8 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
150
154
|
modalities_list = []
|
151
155
|
max_image_offset = []
|
152
156
|
for im in image_inputs:
|
153
|
-
if im
|
154
|
-
modalities_list.extend(im.
|
157
|
+
if im:
|
158
|
+
modalities_list.extend([item.modality for item in im.mm_items])
|
155
159
|
if im and im.image_offsets:
|
156
160
|
max_image_offset.append(
|
157
161
|
np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
|
@@ -164,11 +168,19 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
164
168
|
|
165
169
|
if need_vision.any():
|
166
170
|
bs = forward_batch.batch_size
|
167
|
-
pixel_values =
|
168
|
-
|
169
|
-
|
171
|
+
pixel_values = flatten_nested_list(
|
172
|
+
[
|
173
|
+
[item.pixel_values for item in image_inputs[i].mm_items]
|
174
|
+
for i in range(bs)
|
175
|
+
if need_vision[i]
|
176
|
+
]
|
177
|
+
)
|
170
178
|
image_sizes = [
|
171
|
-
|
179
|
+
flatten_nested_list(
|
180
|
+
[item.image_sizes for item in image_inputs[i].mm_items]
|
181
|
+
)
|
182
|
+
for i in range(bs)
|
183
|
+
if need_vision[i]
|
172
184
|
]
|
173
185
|
|
174
186
|
########## Encode Image ########
|
@@ -197,13 +209,13 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
197
209
|
new_image_features = []
|
198
210
|
height = width = self.num_patches_per_side
|
199
211
|
for image_idx, image_feature in enumerate(image_features):
|
200
|
-
if modalities_list[image_idx] ==
|
212
|
+
if modalities_list[image_idx] == Modality.IMAGE:
|
201
213
|
image_aspect_ratio = (
|
202
214
|
self.config.image_aspect_ratio
|
203
215
|
) # single image
|
204
216
|
elif (
|
205
|
-
modalities_list[image_idx] ==
|
206
|
-
or modalities_list[image_idx] ==
|
217
|
+
modalities_list[image_idx] == Modality.MULTI_IMAGES
|
218
|
+
or modalities_list[image_idx] == Modality.VIDEO
|
207
219
|
):
|
208
220
|
image_aspect_ratio = "pad" # multi image
|
209
221
|
# image_aspect_ratio = (
|
@@ -212,7 +224,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
212
224
|
if (
|
213
225
|
image_feature.shape[0] > 1
|
214
226
|
and "anyres" in image_aspect_ratio
|
215
|
-
and modalities_list[image_idx] ==
|
227
|
+
and modalities_list[image_idx] == Modality.IMAGE
|
216
228
|
):
|
217
229
|
base_image_feature = image_feature[0]
|
218
230
|
image_feature = image_feature[1:]
|
@@ -312,7 +324,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
312
324
|
)
|
313
325
|
image_feature = image_feature.unsqueeze(0)
|
314
326
|
else:
|
315
|
-
if modalities_list[image_idx] ==
|
327
|
+
if modalities_list[image_idx] == Modality.VIDEO: # video
|
316
328
|
# 2x2 pooling
|
317
329
|
num_of_frames = image_feature.shape[0]
|
318
330
|
image_feature = image_feature.view(
|
sglang/srt/models/llavavid.py
CHANGED
@@ -22,7 +22,7 @@ from transformers import CLIPVisionModel, LlavaConfig
|
|
22
22
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
23
23
|
|
24
24
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
25
|
-
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
25
|
+
from sglang.srt.managers.schedule_batch import MultimodalInputs, flatten_nested_list
|
26
26
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
27
27
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
28
28
|
from sglang.srt.models.llama import LlamaForCausalLM
|
@@ -58,7 +58,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|
58
58
|
)
|
59
59
|
|
60
60
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
61
|
-
pad_values = image_inputs.
|
61
|
+
pad_values = [item.pad_value for item in image_inputs.mm_items]
|
62
62
|
new_image_feature_len = self.image_feature_len
|
63
63
|
|
64
64
|
pad_ids = pad_values * (
|
@@ -133,11 +133,19 @@ class LlavaVidForCausalLM(nn.Module):
|
|
133
133
|
need_vision = start_positions <= np.array(max_image_offset)
|
134
134
|
|
135
135
|
if need_vision.any():
|
136
|
-
pixel_values =
|
137
|
-
|
138
|
-
|
136
|
+
pixel_values = flatten_nested_list(
|
137
|
+
[
|
138
|
+
[item.pixel_values for item in image_inputs[i].mm_items]
|
139
|
+
for i in range(bs)
|
140
|
+
if need_vision[i]
|
141
|
+
]
|
142
|
+
)
|
139
143
|
image_offsets = [
|
140
|
-
|
144
|
+
flatten_nested_list(
|
145
|
+
[item.image_offsets for item in image_inputs[i].mm_items]
|
146
|
+
)
|
147
|
+
for i in range(bs)
|
148
|
+
if need_vision[i]
|
141
149
|
]
|
142
150
|
|
143
151
|
########## Encode Image ########
|
@@ -246,7 +254,8 @@ class LlavaVidForCausalLM(nn.Module):
|
|
246
254
|
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
247
255
|
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
|
248
256
|
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
|
249
|
-
"model.vision_tower.vision_tower": "vision_tower",
|
257
|
+
"model.vision_tower.vision_tower": "vision_tower",
|
258
|
+
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
250
259
|
"model.image_newline": "language_model.model.image_newline",
|
251
260
|
}
|
252
261
|
params_dict = dict(self.named_parameters())
|