sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +23 -3
- sglang/srt/configs/deepseekvl2.py +10 -1
- sglang/srt/configs/model_config.py +5 -16
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +32 -5
- sglang/srt/entrypoints/http_server.py +7 -1
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/attention/flashattention_backend.py +218 -79
- sglang/srt/layers/dp_attention.py +12 -1
- sglang/srt/layers/moe/topk.py +30 -3
- sglang/srt/layers/quantization/__init__.py +134 -165
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/fp8_kernel.py +2 -1
- sglang/srt/layers/quantization/gptq.py +30 -40
- sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +12 -0
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +19 -33
- sglang/srt/lora/lora_manager.py +20 -7
- sglang/srt/lora/mem_pool.py +12 -6
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +6 -0
- sglang/srt/managers/io_struct.py +4 -2
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/schedule_batch.py +1 -0
- sglang/srt/managers/scheduler.py +25 -19
- sglang/srt/managers/tokenizer_manager.py +0 -1
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -8
- sglang/srt/model_executor/model_runner.py +9 -6
- sglang/srt/model_loader/loader.py +11 -1
- sglang/srt/model_loader/weight_utils.py +6 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +151 -26
- sglang/srt/models/gemma3_causal.py +12 -2
- sglang/srt/models/gemma3_mm.py +6 -0
- sglang/srt/openai_api/adapter.py +88 -87
- sglang/srt/openai_api/protocol.py +10 -5
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/server_args.py +21 -11
- sglang/srt/speculative/eagle_worker.py +1 -1
- sglang/srt/utils.py +33 -0
- sglang/test/runners.py +27 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +8 -4
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +57 -53
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,563 @@
|
|
1
|
+
# Adapted from
|
2
|
+
# https://github.com/huggingface/transformers/blob/af9b2eaa54c150741f298d6db939af6328e1dc38/src/transformers/models/clip/modeling_clip.py
|
3
|
+
|
4
|
+
from functools import partial
|
5
|
+
from typing import Iterable, List, Optional, Tuple, Type, Union
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import torch.nn as nn
|
9
|
+
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
10
|
+
from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask
|
11
|
+
|
12
|
+
from sglang.srt.layers.activation import QuickGELU
|
13
|
+
from sglang.srt.layers.attention.vision import VisionAttention
|
14
|
+
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
15
|
+
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
16
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
17
|
+
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
18
|
+
from sglang.srt.model_executor.model_runner import ForwardBatch
|
19
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
20
|
+
from sglang.srt.utils import add_prefix
|
21
|
+
|
22
|
+
|
23
|
+
class CLIPVisionEmbeddings(nn.Module):
|
24
|
+
|
25
|
+
def __init__(self, config: CLIPVisionConfig):
|
26
|
+
super().__init__()
|
27
|
+
self.config = config
|
28
|
+
self.embed_dim = config.hidden_size
|
29
|
+
self.image_size = config.image_size
|
30
|
+
self.patch_size = config.patch_size
|
31
|
+
assert self.image_size % self.patch_size == 0
|
32
|
+
|
33
|
+
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
34
|
+
|
35
|
+
self.patch_embedding = nn.Conv2d(
|
36
|
+
in_channels=config.num_channels,
|
37
|
+
out_channels=self.embed_dim,
|
38
|
+
kernel_size=self.patch_size,
|
39
|
+
stride=self.patch_size,
|
40
|
+
bias=False,
|
41
|
+
)
|
42
|
+
|
43
|
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
44
|
+
self.num_positions = self.num_patches + 1
|
45
|
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
46
|
+
self.register_buffer(
|
47
|
+
"position_ids",
|
48
|
+
torch.arange(self.num_positions).expand((1, -1)),
|
49
|
+
persistent=False,
|
50
|
+
)
|
51
|
+
|
52
|
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
53
|
+
batch_size = pixel_values.shape[0]
|
54
|
+
target_dtype = self.patch_embedding.weight.dtype
|
55
|
+
patch_embeds = self.patch_embedding(
|
56
|
+
pixel_values.to(dtype=target_dtype)
|
57
|
+
) # shape = [*, width, grid, grid]
|
58
|
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
59
|
+
|
60
|
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
61
|
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
62
|
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
63
|
+
|
64
|
+
return embeddings
|
65
|
+
|
66
|
+
|
67
|
+
class CLIPTextEmbeddings(nn.Module):
|
68
|
+
def __init__(self, config: CLIPTextConfig):
|
69
|
+
super().__init__()
|
70
|
+
embed_dim = config.hidden_size
|
71
|
+
|
72
|
+
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
73
|
+
self.position_embedding = nn.Embedding(
|
74
|
+
config.max_position_embeddings, embed_dim
|
75
|
+
)
|
76
|
+
|
77
|
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
78
|
+
self.register_buffer(
|
79
|
+
"position_ids",
|
80
|
+
torch.arange(config.max_position_embeddings).expand((1, -1)),
|
81
|
+
persistent=False,
|
82
|
+
)
|
83
|
+
|
84
|
+
def forward(
|
85
|
+
self,
|
86
|
+
input_ids: Optional[torch.LongTensor] = None,
|
87
|
+
position_ids: Optional[torch.LongTensor] = None,
|
88
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
89
|
+
) -> torch.Tensor:
|
90
|
+
seq_length = (
|
91
|
+
input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
92
|
+
)
|
93
|
+
|
94
|
+
if position_ids is None:
|
95
|
+
position_ids = self.position_ids[:, :seq_length]
|
96
|
+
|
97
|
+
if inputs_embeds is None:
|
98
|
+
inputs_embeds = self.token_embedding(input_ids)
|
99
|
+
|
100
|
+
position_embeddings = self.position_embedding(position_ids)
|
101
|
+
embeddings = inputs_embeds + position_embeddings
|
102
|
+
|
103
|
+
return embeddings
|
104
|
+
|
105
|
+
|
106
|
+
class CLIPMLP(nn.Module):
|
107
|
+
|
108
|
+
def __init__(
|
109
|
+
self,
|
110
|
+
config,
|
111
|
+
act_layer: Type[nn.Module] = QuickGELU,
|
112
|
+
quant_config: Optional[QuantizationConfig] = None,
|
113
|
+
prefix: str = "",
|
114
|
+
):
|
115
|
+
super().__init__()
|
116
|
+
self.fc1 = ColumnParallelLinear(
|
117
|
+
config.hidden_size,
|
118
|
+
config.intermediate_size,
|
119
|
+
quant_config=quant_config,
|
120
|
+
prefix=add_prefix("fc1", prefix),
|
121
|
+
)
|
122
|
+
self.act = act_layer()
|
123
|
+
self.fc2 = RowParallelLinear(
|
124
|
+
config.intermediate_size,
|
125
|
+
config.hidden_size,
|
126
|
+
quant_config=quant_config,
|
127
|
+
prefix=add_prefix("fc2", prefix),
|
128
|
+
)
|
129
|
+
|
130
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
131
|
+
x_parallel, _ = self.fc1(x)
|
132
|
+
x_parallel = self.act(x_parallel)
|
133
|
+
x, _ = self.fc2(x_parallel)
|
134
|
+
return x
|
135
|
+
|
136
|
+
|
137
|
+
class CLIPEncoderLayer(nn.Module):
|
138
|
+
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
config: CLIPVisionConfig,
|
142
|
+
act_layer: Type[nn.Module] = QuickGELU,
|
143
|
+
norm_layer: Type[nn.Module] = None,
|
144
|
+
attn_implementation: Optional[str] = "sdpa",
|
145
|
+
quant_config: Optional[QuantizationConfig] = None,
|
146
|
+
prefix: str = "",
|
147
|
+
) -> None:
|
148
|
+
super().__init__()
|
149
|
+
if norm_layer is None:
|
150
|
+
norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
|
151
|
+
self.layer_norm1 = norm_layer(config.hidden_size)
|
152
|
+
self.layer_norm2 = norm_layer(config.hidden_size)
|
153
|
+
if attn_implementation == "sdpa":
|
154
|
+
use_context_forward = False
|
155
|
+
softmax_in_single_precision = False
|
156
|
+
elif attn_implementation == "flash_attention_2":
|
157
|
+
softmax_in_single_precision = False
|
158
|
+
use_context_forward = True
|
159
|
+
elif attn_implementation == "eager":
|
160
|
+
softmax_in_single_precision = True
|
161
|
+
use_context_forward = False
|
162
|
+
self.self_attn = VisionAttention(
|
163
|
+
embed_dim=config.hidden_size,
|
164
|
+
num_heads=config.num_attention_heads,
|
165
|
+
projection_size=config.hidden_size,
|
166
|
+
use_qkv_parallel=True,
|
167
|
+
use_context_forward=use_context_forward,
|
168
|
+
softmax_in_single_precision=softmax_in_single_precision,
|
169
|
+
flatten_batch=True,
|
170
|
+
quant_config=quant_config,
|
171
|
+
prefix=add_prefix("attn", prefix),
|
172
|
+
)
|
173
|
+
self.mlp = CLIPMLP(
|
174
|
+
config,
|
175
|
+
act_layer=act_layer,
|
176
|
+
quant_config=quant_config,
|
177
|
+
prefix=add_prefix("mlp", prefix),
|
178
|
+
)
|
179
|
+
|
180
|
+
def forward(
|
181
|
+
self,
|
182
|
+
hidden_states: torch.Tensor,
|
183
|
+
attention_mask: torch.Tensor,
|
184
|
+
causal_attention_mask: torch.Tensor,
|
185
|
+
) -> torch.Tensor:
|
186
|
+
|
187
|
+
residual = hidden_states
|
188
|
+
hidden_states = self.layer_norm1(hidden_states)
|
189
|
+
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
|
190
|
+
if attention_mask is not None and causal_attention_mask is not None:
|
191
|
+
attn_mask = attention_mask + causal_attention_mask
|
192
|
+
elif causal_attention_mask is not None:
|
193
|
+
attn_mask = causal_attention_mask
|
194
|
+
else:
|
195
|
+
attn_mask = attention_mask
|
196
|
+
hidden_states = self.self_attn(
|
197
|
+
hidden_states,
|
198
|
+
attention_mask=attn_mask,
|
199
|
+
# causal_attention_mask=causal_attention_mask,
|
200
|
+
)
|
201
|
+
|
202
|
+
hidden_states = residual + hidden_states
|
203
|
+
residual = hidden_states
|
204
|
+
hidden_states = self.layer_norm2(hidden_states)
|
205
|
+
hidden_states = self.mlp(hidden_states)
|
206
|
+
hidden_states = residual + hidden_states
|
207
|
+
return hidden_states
|
208
|
+
|
209
|
+
|
210
|
+
class CLIPEncoder(nn.Module):
|
211
|
+
"""
|
212
|
+
Transformer encoder consisting of `config.num_hidden_layers` self
|
213
|
+
attention layers. Each layer is a [`CLIPEncoderLayer`].
|
214
|
+
|
215
|
+
Args:
|
216
|
+
config: CLIPConfig
|
217
|
+
"""
|
218
|
+
|
219
|
+
def __init__(
|
220
|
+
self,
|
221
|
+
config: CLIPVisionConfig,
|
222
|
+
quant_config: Optional[QuantizationConfig] = None,
|
223
|
+
prefix: str = "",
|
224
|
+
) -> None:
|
225
|
+
super().__init__()
|
226
|
+
|
227
|
+
self.config = config
|
228
|
+
|
229
|
+
num_hidden_layers = config.num_hidden_layers
|
230
|
+
norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
|
231
|
+
self.layers = nn.ModuleList(
|
232
|
+
[
|
233
|
+
CLIPEncoderLayer(
|
234
|
+
config=config,
|
235
|
+
norm_layer=norm_layer,
|
236
|
+
attn_implementation="sdpa",
|
237
|
+
quant_config=quant_config,
|
238
|
+
prefix=add_prefix(f"layers.{layer_idx}", prefix),
|
239
|
+
)
|
240
|
+
for layer_idx in range(num_hidden_layers)
|
241
|
+
]
|
242
|
+
)
|
243
|
+
|
244
|
+
def forward(
|
245
|
+
self,
|
246
|
+
inputs_embeds: torch.Tensor,
|
247
|
+
attention_mask: torch.Tensor = None,
|
248
|
+
causal_attention_mask: torch.Tensor = None,
|
249
|
+
return_all_hidden_states: bool = False,
|
250
|
+
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
251
|
+
hidden_states_pool = [inputs_embeds]
|
252
|
+
hidden_states = inputs_embeds
|
253
|
+
|
254
|
+
for encoder_layer in self.layers:
|
255
|
+
hidden_states = encoder_layer(
|
256
|
+
hidden_states, attention_mask, causal_attention_mask
|
257
|
+
)
|
258
|
+
if return_all_hidden_states:
|
259
|
+
hidden_states_pool.append(hidden_states)
|
260
|
+
if return_all_hidden_states:
|
261
|
+
return hidden_states_pool
|
262
|
+
return hidden_states
|
263
|
+
|
264
|
+
|
265
|
+
class CLIPTextTransformer(nn.Module):
|
266
|
+
def __init__(
|
267
|
+
self,
|
268
|
+
config: CLIPTextConfig,
|
269
|
+
quant_config: Optional[QuantizationConfig] = None,
|
270
|
+
prefix: str = "",
|
271
|
+
) -> None:
|
272
|
+
super().__init__()
|
273
|
+
self.config = config
|
274
|
+
embed_dim = config.hidden_size
|
275
|
+
self.embeddings = CLIPTextEmbeddings(config)
|
276
|
+
self.encoder = CLIPEncoder(
|
277
|
+
config=config,
|
278
|
+
quant_config=quant_config,
|
279
|
+
prefix=add_prefix("encoder", prefix),
|
280
|
+
)
|
281
|
+
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
282
|
+
|
283
|
+
@property
|
284
|
+
def device(self) -> torch.device:
|
285
|
+
return self.encoder.layers[0].layer_norm1.weight.device
|
286
|
+
|
287
|
+
def forward(
|
288
|
+
self,
|
289
|
+
input_ids: torch.Tensor,
|
290
|
+
attention_mask: Optional[torch.Tensor] = None,
|
291
|
+
position_ids: Optional[torch.Tensor] = None,
|
292
|
+
):
|
293
|
+
input_shape = input_ids.size()
|
294
|
+
input_ids = input_ids.view(-1, input_shape[-1])
|
295
|
+
hidden_states = self.embeddings(input_ids, position_ids)
|
296
|
+
causal_attention_mask = _create_4d_causal_attention_mask(
|
297
|
+
input_ids.shape, hidden_states.dtype, device=hidden_states.device
|
298
|
+
)
|
299
|
+
encoder_outputs = self.encoder(
|
300
|
+
hidden_states, attention_mask, causal_attention_mask
|
301
|
+
)
|
302
|
+
last_hidden_state = self.final_layer_norm(encoder_outputs)
|
303
|
+
return last_hidden_state
|
304
|
+
|
305
|
+
|
306
|
+
class CLIPTextModel(nn.Module):
|
307
|
+
def __init__(
|
308
|
+
self,
|
309
|
+
config: CLIPTextConfig,
|
310
|
+
quant_config: Optional[QuantizationConfig] = None,
|
311
|
+
prefix: str = "",
|
312
|
+
) -> None:
|
313
|
+
super().__init__()
|
314
|
+
self.config = config
|
315
|
+
self.text_model = CLIPTextTransformer(
|
316
|
+
config=config,
|
317
|
+
quant_config=quant_config,
|
318
|
+
prefix=add_prefix("text_model", prefix),
|
319
|
+
)
|
320
|
+
|
321
|
+
def forward(
|
322
|
+
self,
|
323
|
+
input_ids: torch.Tensor,
|
324
|
+
position_ids: torch.Tensor,
|
325
|
+
):
|
326
|
+
return self.text_model(input_ids, position_ids)
|
327
|
+
|
328
|
+
|
329
|
+
class CLIPVisionTransformer(nn.Module):
|
330
|
+
|
331
|
+
def __init__(
|
332
|
+
self,
|
333
|
+
config: CLIPVisionConfig,
|
334
|
+
quant_config: Optional[QuantizationConfig] = None,
|
335
|
+
prefix: str = "",
|
336
|
+
) -> None:
|
337
|
+
super().__init__()
|
338
|
+
|
339
|
+
self.config = config
|
340
|
+
embed_dim = config.hidden_size
|
341
|
+
|
342
|
+
self.embeddings = CLIPVisionEmbeddings(config)
|
343
|
+
|
344
|
+
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
|
345
|
+
# the original transformers code and name of the model weights.
|
346
|
+
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
347
|
+
|
348
|
+
self.encoder = CLIPEncoder(
|
349
|
+
config=config,
|
350
|
+
quant_config=quant_config,
|
351
|
+
prefix=add_prefix("encoder", prefix),
|
352
|
+
)
|
353
|
+
|
354
|
+
num_hidden_layers = config.num_hidden_layers
|
355
|
+
if len(self.encoder.layers) > config.num_hidden_layers:
|
356
|
+
raise ValueError(
|
357
|
+
f"The original encoder only has {num_hidden_layers} "
|
358
|
+
f"layers, but you requested {len(self.encoder.layers)} layers."
|
359
|
+
)
|
360
|
+
|
361
|
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
362
|
+
|
363
|
+
@property
|
364
|
+
def device(self) -> torch.device:
|
365
|
+
return self.encoder.layers[0].layer_norm1.weight.device
|
366
|
+
|
367
|
+
def forward(
|
368
|
+
self,
|
369
|
+
pixel_values: torch.Tensor,
|
370
|
+
) -> torch.Tensor:
|
371
|
+
|
372
|
+
hidden_states = self.embeddings(pixel_values.to(self.device))
|
373
|
+
hidden_states = self.pre_layrnorm(hidden_states)
|
374
|
+
|
375
|
+
return_all_hidden_states = False
|
376
|
+
|
377
|
+
last_hidden_state = self.encoder(
|
378
|
+
inputs_embeds=hidden_states,
|
379
|
+
return_all_hidden_states=return_all_hidden_states,
|
380
|
+
)
|
381
|
+
|
382
|
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
383
|
+
|
384
|
+
return last_hidden_state
|
385
|
+
|
386
|
+
|
387
|
+
class CLIPVisionModel(nn.Module):
|
388
|
+
def __init__(
|
389
|
+
self,
|
390
|
+
config: CLIPVisionConfig,
|
391
|
+
quant_config: Optional[QuantizationConfig] = None,
|
392
|
+
prefix: str = "",
|
393
|
+
):
|
394
|
+
super().__init__()
|
395
|
+
self.vision_model = CLIPVisionTransformer(
|
396
|
+
config, quant_config, prefix=add_prefix("vision_model", prefix)
|
397
|
+
)
|
398
|
+
|
399
|
+
def forward(self, pixel_values: torch.Tensor):
|
400
|
+
return self.vision_model(pixel_values)
|
401
|
+
|
402
|
+
|
403
|
+
class CLIPModel(nn.Module):
|
404
|
+
def __init__(
|
405
|
+
self,
|
406
|
+
config: CLIPConfig,
|
407
|
+
quant_config: Optional[QuantizationConfig] = None,
|
408
|
+
prefix: str = "",
|
409
|
+
) -> None:
|
410
|
+
super().__init__()
|
411
|
+
self.config = config
|
412
|
+
if not isinstance(config.text_config, CLIPTextConfig):
|
413
|
+
raise TypeError(
|
414
|
+
"config.text_config is expected to be of type CLIPTextConfig but is of type"
|
415
|
+
f" {type(config.text_config)}."
|
416
|
+
)
|
417
|
+
|
418
|
+
if not isinstance(config.vision_config, CLIPVisionConfig):
|
419
|
+
raise TypeError(
|
420
|
+
"config.vision_config is expected to be of type CLIPVisionConfig but is of type"
|
421
|
+
f" {type(config.vision_config)}."
|
422
|
+
)
|
423
|
+
|
424
|
+
text_config = config.text_config
|
425
|
+
vision_config = config.vision_config
|
426
|
+
|
427
|
+
self.projection_dim = config.projection_dim
|
428
|
+
self.text_embed_dim = text_config.hidden_size
|
429
|
+
self.vision_embed_dim = vision_config.hidden_size
|
430
|
+
self.visual_projection = nn.Linear(
|
431
|
+
self.vision_embed_dim, self.projection_dim, bias=False
|
432
|
+
)
|
433
|
+
self.text_projection = nn.Linear(
|
434
|
+
self.text_embed_dim, self.projection_dim, bias=False
|
435
|
+
)
|
436
|
+
self.logit_scale = nn.Parameter(
|
437
|
+
torch.tensor(self.config.logit_scale_init_value)
|
438
|
+
)
|
439
|
+
|
440
|
+
text_model = CLIPTextModel(
|
441
|
+
text_config, quant_config, prefix=add_prefix("text_model", prefix)
|
442
|
+
)
|
443
|
+
vision_model = CLIPVisionModel(
|
444
|
+
vision_config, quant_config, prefix=add_prefix("vision_model", prefix)
|
445
|
+
)
|
446
|
+
self.text_model = text_model.text_model
|
447
|
+
self.vision_model = vision_model.vision_model
|
448
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
449
|
+
monkey_patch_weight_loader()
|
450
|
+
|
451
|
+
def forward(
|
452
|
+
self,
|
453
|
+
input_ids: torch.Tensor,
|
454
|
+
positions: torch.Tensor,
|
455
|
+
forward_batch: ForwardBatch,
|
456
|
+
get_embedding: bool = True,
|
457
|
+
):
|
458
|
+
assert get_embedding, "CLIPEmbeddingModel is only used for embedding"
|
459
|
+
image_inputs = None
|
460
|
+
if forward_batch.mm_inputs is not None:
|
461
|
+
image_inputs = forward_batch.mm_inputs
|
462
|
+
|
463
|
+
if image_inputs is not None and image_inputs[0] is not None:
|
464
|
+
vision_outputs = self.vision_model(image_inputs[0].pixel_values)
|
465
|
+
pooled_output = vision_outputs[:, 0, :]
|
466
|
+
image_embeds = self.visual_projection(pooled_output)
|
467
|
+
image_embeds = nn.functional.normalize(image_embeds, p=2, dim=1)
|
468
|
+
return EmbeddingPoolerOutput(embeddings=image_embeds)
|
469
|
+
|
470
|
+
else:
|
471
|
+
text_outputs = self.text_model(input_ids, position_ids=positions)
|
472
|
+
pooled_output = self.pooler(text_outputs[0], forward_batch)
|
473
|
+
return EmbeddingPoolerOutput(
|
474
|
+
embeddings=self.text_projection(pooled_output.embeddings)
|
475
|
+
)
|
476
|
+
|
477
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
478
|
+
# Clip embeddings models handle text/image separately, so we don't need to pad input ids
|
479
|
+
return input_ids
|
480
|
+
|
481
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
482
|
+
stacked_params_mapping = [
|
483
|
+
# (param_name, shard_name, shard_id)
|
484
|
+
("qkv_proj", "q_proj", "q"),
|
485
|
+
("qkv_proj", "k_proj", "k"),
|
486
|
+
("qkv_proj", "v_proj", "v"),
|
487
|
+
]
|
488
|
+
params_dict = dict(self.named_parameters())
|
489
|
+
for name, loaded_weight in weights:
|
490
|
+
if "position_ids" in name:
|
491
|
+
continue
|
492
|
+
if "out_proj" in name:
|
493
|
+
name = name.replace("out_proj", "proj")
|
494
|
+
for param_name, shard_name, shard_id in stacked_params_mapping:
|
495
|
+
if shard_name not in name:
|
496
|
+
continue
|
497
|
+
name = name.replace(shard_name, param_name)
|
498
|
+
param = params_dict[name]
|
499
|
+
weight_loader = param.weight_loader
|
500
|
+
weight_loader(param, loaded_weight, shard_id)
|
501
|
+
break
|
502
|
+
else:
|
503
|
+
param = params_dict[name]
|
504
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
505
|
+
weight_loader(param, loaded_weight)
|
506
|
+
|
507
|
+
|
508
|
+
# monkey patch weight loader to remove open_clip file
|
509
|
+
def monkey_patch_weight_loader():
|
510
|
+
import glob
|
511
|
+
import os
|
512
|
+
|
513
|
+
from sglang.srt.model_loader.loader import DefaultModelLoader
|
514
|
+
from sglang.srt.model_loader.weight_utils import (
|
515
|
+
download_weights_from_hf,
|
516
|
+
filter_files_not_needed_for_inference,
|
517
|
+
)
|
518
|
+
|
519
|
+
def prepare_weights(
|
520
|
+
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
|
521
|
+
) -> Tuple[str, List[str], bool]:
|
522
|
+
model_name_or_path = (
|
523
|
+
self._maybe_download_from_modelscope(model_name_or_path, revision)
|
524
|
+
or model_name_or_path
|
525
|
+
)
|
526
|
+
|
527
|
+
is_local = os.path.isdir(model_name_or_path)
|
528
|
+
use_safetensors = False
|
529
|
+
allow_patterns = ["*.bin"]
|
530
|
+
|
531
|
+
if not is_local:
|
532
|
+
hf_folder = download_weights_from_hf(
|
533
|
+
model_name_or_path,
|
534
|
+
self.load_config.download_dir,
|
535
|
+
allow_patterns,
|
536
|
+
revision,
|
537
|
+
ignore_patterns=self.load_config.ignore_patterns,
|
538
|
+
)
|
539
|
+
else:
|
540
|
+
hf_folder = model_name_or_path
|
541
|
+
|
542
|
+
hf_weights_files: List[str] = []
|
543
|
+
for pattern in allow_patterns:
|
544
|
+
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
545
|
+
|
546
|
+
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
|
547
|
+
|
548
|
+
# remove open_clip file
|
549
|
+
hf_weights_files = [
|
550
|
+
file for file in hf_weights_files if "open_clip" not in file
|
551
|
+
]
|
552
|
+
|
553
|
+
if len(hf_weights_files) == 0:
|
554
|
+
raise RuntimeError(
|
555
|
+
f"Cannot find any model weights with `{model_name_or_path}`"
|
556
|
+
)
|
557
|
+
|
558
|
+
return hf_folder, hf_weights_files, use_safetensors
|
559
|
+
|
560
|
+
setattr(DefaultModelLoader, "_prepare_weights", prepare_weights)
|
561
|
+
|
562
|
+
|
563
|
+
EntryClass = CLIPModel
|
@@ -252,7 +252,7 @@ def resample_patch_embed(
|
|
252
252
|
try:
|
253
253
|
from torch import vmap
|
254
254
|
except ImportError:
|
255
|
-
from
|
255
|
+
from torch.func import vmap
|
256
256
|
|
257
257
|
assert len(patch_embed.shape) == 4, "Four dimensions expected"
|
258
258
|
assert len(new_size) == 2, "New shape should only be hw"
|
@@ -1084,7 +1084,7 @@ def create_siglip_vit(
|
|
1084
1084
|
)
|
1085
1085
|
|
1086
1086
|
if ckpt_path:
|
1087
|
-
state_dict = torch.load(ckpt_path, map_location="cpu")
|
1087
|
+
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
1088
1088
|
|
1089
1089
|
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
1090
1090
|
print(
|