sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +1 -11
- sglang/bench_serving.py +149 -1
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +17 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +30 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +14 -2
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +5 -0
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/lora/lora_manager.py +10 -13
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/schedule_batch.py +19 -1
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +28 -13
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +9 -12
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/model_executor/model_runner.py +44 -33
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +55 -20
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +1 -1
- sglang/srt/models/llama4.py +53 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +24 -40
- sglang/srt/openai_api/protocol.py +28 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +30 -6
- sglang/srt/utils.py +35 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,670 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==========================582====================================================
|
14
|
+
|
15
|
+
from typing import Iterable, List, Optional, Tuple, Union
|
16
|
+
|
17
|
+
import torch
|
18
|
+
|
19
|
+
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py
|
20
|
+
# Adapted from https://raw.githubusercontent.com/hehesangsj/sglang/refs/heads/internvl/python/sglang/srt/models/internvl.py
|
21
|
+
import torch.nn.functional as F
|
22
|
+
from einops import rearrange, repeat
|
23
|
+
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
24
|
+
from torch import nn
|
25
|
+
from transformers import PretrainedConfig, PreTrainedModel
|
26
|
+
from transformers.activations import ACT2FN
|
27
|
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
28
|
+
|
29
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
30
|
+
from sglang.srt.managers.mm_utils import (
|
31
|
+
MultiModalityDataPaddingPatternTokenPairs,
|
32
|
+
general_mm_embed_routine,
|
33
|
+
)
|
34
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
35
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
36
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
37
|
+
from sglang.srt.models.deepseek_janus_pro import DropPath
|
38
|
+
from sglang.srt.models.internlm2 import InternLM2ForCausalLM
|
39
|
+
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
40
|
+
from sglang.utils import logger
|
41
|
+
|
42
|
+
|
43
|
+
class FlashAttention(nn.Module):
|
44
|
+
"""Implement the scaled dot product attention with softmax.
|
45
|
+
Arguments
|
46
|
+
---------
|
47
|
+
softmax_scale: The temperature to use for the softmax attention.
|
48
|
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
49
|
+
runtime)
|
50
|
+
attention_dropout: The dropout rate to apply to the attention
|
51
|
+
(default: 0.0)
|
52
|
+
"""
|
53
|
+
|
54
|
+
def __init__(
|
55
|
+
self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None
|
56
|
+
):
|
57
|
+
super().__init__()
|
58
|
+
self.softmax_scale = softmax_scale
|
59
|
+
self.dropout_p = attention_dropout
|
60
|
+
|
61
|
+
def forward(
|
62
|
+
self,
|
63
|
+
qkv,
|
64
|
+
causal=False,
|
65
|
+
max_s=None,
|
66
|
+
):
|
67
|
+
"""Implements the multihead softmax attention.
|
68
|
+
Arguments
|
69
|
+
---------
|
70
|
+
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
71
|
+
if unpadded: (nnz, 3, h, d)
|
72
|
+
"""
|
73
|
+
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
74
|
+
assert qkv.is_cuda
|
75
|
+
|
76
|
+
batch_size, seqlen, _, nheads, d = qkv.shape
|
77
|
+
if batch_size == 0 or seqlen == 0:
|
78
|
+
output_shape = (batch_size, seqlen, nheads, d)
|
79
|
+
return (
|
80
|
+
torch.zeros(output_shape, dtype=qkv.dtype, device=qkv.device),
|
81
|
+
None,
|
82
|
+
)
|
83
|
+
|
84
|
+
qkv_reshaped = rearrange(qkv, "b s three h d -> (b s) three h d", three=3)
|
85
|
+
q, k, v = qkv_reshaped.unbind(1)
|
86
|
+
|
87
|
+
max_s = seqlen
|
88
|
+
cu_seqlens = torch.arange(
|
89
|
+
0,
|
90
|
+
(batch_size + 1) * seqlen,
|
91
|
+
step=seqlen,
|
92
|
+
dtype=torch.int32,
|
93
|
+
device=qkv.device,
|
94
|
+
)
|
95
|
+
output_reshaped = flash_attn_varlen_func(
|
96
|
+
q,
|
97
|
+
k,
|
98
|
+
v,
|
99
|
+
cu_seqlens,
|
100
|
+
cu_seqlens,
|
101
|
+
max_s,
|
102
|
+
max_s,
|
103
|
+
softmax_scale=self.softmax_scale,
|
104
|
+
causal=causal,
|
105
|
+
)
|
106
|
+
output = rearrange(output_reshaped, "(b s) h d -> b s h d", b=batch_size)
|
107
|
+
return output, None
|
108
|
+
|
109
|
+
|
110
|
+
class InternAttention(nn.Module):
|
111
|
+
def __init__(self, config):
|
112
|
+
super().__init__()
|
113
|
+
self.config = config
|
114
|
+
self.embed_dim = config.hidden_size
|
115
|
+
self.num_heads = config.num_attention_heads
|
116
|
+
self.head_dim = self.embed_dim // self.num_heads
|
117
|
+
|
118
|
+
self.scale = self.head_dim**-0.5
|
119
|
+
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
|
120
|
+
self.proj_drop = nn.Dropout(config.dropout)
|
121
|
+
|
122
|
+
self.qk_normalization = config.qk_normalization
|
123
|
+
|
124
|
+
if self.qk_normalization:
|
125
|
+
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
126
|
+
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
127
|
+
|
128
|
+
self.inner_attn = FlashAttention(softmax_scale=self.scale)
|
129
|
+
|
130
|
+
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
131
|
+
|
132
|
+
def _flash_attn(
|
133
|
+
self,
|
134
|
+
x,
|
135
|
+
):
|
136
|
+
qkv = self.qkv(x)
|
137
|
+
qkv = rearrange(
|
138
|
+
qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads
|
139
|
+
)
|
140
|
+
|
141
|
+
if self.qk_normalization:
|
142
|
+
q, k, v = qkv.unbind(2)
|
143
|
+
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
|
144
|
+
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
|
145
|
+
qkv = torch.stack([q, k, v], dim=2)
|
146
|
+
|
147
|
+
context, _ = self.inner_attn(
|
148
|
+
qkv,
|
149
|
+
)
|
150
|
+
outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
|
151
|
+
outs = self.proj_drop(outs)
|
152
|
+
return outs
|
153
|
+
|
154
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
155
|
+
x = self._flash_attn(hidden_states)
|
156
|
+
return x
|
157
|
+
|
158
|
+
|
159
|
+
class InternVisionEmbeddings(nn.Module):
|
160
|
+
def __init__(self, config: PretrainedConfig):
|
161
|
+
super().__init__()
|
162
|
+
self.config = config
|
163
|
+
self.embed_dim = config.hidden_size
|
164
|
+
self.image_size = config.image_size
|
165
|
+
self.patch_size = config.patch_size
|
166
|
+
|
167
|
+
self.class_embedding = nn.Parameter(
|
168
|
+
torch.randn(1, 1, self.embed_dim),
|
169
|
+
)
|
170
|
+
|
171
|
+
self.patch_embedding = nn.Conv2d(
|
172
|
+
in_channels=3,
|
173
|
+
out_channels=self.embed_dim,
|
174
|
+
kernel_size=self.patch_size,
|
175
|
+
stride=self.patch_size,
|
176
|
+
)
|
177
|
+
|
178
|
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
179
|
+
self.num_positions = self.num_patches + 1
|
180
|
+
|
181
|
+
self.position_embedding = nn.Parameter(
|
182
|
+
torch.randn(1, self.num_positions, self.embed_dim)
|
183
|
+
)
|
184
|
+
|
185
|
+
def _get_pos_embed(self, pos_embed, H, W):
|
186
|
+
target_dtype = pos_embed.dtype
|
187
|
+
pos_embed = (
|
188
|
+
pos_embed.float()
|
189
|
+
.reshape(
|
190
|
+
1,
|
191
|
+
self.image_size // self.patch_size,
|
192
|
+
self.image_size // self.patch_size,
|
193
|
+
-1,
|
194
|
+
)
|
195
|
+
.permute(0, 3, 1, 2)
|
196
|
+
)
|
197
|
+
pos_embed = (
|
198
|
+
F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False)
|
199
|
+
.reshape(1, -1, H * W)
|
200
|
+
.permute(0, 2, 1)
|
201
|
+
.to(target_dtype)
|
202
|
+
)
|
203
|
+
return pos_embed
|
204
|
+
|
205
|
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
206
|
+
target_dtype = self.patch_embedding.weight.dtype
|
207
|
+
patch_embeds = self.patch_embedding(
|
208
|
+
pixel_values
|
209
|
+
) # shape = [*, channel, width, height]
|
210
|
+
batch_size, _, height, width = patch_embeds.shape
|
211
|
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
212
|
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
213
|
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
214
|
+
position_embedding = torch.cat(
|
215
|
+
[
|
216
|
+
self.position_embedding[:, :1, :],
|
217
|
+
self._get_pos_embed(self.position_embedding[:, 1:, :], height, width),
|
218
|
+
],
|
219
|
+
dim=1,
|
220
|
+
)
|
221
|
+
embeddings = embeddings + position_embedding.to(target_dtype)
|
222
|
+
return embeddings
|
223
|
+
|
224
|
+
|
225
|
+
class InternRMSNorm(nn.Module):
|
226
|
+
def __init__(self, hidden_size, eps=1e-6):
|
227
|
+
super().__init__()
|
228
|
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
229
|
+
self.variance_epsilon = eps
|
230
|
+
|
231
|
+
def forward(self, hidden_states):
|
232
|
+
input_dtype = hidden_states.dtype
|
233
|
+
hidden_states = hidden_states.to(torch.float32)
|
234
|
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
235
|
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
236
|
+
return self.weight * hidden_states.to(input_dtype)
|
237
|
+
|
238
|
+
|
239
|
+
class InternMLP(nn.Module):
|
240
|
+
def __init__(self, config: PretrainedConfig):
|
241
|
+
super().__init__()
|
242
|
+
self.config = config
|
243
|
+
self.act = ACT2FN[config.hidden_act]
|
244
|
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
245
|
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
246
|
+
|
247
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
248
|
+
hidden_states = self.fc1(hidden_states)
|
249
|
+
hidden_states = self.act(hidden_states)
|
250
|
+
hidden_states = self.fc2(hidden_states)
|
251
|
+
return hidden_states
|
252
|
+
|
253
|
+
|
254
|
+
NORM2FN = {
|
255
|
+
"rms_norm": InternRMSNorm,
|
256
|
+
"layer_norm": nn.LayerNorm,
|
257
|
+
}
|
258
|
+
|
259
|
+
|
260
|
+
class InternVisionEncoderLayer(nn.Module):
|
261
|
+
|
262
|
+
def __init__(
|
263
|
+
self,
|
264
|
+
config: PretrainedConfig,
|
265
|
+
drop_path_rate: float,
|
266
|
+
quant_config: QuantizationConfig = None,
|
267
|
+
):
|
268
|
+
super().__init__()
|
269
|
+
self.embed_dim = config.hidden_size
|
270
|
+
self.intermediate_size = config.intermediate_size
|
271
|
+
self.norm_type = config.norm_type
|
272
|
+
self.attn = InternAttention(config)
|
273
|
+
self.mlp = InternMLP(config)
|
274
|
+
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
275
|
+
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
276
|
+
|
277
|
+
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
278
|
+
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
279
|
+
self.drop_path1 = (
|
280
|
+
DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
281
|
+
)
|
282
|
+
self.drop_path2 = (
|
283
|
+
DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
284
|
+
)
|
285
|
+
|
286
|
+
def forward(
|
287
|
+
self,
|
288
|
+
hidden_states: torch.Tensor,
|
289
|
+
) -> Tuple[
|
290
|
+
torch.FloatTensor,
|
291
|
+
Optional[torch.FloatTensor],
|
292
|
+
Optional[Tuple[torch.FloatTensor]],
|
293
|
+
]:
|
294
|
+
"""
|
295
|
+
Args:
|
296
|
+
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
297
|
+
"""
|
298
|
+
hidden_states = hidden_states + self.drop_path1(
|
299
|
+
self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1
|
300
|
+
)
|
301
|
+
|
302
|
+
hidden_states = hidden_states + self.drop_path2(
|
303
|
+
self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2
|
304
|
+
)
|
305
|
+
|
306
|
+
return hidden_states
|
307
|
+
|
308
|
+
|
309
|
+
class InternVisionEncoder(nn.Module):
|
310
|
+
"""
|
311
|
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
312
|
+
[`InternEncoderLayer`].
|
313
|
+
|
314
|
+
Args:
|
315
|
+
config (`InternConfig`):
|
316
|
+
The corresponding vision configuration for the `InternEncoder`.
|
317
|
+
"""
|
318
|
+
|
319
|
+
def __init__(
|
320
|
+
self,
|
321
|
+
config: PretrainedConfig,
|
322
|
+
quant_config: Optional[QuantizationConfig] = None,
|
323
|
+
):
|
324
|
+
super().__init__()
|
325
|
+
self.config = config
|
326
|
+
# stochastic depth decay rule
|
327
|
+
dpr = [
|
328
|
+
x.item()
|
329
|
+
for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)
|
330
|
+
]
|
331
|
+
self.layers = nn.ModuleList(
|
332
|
+
[
|
333
|
+
InternVisionEncoderLayer(config, dpr[idx], quant_config)
|
334
|
+
for idx in range(config.num_hidden_layers)
|
335
|
+
]
|
336
|
+
)
|
337
|
+
|
338
|
+
def forward(
|
339
|
+
self,
|
340
|
+
inputs_embeds,
|
341
|
+
output_hidden_states: Optional[bool] = None,
|
342
|
+
return_dict: Optional[bool] = None,
|
343
|
+
) -> Union[Tuple, BaseModelOutput]:
|
344
|
+
r"""
|
345
|
+
Args:
|
346
|
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
347
|
+
Embedded representation of the inputs. Should be float, not int tokens.
|
348
|
+
output_hidden_states (`bool`, *optional*):
|
349
|
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
350
|
+
for more detail.
|
351
|
+
return_dict (`bool`, *optional*):
|
352
|
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
353
|
+
"""
|
354
|
+
output_hidden_states = (
|
355
|
+
output_hidden_states
|
356
|
+
if output_hidden_states is not None
|
357
|
+
else self.config.output_hidden_states
|
358
|
+
)
|
359
|
+
return_dict = (
|
360
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
361
|
+
)
|
362
|
+
|
363
|
+
encoder_states = () if output_hidden_states else None
|
364
|
+
hidden_states = inputs_embeds
|
365
|
+
|
366
|
+
for idx, encoder_layer in enumerate(self.layers):
|
367
|
+
if output_hidden_states:
|
368
|
+
encoder_states = encoder_states + (hidden_states,)
|
369
|
+
layer_outputs = encoder_layer(
|
370
|
+
hidden_states,
|
371
|
+
)
|
372
|
+
hidden_states = layer_outputs
|
373
|
+
|
374
|
+
if output_hidden_states:
|
375
|
+
encoder_states = encoder_states + (hidden_states,)
|
376
|
+
|
377
|
+
if not return_dict:
|
378
|
+
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
|
379
|
+
return BaseModelOutput(
|
380
|
+
last_hidden_state=hidden_states, hidden_states=encoder_states
|
381
|
+
)
|
382
|
+
|
383
|
+
|
384
|
+
class InternVisionModel(PreTrainedModel):
|
385
|
+
main_input_name = "pixel_values"
|
386
|
+
_supports_flash_attn_2 = True
|
387
|
+
config_class = PretrainedConfig
|
388
|
+
_no_split_modules = ["InternVisionEncoderLayer"]
|
389
|
+
|
390
|
+
def __init__(
|
391
|
+
self,
|
392
|
+
config: PretrainedConfig,
|
393
|
+
quant_config: Optional[QuantizationConfig] = None,
|
394
|
+
):
|
395
|
+
super().__init__(config)
|
396
|
+
self.config = config
|
397
|
+
|
398
|
+
self.embeddings = InternVisionEmbeddings(
|
399
|
+
config,
|
400
|
+
)
|
401
|
+
self.encoder = InternVisionEncoder(config, quant_config)
|
402
|
+
|
403
|
+
def resize_pos_embeddings(self, old_size, new_size, patch_size):
|
404
|
+
pos_emb = self.embeddings.position_embedding
|
405
|
+
_, num_positions, embed_dim = pos_emb.shape
|
406
|
+
cls_emb = pos_emb[:, :1, :]
|
407
|
+
pos_emb = (
|
408
|
+
pos_emb[:, 1:, :]
|
409
|
+
.reshape(1, old_size // patch_size, old_size // patch_size, -1)
|
410
|
+
.permute(0, 3, 1, 2)
|
411
|
+
)
|
412
|
+
pos_emb = F.interpolate(
|
413
|
+
pos_emb.float(),
|
414
|
+
size=new_size // patch_size,
|
415
|
+
mode="bicubic",
|
416
|
+
align_corners=False,
|
417
|
+
)
|
418
|
+
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
|
419
|
+
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
|
420
|
+
self.embeddings.position_embedding = nn.Parameter(pos_emb)
|
421
|
+
self.embeddings.image_size = new_size
|
422
|
+
logger.info(
|
423
|
+
"Resized position embeddings from {} to {}".format(old_size, new_size)
|
424
|
+
)
|
425
|
+
|
426
|
+
def get_input_embeddings(self):
|
427
|
+
return self.embeddings
|
428
|
+
|
429
|
+
def forward(
|
430
|
+
self,
|
431
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
432
|
+
output_hidden_states: Optional[bool] = None,
|
433
|
+
return_dict: Optional[bool] = None,
|
434
|
+
pixel_embeds: Optional[torch.FloatTensor] = None,
|
435
|
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
436
|
+
pixel_values = pixel_values.to(device=self.device, dtype=self.dtype)
|
437
|
+
output_hidden_states = (
|
438
|
+
output_hidden_states
|
439
|
+
if output_hidden_states is not None
|
440
|
+
else self.config.output_hidden_states
|
441
|
+
)
|
442
|
+
return_dict = (
|
443
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
444
|
+
)
|
445
|
+
|
446
|
+
if pixel_values is None and pixel_embeds is None:
|
447
|
+
raise ValueError("You have to specify pixel_values or pixel_embeds")
|
448
|
+
|
449
|
+
if pixel_embeds is not None:
|
450
|
+
hidden_states = pixel_embeds
|
451
|
+
else:
|
452
|
+
if len(pixel_values.shape) == 4:
|
453
|
+
hidden_states = self.embeddings(pixel_values)
|
454
|
+
else:
|
455
|
+
raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
|
456
|
+
encoder_outputs = self.encoder(
|
457
|
+
inputs_embeds=hidden_states,
|
458
|
+
output_hidden_states=output_hidden_states,
|
459
|
+
return_dict=return_dict,
|
460
|
+
)
|
461
|
+
last_hidden_state = encoder_outputs.last_hidden_state
|
462
|
+
pooled_output = last_hidden_state[:, 0, :]
|
463
|
+
|
464
|
+
if not return_dict:
|
465
|
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
466
|
+
|
467
|
+
return BaseModelOutputWithPooling(
|
468
|
+
last_hidden_state=last_hidden_state,
|
469
|
+
pooler_output=pooled_output,
|
470
|
+
hidden_states=encoder_outputs.hidden_states,
|
471
|
+
attentions=encoder_outputs.attentions,
|
472
|
+
)
|
473
|
+
|
474
|
+
|
475
|
+
class InternVLChatModel(nn.Module):
|
476
|
+
def __init__(
|
477
|
+
self,
|
478
|
+
config: PretrainedConfig,
|
479
|
+
quant_config: Optional[QuantizationConfig] = None,
|
480
|
+
use_flash_attn=True,
|
481
|
+
) -> None:
|
482
|
+
super().__init__()
|
483
|
+
self.config = config
|
484
|
+
self.quant_config = quant_config
|
485
|
+
|
486
|
+
image_size = config.force_image_size or config.vision_config.image_size
|
487
|
+
patch_size = config.vision_config.patch_size
|
488
|
+
self.patch_size = patch_size
|
489
|
+
self.select_layer = config.select_layer
|
490
|
+
self.template = config.template
|
491
|
+
self.num_image_token = int(
|
492
|
+
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
|
493
|
+
)
|
494
|
+
self.downsample_ratio = config.downsample_ratio
|
495
|
+
self.ps_version = config.ps_version
|
496
|
+
|
497
|
+
config.vision_config.use_flash_attn = True if use_flash_attn else False
|
498
|
+
config.llm_config._attn_implementation = (
|
499
|
+
"flash_attention_2" if use_flash_attn else "eager"
|
500
|
+
)
|
501
|
+
|
502
|
+
logger.info(f"num_image_token: {self.num_image_token}")
|
503
|
+
logger.info(f"ps_version: {self.ps_version}")
|
504
|
+
|
505
|
+
self.vision_model = InternVisionModel(config.vision_config)
|
506
|
+
if config.llm_config.architectures[0] == "Qwen2ForCausalLM":
|
507
|
+
self.language_model = Qwen2ForCausalLM(
|
508
|
+
config=config.llm_config, quant_config=quant_config
|
509
|
+
)
|
510
|
+
elif config.llm_config.architectures[0] == "InternLM2ForCausalLM":
|
511
|
+
self.language_model = InternLM2ForCausalLM(
|
512
|
+
config=config.llm_config, quant_config=quant_config
|
513
|
+
)
|
514
|
+
else:
|
515
|
+
raise NotImplementedError(
|
516
|
+
f"{config.llm_config.architectures[0]} is not implemented."
|
517
|
+
)
|
518
|
+
|
519
|
+
vit_hidden_size = config.vision_config.hidden_size
|
520
|
+
llm_hidden_size = config.llm_config.hidden_size
|
521
|
+
|
522
|
+
self.mlp1 = nn.Sequential(
|
523
|
+
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
|
524
|
+
nn.Linear(
|
525
|
+
vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size
|
526
|
+
),
|
527
|
+
nn.GELU(),
|
528
|
+
nn.Linear(llm_hidden_size, llm_hidden_size),
|
529
|
+
)
|
530
|
+
|
531
|
+
def pixel_shuffle(self, x, scale_factor=0.5):
|
532
|
+
n, w, h, c = x.size()
|
533
|
+
# N, W, H, C --> N, W, H * scale, C // scale
|
534
|
+
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
535
|
+
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
536
|
+
x = x.permute(0, 2, 1, 3).contiguous()
|
537
|
+
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
538
|
+
x = x.view(
|
539
|
+
n,
|
540
|
+
int(h * scale_factor),
|
541
|
+
int(w * scale_factor),
|
542
|
+
int(c / (scale_factor * scale_factor)),
|
543
|
+
)
|
544
|
+
if self.ps_version == "v1":
|
545
|
+
logger.warn(
|
546
|
+
"In ps_version 'v1', the height and width have not been swapped back, "
|
547
|
+
"which results in a transposed image."
|
548
|
+
)
|
549
|
+
else:
|
550
|
+
x = x.permute(0, 2, 1, 3).contiguous()
|
551
|
+
return x
|
552
|
+
|
553
|
+
def extract_feature(self, pixel_values):
|
554
|
+
if self.select_layer == -1:
|
555
|
+
vit_embeds = self.vision_model(
|
556
|
+
pixel_values=pixel_values, output_hidden_states=False, return_dict=True
|
557
|
+
).last_hidden_state
|
558
|
+
else:
|
559
|
+
vit_embeds = self.vision_model(
|
560
|
+
pixel_values=pixel_values, output_hidden_states=True, return_dict=True
|
561
|
+
).hidden_states[self.select_layer]
|
562
|
+
vit_embeds = vit_embeds[:, 1:, :]
|
563
|
+
|
564
|
+
h = w = int(vit_embeds.shape[1] ** 0.5)
|
565
|
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
566
|
+
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
|
567
|
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
568
|
+
vit_embeds = self.mlp1(vit_embeds)
|
569
|
+
return vit_embeds
|
570
|
+
|
571
|
+
def get_image_feature(self, items: List[MultimodalDataItem]):
|
572
|
+
"""
|
573
|
+
Projects the last hidden state from the vision model into language model space.
|
574
|
+
|
575
|
+
Returns:
|
576
|
+
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
577
|
+
"""
|
578
|
+
pixel_values = torch.cat([item.pixel_values for item in items])
|
579
|
+
image_features = self.extract_feature(pixel_values)
|
580
|
+
return image_features
|
581
|
+
|
582
|
+
@torch.no_grad()
|
583
|
+
def forward(
|
584
|
+
self,
|
585
|
+
input_ids: torch.Tensor,
|
586
|
+
positions: torch.Tensor,
|
587
|
+
forward_batch: ForwardBatch,
|
588
|
+
input_embeds: torch.Tensor = None,
|
589
|
+
) -> torch.Tensor:
|
590
|
+
|
591
|
+
hs = general_mm_embed_routine(
|
592
|
+
input_ids=input_ids,
|
593
|
+
forward_batch=forward_batch,
|
594
|
+
language_model=self.language_model,
|
595
|
+
image_data_embedding_func=self.get_image_feature,
|
596
|
+
positions=positions,
|
597
|
+
)
|
598
|
+
|
599
|
+
return hs
|
600
|
+
|
601
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
602
|
+
# Get all special token IDs
|
603
|
+
im_start_id: int = mm_inputs.im_start_id
|
604
|
+
im_end_id: int = mm_inputs.im_end_id
|
605
|
+
|
606
|
+
media_token_pairs = [(im_start_id, im_end_id)]
|
607
|
+
helper = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
608
|
+
|
609
|
+
return helper.pad_input_tokens(input_ids, mm_inputs)
|
610
|
+
|
611
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
612
|
+
if "InternLM2ForCausalLM" in self.config.llm_config.architectures:
|
613
|
+
stacked_params_mapping = [
|
614
|
+
# (param_name, shard_name, shard_id)
|
615
|
+
("gate_up_proj", "w1", 0),
|
616
|
+
("gate_up_proj", "w3", 1),
|
617
|
+
]
|
618
|
+
elif "Qwen2ForCausalLM" in self.config.llm_config.architectures:
|
619
|
+
stacked_params_mapping = [
|
620
|
+
# (param_name, shard_name, shard_id)
|
621
|
+
("qkv_proj", "q_proj", "q"),
|
622
|
+
("qkv_proj", "k_proj", "k"),
|
623
|
+
("qkv_proj", "v_proj", "v"),
|
624
|
+
("gate_up_proj", "gate_proj", 0),
|
625
|
+
("gate_up_proj", "up_proj", 1),
|
626
|
+
]
|
627
|
+
params_dict = dict(self.named_parameters())
|
628
|
+
|
629
|
+
for name, loaded_weight in weights:
|
630
|
+
if "rotary_emb.inv_freq" in name:
|
631
|
+
continue
|
632
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
633
|
+
if weight_name not in name:
|
634
|
+
continue
|
635
|
+
name = name.replace(weight_name, param_name)
|
636
|
+
# Skip loading extra bias for GPTQ models.
|
637
|
+
if name.endswith(".bias") and name not in params_dict:
|
638
|
+
continue
|
639
|
+
param = params_dict[name]
|
640
|
+
weight_loader = param.weight_loader
|
641
|
+
weight_loader(param, loaded_weight, shard_id)
|
642
|
+
break
|
643
|
+
else:
|
644
|
+
# Skip loading extra bias for GPTQ models.
|
645
|
+
if name.endswith(".bias") and name not in params_dict:
|
646
|
+
continue
|
647
|
+
param = params_dict[name]
|
648
|
+
if "wqkv" in name:
|
649
|
+
config = self.config
|
650
|
+
kv_groups = config.num_attention_heads // config.num_key_value_heads
|
651
|
+
head_dim = config.hidden_size // config.num_attention_heads
|
652
|
+
loaded_weight = loaded_weight.view(
|
653
|
+
-1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
|
654
|
+
)
|
655
|
+
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], dim=1)
|
656
|
+
wq = wq.reshape(-1, wq.shape[-1])
|
657
|
+
wk = wk.reshape(-1, wk.shape[-1])
|
658
|
+
wv = wv.reshape(-1, wv.shape[-1])
|
659
|
+
weight_loader = param.weight_loader
|
660
|
+
weight_loader(param, wq, "q")
|
661
|
+
weight_loader(param, wk, "k")
|
662
|
+
weight_loader(param, wv, "v")
|
663
|
+
else:
|
664
|
+
weight_loader = getattr(
|
665
|
+
param, "weight_loader", default_weight_loader
|
666
|
+
)
|
667
|
+
weight_loader(param, loaded_weight)
|
668
|
+
|
669
|
+
|
670
|
+
EntryClass = InternVLChatModel
|
sglang/srt/models/llama.py
CHANGED