sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. sglang/bench_serving.py +23 -3
  2. sglang/srt/configs/deepseekvl2.py +10 -1
  3. sglang/srt/configs/model_config.py +5 -16
  4. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  5. sglang/srt/distributed/parallel_state.py +32 -5
  6. sglang/srt/entrypoints/http_server.py +7 -1
  7. sglang/srt/entrypoints/verl_engine.py +2 -0
  8. sglang/srt/function_call_parser.py +0 -1
  9. sglang/srt/layers/attention/flashattention_backend.py +218 -79
  10. sglang/srt/layers/dp_attention.py +12 -1
  11. sglang/srt/layers/moe/topk.py +30 -3
  12. sglang/srt/layers/quantization/__init__.py +134 -165
  13. sglang/srt/layers/quantization/awq.py +200 -0
  14. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  15. sglang/srt/layers/quantization/gptq.py +30 -40
  16. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  17. sglang/srt/layers/rotary_embedding.py +12 -0
  18. sglang/srt/lora/backend/base_backend.py +4 -4
  19. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  20. sglang/srt/lora/backend/triton_backend.py +5 -8
  21. sglang/srt/lora/layers.py +19 -33
  22. sglang/srt/lora/lora_manager.py +20 -7
  23. sglang/srt/lora/mem_pool.py +12 -6
  24. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  25. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  26. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  27. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  28. sglang/srt/lora/utils.py +6 -0
  29. sglang/srt/managers/io_struct.py +4 -2
  30. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  31. sglang/srt/managers/schedule_batch.py +1 -0
  32. sglang/srt/managers/scheduler.py +25 -19
  33. sglang/srt/managers/tokenizer_manager.py +0 -1
  34. sglang/srt/managers/tp_worker.py +3 -0
  35. sglang/srt/model_executor/cuda_graph_runner.py +9 -8
  36. sglang/srt/model_executor/model_runner.py +9 -6
  37. sglang/srt/model_loader/loader.py +11 -1
  38. sglang/srt/model_loader/weight_utils.py +6 -3
  39. sglang/srt/models/clip.py +563 -0
  40. sglang/srt/models/deepseek_janus_pro.py +2 -2
  41. sglang/srt/models/deepseek_v2.py +151 -26
  42. sglang/srt/models/gemma3_causal.py +12 -2
  43. sglang/srt/models/gemma3_mm.py +6 -0
  44. sglang/srt/openai_api/adapter.py +88 -87
  45. sglang/srt/openai_api/protocol.py +10 -5
  46. sglang/srt/patch_torch.py +71 -0
  47. sglang/srt/server_args.py +21 -11
  48. sglang/srt/speculative/eagle_worker.py +1 -1
  49. sglang/srt/utils.py +33 -0
  50. sglang/test/runners.py +27 -2
  51. sglang/test/test_utils.py +1 -1
  52. sglang/version.py +1 -1
  53. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +8 -4
  54. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +57 -53
  55. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +0 -0
  56. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/licenses/LICENSE +0 -0
  57. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,563 @@
1
+ # Adapted from
2
+ # https://github.com/huggingface/transformers/blob/af9b2eaa54c150741f298d6db939af6328e1dc38/src/transformers/models/clip/modeling_clip.py
3
+
4
+ from functools import partial
5
+ from typing import Iterable, List, Optional, Tuple, Type, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
10
+ from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask
11
+
12
+ from sglang.srt.layers.activation import QuickGELU
13
+ from sglang.srt.layers.attention.vision import VisionAttention
14
+ from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
15
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
16
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
17
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
18
+ from sglang.srt.model_executor.model_runner import ForwardBatch
19
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
20
+ from sglang.srt.utils import add_prefix
21
+
22
+
23
+ class CLIPVisionEmbeddings(nn.Module):
24
+
25
+ def __init__(self, config: CLIPVisionConfig):
26
+ super().__init__()
27
+ self.config = config
28
+ self.embed_dim = config.hidden_size
29
+ self.image_size = config.image_size
30
+ self.patch_size = config.patch_size
31
+ assert self.image_size % self.patch_size == 0
32
+
33
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
34
+
35
+ self.patch_embedding = nn.Conv2d(
36
+ in_channels=config.num_channels,
37
+ out_channels=self.embed_dim,
38
+ kernel_size=self.patch_size,
39
+ stride=self.patch_size,
40
+ bias=False,
41
+ )
42
+
43
+ self.num_patches = (self.image_size // self.patch_size) ** 2
44
+ self.num_positions = self.num_patches + 1
45
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
46
+ self.register_buffer(
47
+ "position_ids",
48
+ torch.arange(self.num_positions).expand((1, -1)),
49
+ persistent=False,
50
+ )
51
+
52
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
53
+ batch_size = pixel_values.shape[0]
54
+ target_dtype = self.patch_embedding.weight.dtype
55
+ patch_embeds = self.patch_embedding(
56
+ pixel_values.to(dtype=target_dtype)
57
+ ) # shape = [*, width, grid, grid]
58
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
59
+
60
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
61
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
62
+ embeddings = embeddings + self.position_embedding(self.position_ids)
63
+
64
+ return embeddings
65
+
66
+
67
+ class CLIPTextEmbeddings(nn.Module):
68
+ def __init__(self, config: CLIPTextConfig):
69
+ super().__init__()
70
+ embed_dim = config.hidden_size
71
+
72
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
73
+ self.position_embedding = nn.Embedding(
74
+ config.max_position_embeddings, embed_dim
75
+ )
76
+
77
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
78
+ self.register_buffer(
79
+ "position_ids",
80
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
81
+ persistent=False,
82
+ )
83
+
84
+ def forward(
85
+ self,
86
+ input_ids: Optional[torch.LongTensor] = None,
87
+ position_ids: Optional[torch.LongTensor] = None,
88
+ inputs_embeds: Optional[torch.FloatTensor] = None,
89
+ ) -> torch.Tensor:
90
+ seq_length = (
91
+ input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
92
+ )
93
+
94
+ if position_ids is None:
95
+ position_ids = self.position_ids[:, :seq_length]
96
+
97
+ if inputs_embeds is None:
98
+ inputs_embeds = self.token_embedding(input_ids)
99
+
100
+ position_embeddings = self.position_embedding(position_ids)
101
+ embeddings = inputs_embeds + position_embeddings
102
+
103
+ return embeddings
104
+
105
+
106
+ class CLIPMLP(nn.Module):
107
+
108
+ def __init__(
109
+ self,
110
+ config,
111
+ act_layer: Type[nn.Module] = QuickGELU,
112
+ quant_config: Optional[QuantizationConfig] = None,
113
+ prefix: str = "",
114
+ ):
115
+ super().__init__()
116
+ self.fc1 = ColumnParallelLinear(
117
+ config.hidden_size,
118
+ config.intermediate_size,
119
+ quant_config=quant_config,
120
+ prefix=add_prefix("fc1", prefix),
121
+ )
122
+ self.act = act_layer()
123
+ self.fc2 = RowParallelLinear(
124
+ config.intermediate_size,
125
+ config.hidden_size,
126
+ quant_config=quant_config,
127
+ prefix=add_prefix("fc2", prefix),
128
+ )
129
+
130
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
131
+ x_parallel, _ = self.fc1(x)
132
+ x_parallel = self.act(x_parallel)
133
+ x, _ = self.fc2(x_parallel)
134
+ return x
135
+
136
+
137
+ class CLIPEncoderLayer(nn.Module):
138
+
139
+ def __init__(
140
+ self,
141
+ config: CLIPVisionConfig,
142
+ act_layer: Type[nn.Module] = QuickGELU,
143
+ norm_layer: Type[nn.Module] = None,
144
+ attn_implementation: Optional[str] = "sdpa",
145
+ quant_config: Optional[QuantizationConfig] = None,
146
+ prefix: str = "",
147
+ ) -> None:
148
+ super().__init__()
149
+ if norm_layer is None:
150
+ norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
151
+ self.layer_norm1 = norm_layer(config.hidden_size)
152
+ self.layer_norm2 = norm_layer(config.hidden_size)
153
+ if attn_implementation == "sdpa":
154
+ use_context_forward = False
155
+ softmax_in_single_precision = False
156
+ elif attn_implementation == "flash_attention_2":
157
+ softmax_in_single_precision = False
158
+ use_context_forward = True
159
+ elif attn_implementation == "eager":
160
+ softmax_in_single_precision = True
161
+ use_context_forward = False
162
+ self.self_attn = VisionAttention(
163
+ embed_dim=config.hidden_size,
164
+ num_heads=config.num_attention_heads,
165
+ projection_size=config.hidden_size,
166
+ use_qkv_parallel=True,
167
+ use_context_forward=use_context_forward,
168
+ softmax_in_single_precision=softmax_in_single_precision,
169
+ flatten_batch=True,
170
+ quant_config=quant_config,
171
+ prefix=add_prefix("attn", prefix),
172
+ )
173
+ self.mlp = CLIPMLP(
174
+ config,
175
+ act_layer=act_layer,
176
+ quant_config=quant_config,
177
+ prefix=add_prefix("mlp", prefix),
178
+ )
179
+
180
+ def forward(
181
+ self,
182
+ hidden_states: torch.Tensor,
183
+ attention_mask: torch.Tensor,
184
+ causal_attention_mask: torch.Tensor,
185
+ ) -> torch.Tensor:
186
+
187
+ residual = hidden_states
188
+ hidden_states = self.layer_norm1(hidden_states)
189
+ # CLIP text model uses both `causal_attention_mask` and `attention_mask`
190
+ if attention_mask is not None and causal_attention_mask is not None:
191
+ attn_mask = attention_mask + causal_attention_mask
192
+ elif causal_attention_mask is not None:
193
+ attn_mask = causal_attention_mask
194
+ else:
195
+ attn_mask = attention_mask
196
+ hidden_states = self.self_attn(
197
+ hidden_states,
198
+ attention_mask=attn_mask,
199
+ # causal_attention_mask=causal_attention_mask,
200
+ )
201
+
202
+ hidden_states = residual + hidden_states
203
+ residual = hidden_states
204
+ hidden_states = self.layer_norm2(hidden_states)
205
+ hidden_states = self.mlp(hidden_states)
206
+ hidden_states = residual + hidden_states
207
+ return hidden_states
208
+
209
+
210
+ class CLIPEncoder(nn.Module):
211
+ """
212
+ Transformer encoder consisting of `config.num_hidden_layers` self
213
+ attention layers. Each layer is a [`CLIPEncoderLayer`].
214
+
215
+ Args:
216
+ config: CLIPConfig
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ config: CLIPVisionConfig,
222
+ quant_config: Optional[QuantizationConfig] = None,
223
+ prefix: str = "",
224
+ ) -> None:
225
+ super().__init__()
226
+
227
+ self.config = config
228
+
229
+ num_hidden_layers = config.num_hidden_layers
230
+ norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
231
+ self.layers = nn.ModuleList(
232
+ [
233
+ CLIPEncoderLayer(
234
+ config=config,
235
+ norm_layer=norm_layer,
236
+ attn_implementation="sdpa",
237
+ quant_config=quant_config,
238
+ prefix=add_prefix(f"layers.{layer_idx}", prefix),
239
+ )
240
+ for layer_idx in range(num_hidden_layers)
241
+ ]
242
+ )
243
+
244
+ def forward(
245
+ self,
246
+ inputs_embeds: torch.Tensor,
247
+ attention_mask: torch.Tensor = None,
248
+ causal_attention_mask: torch.Tensor = None,
249
+ return_all_hidden_states: bool = False,
250
+ ) -> Union[torch.Tensor, list[torch.Tensor]]:
251
+ hidden_states_pool = [inputs_embeds]
252
+ hidden_states = inputs_embeds
253
+
254
+ for encoder_layer in self.layers:
255
+ hidden_states = encoder_layer(
256
+ hidden_states, attention_mask, causal_attention_mask
257
+ )
258
+ if return_all_hidden_states:
259
+ hidden_states_pool.append(hidden_states)
260
+ if return_all_hidden_states:
261
+ return hidden_states_pool
262
+ return hidden_states
263
+
264
+
265
+ class CLIPTextTransformer(nn.Module):
266
+ def __init__(
267
+ self,
268
+ config: CLIPTextConfig,
269
+ quant_config: Optional[QuantizationConfig] = None,
270
+ prefix: str = "",
271
+ ) -> None:
272
+ super().__init__()
273
+ self.config = config
274
+ embed_dim = config.hidden_size
275
+ self.embeddings = CLIPTextEmbeddings(config)
276
+ self.encoder = CLIPEncoder(
277
+ config=config,
278
+ quant_config=quant_config,
279
+ prefix=add_prefix("encoder", prefix),
280
+ )
281
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
282
+
283
+ @property
284
+ def device(self) -> torch.device:
285
+ return self.encoder.layers[0].layer_norm1.weight.device
286
+
287
+ def forward(
288
+ self,
289
+ input_ids: torch.Tensor,
290
+ attention_mask: Optional[torch.Tensor] = None,
291
+ position_ids: Optional[torch.Tensor] = None,
292
+ ):
293
+ input_shape = input_ids.size()
294
+ input_ids = input_ids.view(-1, input_shape[-1])
295
+ hidden_states = self.embeddings(input_ids, position_ids)
296
+ causal_attention_mask = _create_4d_causal_attention_mask(
297
+ input_ids.shape, hidden_states.dtype, device=hidden_states.device
298
+ )
299
+ encoder_outputs = self.encoder(
300
+ hidden_states, attention_mask, causal_attention_mask
301
+ )
302
+ last_hidden_state = self.final_layer_norm(encoder_outputs)
303
+ return last_hidden_state
304
+
305
+
306
+ class CLIPTextModel(nn.Module):
307
+ def __init__(
308
+ self,
309
+ config: CLIPTextConfig,
310
+ quant_config: Optional[QuantizationConfig] = None,
311
+ prefix: str = "",
312
+ ) -> None:
313
+ super().__init__()
314
+ self.config = config
315
+ self.text_model = CLIPTextTransformer(
316
+ config=config,
317
+ quant_config=quant_config,
318
+ prefix=add_prefix("text_model", prefix),
319
+ )
320
+
321
+ def forward(
322
+ self,
323
+ input_ids: torch.Tensor,
324
+ position_ids: torch.Tensor,
325
+ ):
326
+ return self.text_model(input_ids, position_ids)
327
+
328
+
329
+ class CLIPVisionTransformer(nn.Module):
330
+
331
+ def __init__(
332
+ self,
333
+ config: CLIPVisionConfig,
334
+ quant_config: Optional[QuantizationConfig] = None,
335
+ prefix: str = "",
336
+ ) -> None:
337
+ super().__init__()
338
+
339
+ self.config = config
340
+ embed_dim = config.hidden_size
341
+
342
+ self.embeddings = CLIPVisionEmbeddings(config)
343
+
344
+ # NOTE: This typo of "layrnorm" is not fixed on purpose to match
345
+ # the original transformers code and name of the model weights.
346
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
347
+
348
+ self.encoder = CLIPEncoder(
349
+ config=config,
350
+ quant_config=quant_config,
351
+ prefix=add_prefix("encoder", prefix),
352
+ )
353
+
354
+ num_hidden_layers = config.num_hidden_layers
355
+ if len(self.encoder.layers) > config.num_hidden_layers:
356
+ raise ValueError(
357
+ f"The original encoder only has {num_hidden_layers} "
358
+ f"layers, but you requested {len(self.encoder.layers)} layers."
359
+ )
360
+
361
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
362
+
363
+ @property
364
+ def device(self) -> torch.device:
365
+ return self.encoder.layers[0].layer_norm1.weight.device
366
+
367
+ def forward(
368
+ self,
369
+ pixel_values: torch.Tensor,
370
+ ) -> torch.Tensor:
371
+
372
+ hidden_states = self.embeddings(pixel_values.to(self.device))
373
+ hidden_states = self.pre_layrnorm(hidden_states)
374
+
375
+ return_all_hidden_states = False
376
+
377
+ last_hidden_state = self.encoder(
378
+ inputs_embeds=hidden_states,
379
+ return_all_hidden_states=return_all_hidden_states,
380
+ )
381
+
382
+ last_hidden_state = self.post_layernorm(last_hidden_state)
383
+
384
+ return last_hidden_state
385
+
386
+
387
+ class CLIPVisionModel(nn.Module):
388
+ def __init__(
389
+ self,
390
+ config: CLIPVisionConfig,
391
+ quant_config: Optional[QuantizationConfig] = None,
392
+ prefix: str = "",
393
+ ):
394
+ super().__init__()
395
+ self.vision_model = CLIPVisionTransformer(
396
+ config, quant_config, prefix=add_prefix("vision_model", prefix)
397
+ )
398
+
399
+ def forward(self, pixel_values: torch.Tensor):
400
+ return self.vision_model(pixel_values)
401
+
402
+
403
+ class CLIPModel(nn.Module):
404
+ def __init__(
405
+ self,
406
+ config: CLIPConfig,
407
+ quant_config: Optional[QuantizationConfig] = None,
408
+ prefix: str = "",
409
+ ) -> None:
410
+ super().__init__()
411
+ self.config = config
412
+ if not isinstance(config.text_config, CLIPTextConfig):
413
+ raise TypeError(
414
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
415
+ f" {type(config.text_config)}."
416
+ )
417
+
418
+ if not isinstance(config.vision_config, CLIPVisionConfig):
419
+ raise TypeError(
420
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
421
+ f" {type(config.vision_config)}."
422
+ )
423
+
424
+ text_config = config.text_config
425
+ vision_config = config.vision_config
426
+
427
+ self.projection_dim = config.projection_dim
428
+ self.text_embed_dim = text_config.hidden_size
429
+ self.vision_embed_dim = vision_config.hidden_size
430
+ self.visual_projection = nn.Linear(
431
+ self.vision_embed_dim, self.projection_dim, bias=False
432
+ )
433
+ self.text_projection = nn.Linear(
434
+ self.text_embed_dim, self.projection_dim, bias=False
435
+ )
436
+ self.logit_scale = nn.Parameter(
437
+ torch.tensor(self.config.logit_scale_init_value)
438
+ )
439
+
440
+ text_model = CLIPTextModel(
441
+ text_config, quant_config, prefix=add_prefix("text_model", prefix)
442
+ )
443
+ vision_model = CLIPVisionModel(
444
+ vision_config, quant_config, prefix=add_prefix("vision_model", prefix)
445
+ )
446
+ self.text_model = text_model.text_model
447
+ self.vision_model = vision_model.vision_model
448
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
449
+ monkey_patch_weight_loader()
450
+
451
+ def forward(
452
+ self,
453
+ input_ids: torch.Tensor,
454
+ positions: torch.Tensor,
455
+ forward_batch: ForwardBatch,
456
+ get_embedding: bool = True,
457
+ ):
458
+ assert get_embedding, "CLIPEmbeddingModel is only used for embedding"
459
+ image_inputs = None
460
+ if forward_batch.mm_inputs is not None:
461
+ image_inputs = forward_batch.mm_inputs
462
+
463
+ if image_inputs is not None and image_inputs[0] is not None:
464
+ vision_outputs = self.vision_model(image_inputs[0].pixel_values)
465
+ pooled_output = vision_outputs[:, 0, :]
466
+ image_embeds = self.visual_projection(pooled_output)
467
+ image_embeds = nn.functional.normalize(image_embeds, p=2, dim=1)
468
+ return EmbeddingPoolerOutput(embeddings=image_embeds)
469
+
470
+ else:
471
+ text_outputs = self.text_model(input_ids, position_ids=positions)
472
+ pooled_output = self.pooler(text_outputs[0], forward_batch)
473
+ return EmbeddingPoolerOutput(
474
+ embeddings=self.text_projection(pooled_output.embeddings)
475
+ )
476
+
477
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
478
+ # Clip embeddings models handle text/image separately, so we don't need to pad input ids
479
+ return input_ids
480
+
481
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
482
+ stacked_params_mapping = [
483
+ # (param_name, shard_name, shard_id)
484
+ ("qkv_proj", "q_proj", "q"),
485
+ ("qkv_proj", "k_proj", "k"),
486
+ ("qkv_proj", "v_proj", "v"),
487
+ ]
488
+ params_dict = dict(self.named_parameters())
489
+ for name, loaded_weight in weights:
490
+ if "position_ids" in name:
491
+ continue
492
+ if "out_proj" in name:
493
+ name = name.replace("out_proj", "proj")
494
+ for param_name, shard_name, shard_id in stacked_params_mapping:
495
+ if shard_name not in name:
496
+ continue
497
+ name = name.replace(shard_name, param_name)
498
+ param = params_dict[name]
499
+ weight_loader = param.weight_loader
500
+ weight_loader(param, loaded_weight, shard_id)
501
+ break
502
+ else:
503
+ param = params_dict[name]
504
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
505
+ weight_loader(param, loaded_weight)
506
+
507
+
508
+ # monkey patch weight loader to remove open_clip file
509
+ def monkey_patch_weight_loader():
510
+ import glob
511
+ import os
512
+
513
+ from sglang.srt.model_loader.loader import DefaultModelLoader
514
+ from sglang.srt.model_loader.weight_utils import (
515
+ download_weights_from_hf,
516
+ filter_files_not_needed_for_inference,
517
+ )
518
+
519
+ def prepare_weights(
520
+ self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
521
+ ) -> Tuple[str, List[str], bool]:
522
+ model_name_or_path = (
523
+ self._maybe_download_from_modelscope(model_name_or_path, revision)
524
+ or model_name_or_path
525
+ )
526
+
527
+ is_local = os.path.isdir(model_name_or_path)
528
+ use_safetensors = False
529
+ allow_patterns = ["*.bin"]
530
+
531
+ if not is_local:
532
+ hf_folder = download_weights_from_hf(
533
+ model_name_or_path,
534
+ self.load_config.download_dir,
535
+ allow_patterns,
536
+ revision,
537
+ ignore_patterns=self.load_config.ignore_patterns,
538
+ )
539
+ else:
540
+ hf_folder = model_name_or_path
541
+
542
+ hf_weights_files: List[str] = []
543
+ for pattern in allow_patterns:
544
+ hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
545
+
546
+ hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
547
+
548
+ # remove open_clip file
549
+ hf_weights_files = [
550
+ file for file in hf_weights_files if "open_clip" not in file
551
+ ]
552
+
553
+ if len(hf_weights_files) == 0:
554
+ raise RuntimeError(
555
+ f"Cannot find any model weights with `{model_name_or_path}`"
556
+ )
557
+
558
+ return hf_folder, hf_weights_files, use_safetensors
559
+
560
+ setattr(DefaultModelLoader, "_prepare_weights", prepare_weights)
561
+
562
+
563
+ EntryClass = CLIPModel
@@ -252,7 +252,7 @@ def resample_patch_embed(
252
252
  try:
253
253
  from torch import vmap
254
254
  except ImportError:
255
- from functorch import vmap
255
+ from torch.func import vmap
256
256
 
257
257
  assert len(patch_embed.shape) == 4, "Four dimensions expected"
258
258
  assert len(new_size) == 2, "New shape should only be hw"
@@ -1084,7 +1084,7 @@ def create_siglip_vit(
1084
1084
  )
1085
1085
 
1086
1086
  if ckpt_path:
1087
- state_dict = torch.load(ckpt_path, map_location="cpu")
1087
+ state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
1088
1088
 
1089
1089
  incompatible_keys = model.load_state_dict(state_dict, strict=False)
1090
1090
  print(