xinference 0.9.3__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (64) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/oauth2/auth_service.py +47 -18
  3. xinference/api/oauth2/types.py +1 -0
  4. xinference/api/restful_api.py +16 -11
  5. xinference/client/restful/restful_client.py +12 -2
  6. xinference/conftest.py +13 -2
  7. xinference/constants.py +2 -0
  8. xinference/core/supervisor.py +32 -1
  9. xinference/core/worker.py +139 -20
  10. xinference/deploy/cmdline.py +119 -20
  11. xinference/model/llm/__init__.py +6 -0
  12. xinference/model/llm/llm_family.json +711 -10
  13. xinference/model/llm/llm_family_modelscope.json +557 -7
  14. xinference/model/llm/pytorch/chatglm.py +2 -1
  15. xinference/model/llm/pytorch/core.py +2 -0
  16. xinference/model/llm/pytorch/deepseek_vl.py +232 -0
  17. xinference/model/llm/pytorch/internlm2.py +2 -1
  18. xinference/model/llm/pytorch/omnilmm.py +153 -0
  19. xinference/model/llm/sglang/__init__.py +13 -0
  20. xinference/model/llm/sglang/core.py +365 -0
  21. xinference/model/llm/utils.py +46 -13
  22. xinference/model/llm/vllm/core.py +10 -0
  23. xinference/thirdparty/deepseek_vl/__init__.py +31 -0
  24. xinference/thirdparty/deepseek_vl/models/__init__.py +28 -0
  25. xinference/thirdparty/deepseek_vl/models/clip_encoder.py +242 -0
  26. xinference/thirdparty/deepseek_vl/models/image_processing_vlm.py +208 -0
  27. xinference/thirdparty/deepseek_vl/models/modeling_vlm.py +170 -0
  28. xinference/thirdparty/deepseek_vl/models/processing_vlm.py +390 -0
  29. xinference/thirdparty/deepseek_vl/models/projector.py +100 -0
  30. xinference/thirdparty/deepseek_vl/models/sam.py +593 -0
  31. xinference/thirdparty/deepseek_vl/models/siglip_vit.py +681 -0
  32. xinference/thirdparty/deepseek_vl/utils/__init__.py +18 -0
  33. xinference/thirdparty/deepseek_vl/utils/conversation.py +348 -0
  34. xinference/thirdparty/deepseek_vl/utils/io.py +78 -0
  35. xinference/thirdparty/omnilmm/__init__.py +0 -0
  36. xinference/thirdparty/omnilmm/chat.py +216 -0
  37. xinference/thirdparty/omnilmm/constants.py +4 -0
  38. xinference/thirdparty/omnilmm/conversation.py +332 -0
  39. xinference/thirdparty/omnilmm/model/__init__.py +1 -0
  40. xinference/thirdparty/omnilmm/model/omnilmm.py +594 -0
  41. xinference/thirdparty/omnilmm/model/resampler.py +166 -0
  42. xinference/thirdparty/omnilmm/model/utils.py +563 -0
  43. xinference/thirdparty/omnilmm/train/__init__.py +13 -0
  44. xinference/thirdparty/omnilmm/train/train_utils.py +150 -0
  45. xinference/thirdparty/omnilmm/utils.py +134 -0
  46. xinference/web/ui/build/asset-manifest.json +3 -3
  47. xinference/web/ui/build/index.html +1 -1
  48. xinference/web/ui/build/static/js/main.98516614.js +3 -0
  49. xinference/web/ui/build/static/js/main.98516614.js.map +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/139969fd25258eb7decc9505f30b779089bba50c402bb5c663008477c7bff73b.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/3f357ab57b8e7fade54c667f0e0ebf2787566f72bfdca0fea14e395b5c203753.json +1 -0
  52. xinference/web/ui/node_modules/.cache/babel-loader/9d7c49815d97539207e5aab2fb967591b5fed7791218a0762539efc9491f36af.json +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/d0d0b591d9adaf42b83ad6633f8b7c118541a4b80ea957c303d3bf9b86fbad0a.json +1 -0
  54. {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/METADATA +21 -5
  55. {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/RECORD +60 -31
  56. xinference/web/ui/build/static/js/main.66b1c4fb.js +0 -3
  57. xinference/web/ui/build/static/js/main.66b1c4fb.js.map +0 -1
  58. xinference/web/ui/node_modules/.cache/babel-loader/c2124cfe036b26befcbd386d1d17743b1a58d0b7a041a17bb67f9924400d63c3.json +0 -1
  59. xinference/web/ui/node_modules/.cache/babel-loader/fd4a8ae5d192331af1bedd1d2d70efcc569708ee6cc4cb479b225d059482aa81.json +0 -1
  60. /xinference/web/ui/build/static/js/{main.66b1c4fb.js.LICENSE.txt → main.98516614.js.LICENSE.txt} +0 -0
  61. {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/LICENSE +0 -0
  62. {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/WHEEL +0 -0
  63. {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/entry_points.txt +0 -0
  64. {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,681 @@
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
21
+ import math
22
+ import warnings
23
+ from dataclasses import dataclass
24
+ from functools import partial
25
+ from typing import (
26
+ Callable,
27
+ Dict,
28
+ Final,
29
+ List,
30
+ Literal,
31
+ Optional,
32
+ Sequence,
33
+ Set,
34
+ Tuple,
35
+ Type,
36
+ Union,
37
+ )
38
+
39
+ import torch
40
+ import torch.nn as nn
41
+ import torch.nn.functional as F
42
+ from timm.layers import (
43
+ AttentionPoolLatent,
44
+ DropPath,
45
+ LayerType,
46
+ Mlp,
47
+ PatchDropout,
48
+ PatchEmbed,
49
+ resample_abs_pos_embed,
50
+ )
51
+ from timm.models._manipulate import checkpoint_seq, named_apply
52
+
53
+
54
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
55
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
56
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
57
+ def norm_cdf(x):
58
+ # Computes standard normal cumulative distribution function
59
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
60
+
61
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
62
+ warnings.warn(
63
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
64
+ "The distribution of values may be incorrect.",
65
+ stacklevel=2,
66
+ )
67
+
68
+ with torch.no_grad():
69
+ # Values are generated by using a truncated uniform distribution and
70
+ # then using the inverse CDF for the normal distribution.
71
+ # Get upper and lower cdf values
72
+ l = norm_cdf((a - mean) / std) # noqa: E741
73
+ u = norm_cdf((b - mean) / std)
74
+
75
+ # Uniformly fill tensor with values from [l, u], then translate to
76
+ # [2l-1, 2u-1].
77
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
78
+
79
+ # Use inverse cdf transform for normal distribution to get truncated
80
+ # standard normal
81
+ tensor.erfinv_()
82
+
83
+ # Transform to proper mean, std
84
+ tensor.mul_(std * math.sqrt(2.0))
85
+ tensor.add_(mean)
86
+
87
+ # Clamp to ensure it's in the proper range
88
+ tensor.clamp_(min=a, max=b)
89
+ return tensor
90
+
91
+
92
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
93
+ # type: (torch.Tensor, float, float, float, float) -> torch.Tensor
94
+ r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
95
+ convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype.
96
+ Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
97
+ from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
98
+ with values outside :math:`[a, b]` redrawn until they are within
99
+ the bounds. The method used for generating the random values works
100
+ best when :math:`a \leq \text{mean} \leq b`.
101
+ Args:
102
+ tensor: an n-dimensional `torch.Tensor`
103
+ mean: the mean of the normal distribution
104
+ std: the standard deviation of the normal distribution
105
+ a: the minimum cutoff value
106
+ b: the maximum cutoff value
107
+ Examples:
108
+ >>> w = torch.empty(3, 5)
109
+ >>> nn.init.trunc_normal_(w)
110
+ """
111
+
112
+ with torch.no_grad():
113
+ dtype = tensor.dtype
114
+ tensor_fp32 = tensor.float()
115
+ tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
116
+ tensor_dtype = tensor_fp32.to(dtype=dtype)
117
+ tensor.copy_(tensor_dtype)
118
+
119
+
120
+ def init_weights(self):
121
+ if self.pos_embed is not None:
122
+ trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
123
+ trunc_normal_(self.latent, std=self.latent_dim**-0.5)
124
+
125
+
126
+ def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
127
+ """ViT weight initialization, original timm impl (for reproducibility)"""
128
+ if isinstance(module, nn.Linear):
129
+ trunc_normal_(module.weight, std=0.02)
130
+ if module.bias is not None:
131
+ nn.init.zeros_(module.bias)
132
+ elif hasattr(module, "init_weights"):
133
+ module.init_weights()
134
+
135
+
136
+ class Attention(nn.Module):
137
+ fused_attn: Final[bool]
138
+
139
+ def __init__(
140
+ self,
141
+ dim: int,
142
+ num_heads: int = 8,
143
+ qkv_bias: bool = False,
144
+ qk_norm: bool = False,
145
+ attn_drop: float = 0.0,
146
+ proj_drop: float = 0.0,
147
+ norm_layer: nn.Module = nn.LayerNorm,
148
+ ) -> None:
149
+ super().__init__()
150
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
151
+ self.num_heads = num_heads
152
+ self.head_dim = dim // num_heads
153
+ self.scale = self.head_dim**-0.5
154
+ # self.fused_attn = use_fused_attn()
155
+ self.fused_attn = True
156
+
157
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
158
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
159
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
160
+ self.attn_drop = nn.Dropout(attn_drop)
161
+ self.proj = nn.Linear(dim, dim)
162
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ B, N, C = x.shape
166
+ qkv = (
167
+ self.qkv(x)
168
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
169
+ .permute(2, 0, 3, 1, 4)
170
+ )
171
+ q, k, v = qkv.unbind(0)
172
+ q, k = self.q_norm(q), self.k_norm(k)
173
+
174
+ if self.fused_attn:
175
+ x = F.scaled_dot_product_attention(
176
+ q,
177
+ k,
178
+ v,
179
+ dropout_p=self.attn_drop.p if self.training else 0.0,
180
+ )
181
+ else:
182
+ q = q * self.scale
183
+ attn = q @ k.transpose(-2, -1)
184
+ attn = attn.softmax(dim=-1)
185
+ attn = self.attn_drop(attn)
186
+ x = attn @ v
187
+
188
+ x = x.transpose(1, 2).reshape(B, N, C)
189
+ x = self.proj(x)
190
+ x = self.proj_drop(x)
191
+ return x
192
+
193
+
194
+ class LayerScale(nn.Module):
195
+ def __init__(
196
+ self,
197
+ dim: int,
198
+ init_values: float = 1e-5,
199
+ inplace: bool = False,
200
+ ) -> None:
201
+ super().__init__()
202
+ self.inplace = inplace
203
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
204
+
205
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
206
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
207
+
208
+
209
+ class Block(nn.Module):
210
+ def __init__(
211
+ self,
212
+ dim: int,
213
+ num_heads: int,
214
+ mlp_ratio: float = 4.0,
215
+ qkv_bias: bool = False,
216
+ qk_norm: bool = False,
217
+ proj_drop: float = 0.0,
218
+ attn_drop: float = 0.0,
219
+ init_values: Optional[float] = None,
220
+ drop_path: float = 0.0,
221
+ act_layer: nn.Module = nn.GELU,
222
+ norm_layer: nn.Module = nn.LayerNorm,
223
+ mlp_layer: nn.Module = Mlp,
224
+ ) -> None:
225
+ super().__init__()
226
+ self.norm1 = norm_layer(dim)
227
+ self.attn = Attention(
228
+ dim,
229
+ num_heads=num_heads,
230
+ qkv_bias=qkv_bias,
231
+ qk_norm=qk_norm,
232
+ attn_drop=attn_drop,
233
+ proj_drop=proj_drop,
234
+ norm_layer=norm_layer,
235
+ )
236
+ self.ls1 = (
237
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
238
+ )
239
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
240
+
241
+ self.norm2 = norm_layer(dim)
242
+ self.mlp = mlp_layer(
243
+ in_features=dim,
244
+ hidden_features=int(dim * mlp_ratio),
245
+ act_layer=act_layer,
246
+ drop=proj_drop,
247
+ )
248
+ self.ls2 = (
249
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
250
+ )
251
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
252
+
253
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
254
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
255
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
256
+ return x
257
+
258
+
259
+ class VisionTransformer(nn.Module):
260
+ """Vision Transformer
261
+
262
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
263
+ - https://arxiv.org/abs/2010.11929
264
+ """
265
+
266
+ dynamic_img_size: Final[bool]
267
+
268
+ def __init__(
269
+ self,
270
+ img_size: Union[int, Tuple[int, int]] = 224,
271
+ patch_size: Union[int, Tuple[int, int]] = 16,
272
+ in_chans: int = 3,
273
+ num_classes: int = 1000,
274
+ global_pool: Literal["", "avg", "token", "map"] = "token",
275
+ embed_dim: int = 768,
276
+ depth: int = 12,
277
+ num_heads: int = 12,
278
+ mlp_ratio: float = 4.0,
279
+ qkv_bias: bool = True,
280
+ qk_norm: bool = False,
281
+ init_values: Optional[float] = None,
282
+ class_token: bool = True,
283
+ no_embed_class: bool = False,
284
+ reg_tokens: int = 0,
285
+ pre_norm: bool = False,
286
+ fc_norm: Optional[bool] = None,
287
+ dynamic_img_size: bool = False,
288
+ dynamic_img_pad: bool = False,
289
+ drop_rate: float = 0.0,
290
+ pos_drop_rate: float = 0.0,
291
+ patch_drop_rate: float = 0.0,
292
+ proj_drop_rate: float = 0.0,
293
+ attn_drop_rate: float = 0.0,
294
+ drop_path_rate: float = 0.0,
295
+ weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
296
+ embed_layer: Callable = PatchEmbed,
297
+ norm_layer: Optional[LayerType] = None,
298
+ act_layer: Optional[LayerType] = None,
299
+ block_fn: Type[nn.Module] = Block,
300
+ mlp_layer: Type[nn.Module] = Mlp,
301
+ ignore_head: bool = False,
302
+ ) -> None:
303
+ """
304
+ Args:
305
+ img_size: Input image size.
306
+ patch_size: Patch size.
307
+ in_chans: Number of image input channels.
308
+ num_classes: Mumber of classes for classification head.
309
+ global_pool: Type of global pooling for final sequence (default: 'token').
310
+ embed_dim: Transformer embedding dimension.
311
+ depth: Depth of transformer.
312
+ num_heads: Number of attention heads.
313
+ mlp_ratio: Ratio of mlp hidden dim to embedding dim.
314
+ qkv_bias: Enable bias for qkv projections if True.
315
+ init_values: Layer-scale init values (layer-scale enabled if not None).
316
+ class_token: Use class token.
317
+ no_embed_class: Don't include position embeddings for class (or reg) tokens.
318
+ reg_tokens: Number of register tokens.
319
+ fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
320
+ drop_rate: Head dropout rate.
321
+ pos_drop_rate: Position embedding dropout rate.
322
+ attn_drop_rate: Attention dropout rate.
323
+ drop_path_rate: Stochastic depth rate.
324
+ weight_init: Weight initialization scheme.
325
+ embed_layer: Patch embedding layer.
326
+ norm_layer: Normalization layer.
327
+ act_layer: MLP activation layer.
328
+ block_fn: Transformer block layer.
329
+ """
330
+ super().__init__()
331
+ assert global_pool in ("", "avg", "token", "map")
332
+ assert class_token or global_pool != "token"
333
+ use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
334
+ # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
335
+ # act_layer = get_act_layer(act_layer) or nn.GELU
336
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
337
+ act_layer = nn.GELU
338
+
339
+ self.num_classes = num_classes
340
+ self.global_pool = global_pool
341
+ self.num_features = (
342
+ self.embed_dim
343
+ ) = embed_dim # num_features for consistency with other models
344
+ self.num_prefix_tokens = 1 if class_token else 0
345
+ self.num_prefix_tokens += reg_tokens
346
+ self.num_reg_tokens = reg_tokens
347
+ self.has_class_token = class_token
348
+ self.no_embed_class = (
349
+ no_embed_class # don't embed prefix positions (includes reg)
350
+ )
351
+ self.dynamic_img_size = dynamic_img_size
352
+ self.grad_checkpointing = False
353
+ self.ignore_head = ignore_head
354
+
355
+ embed_args = {}
356
+ if dynamic_img_size:
357
+ # flatten deferred until after pos embed
358
+ embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
359
+ self.patch_embed = embed_layer(
360
+ img_size=img_size,
361
+ patch_size=patch_size,
362
+ in_chans=in_chans,
363
+ embed_dim=embed_dim,
364
+ bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
365
+ dynamic_img_pad=dynamic_img_pad,
366
+ **embed_args,
367
+ )
368
+ num_patches = self.patch_embed.num_patches
369
+
370
+ self.cls_token = (
371
+ nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
372
+ )
373
+ self.reg_token = (
374
+ nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
375
+ )
376
+ embed_len = (
377
+ num_patches if no_embed_class else num_patches + self.num_prefix_tokens
378
+ )
379
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
380
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
381
+ if patch_drop_rate > 0:
382
+ self.patch_drop = PatchDropout(
383
+ patch_drop_rate,
384
+ num_prefix_tokens=self.num_prefix_tokens,
385
+ )
386
+ else:
387
+ self.patch_drop = nn.Identity()
388
+ self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
389
+
390
+ dpr = [
391
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
392
+ ] # stochastic depth decay rule
393
+ self.blocks = nn.Sequential(
394
+ *[
395
+ block_fn(
396
+ dim=embed_dim,
397
+ num_heads=num_heads,
398
+ mlp_ratio=mlp_ratio,
399
+ qkv_bias=qkv_bias,
400
+ qk_norm=qk_norm,
401
+ init_values=init_values,
402
+ proj_drop=proj_drop_rate,
403
+ attn_drop=attn_drop_rate,
404
+ drop_path=dpr[i],
405
+ norm_layer=norm_layer,
406
+ act_layer=act_layer,
407
+ mlp_layer=mlp_layer,
408
+ )
409
+ for i in range(depth)
410
+ ]
411
+ )
412
+ self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
413
+
414
+ # Classifier Head
415
+ if global_pool == "map":
416
+ AttentionPoolLatent.init_weights = init_weights
417
+ self.attn_pool = AttentionPoolLatent(
418
+ self.embed_dim,
419
+ num_heads=num_heads,
420
+ mlp_ratio=mlp_ratio,
421
+ norm_layer=norm_layer,
422
+ )
423
+ else:
424
+ self.attn_pool = None
425
+ self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
426
+ self.head_drop = nn.Dropout(drop_rate)
427
+ self.head = (
428
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
429
+ )
430
+
431
+ if weight_init != "skip":
432
+ self.init_weights(weight_init)
433
+
434
+ def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
435
+ assert mode in ("jax", "jax_nlhb", "moco", "")
436
+ # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
437
+ trunc_normal_(self.pos_embed, std=0.02)
438
+ if self.cls_token is not None:
439
+ nn.init.normal_(self.cls_token, std=1e-6)
440
+ named_apply(init_weights_vit_timm, self)
441
+
442
+ @torch.jit.ignore
443
+ def no_weight_decay(self) -> Set:
444
+ return {"pos_embed", "cls_token", "dist_token"}
445
+
446
+ @torch.jit.ignore
447
+ def group_matcher(self, coarse: bool = False) -> Dict:
448
+ return dict(
449
+ stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
450
+ blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
451
+ )
452
+
453
+ @torch.jit.ignore
454
+ def set_grad_checkpointing(self, enable: bool = True) -> None:
455
+ self.grad_checkpointing = enable
456
+
457
+ @torch.jit.ignore
458
+ def get_classifier(self) -> nn.Module:
459
+ return self.head
460
+
461
+ def reset_classifier(self, num_classes: int, global_pool=None) -> None:
462
+ self.num_classes = num_classes
463
+ if global_pool is not None:
464
+ assert global_pool in ("", "avg", "token", "map")
465
+ if global_pool == "map" and self.attn_pool is None:
466
+ assert (
467
+ False
468
+ ), "Cannot currently add attention pooling in reset_classifier()."
469
+ elif global_pool != "map " and self.attn_pool is not None:
470
+ self.attn_pool = None # remove attention pooling
471
+ self.global_pool = global_pool
472
+ self.head = (
473
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
474
+ )
475
+
476
+ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
477
+ if self.dynamic_img_size:
478
+ B, H, W, C = x.shape
479
+ pos_embed = resample_abs_pos_embed(
480
+ self.pos_embed,
481
+ (H, W),
482
+ num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
483
+ )
484
+ x = x.view(B, -1, C)
485
+ else:
486
+ pos_embed = self.pos_embed
487
+
488
+ to_cat = []
489
+ if self.cls_token is not None:
490
+ to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
491
+ if self.reg_token is not None:
492
+ to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
493
+
494
+ if self.no_embed_class:
495
+ # deit-3, updated JAX (big vision)
496
+ # position embedding does not overlap with class token, add then concat
497
+ x = x + pos_embed
498
+ if to_cat:
499
+ x = torch.cat(to_cat + [x], dim=1)
500
+ else:
501
+ # original timm, JAX, and deit vit impl
502
+ # pos_embed has entry for class token, concat then add
503
+ if to_cat:
504
+ x = torch.cat(to_cat + [x], dim=1)
505
+ x = x + pos_embed
506
+
507
+ return self.pos_drop(x)
508
+
509
+ def _intermediate_layers(
510
+ self,
511
+ x: torch.Tensor,
512
+ n: Union[int, Sequence] = 1,
513
+ ) -> List[torch.Tensor]:
514
+ outputs, num_blocks = [], len(self.blocks)
515
+ take_indices = set(
516
+ range(num_blocks - n, num_blocks) if isinstance(n, int) else n
517
+ )
518
+
519
+ # forward pass
520
+ x = self.patch_embed(x)
521
+ x = self._pos_embed(x)
522
+ x = self.patch_drop(x)
523
+ x = self.norm_pre(x)
524
+ for i, blk in enumerate(self.blocks):
525
+ x = blk(x)
526
+ if i in take_indices:
527
+ outputs.append(x)
528
+
529
+ return outputs
530
+
531
+ def get_intermediate_layers(
532
+ self,
533
+ x: torch.Tensor,
534
+ n: Union[int, Sequence] = 1,
535
+ reshape: bool = False,
536
+ return_prefix_tokens: bool = False,
537
+ norm: bool = False,
538
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
539
+ """Intermediate layer accessor (NOTE: This is a WIP experiment).
540
+ Inspired by DINO / DINOv2 interface
541
+ """
542
+ # take last n blocks if n is an int, if in is a sequence, select by matching indices
543
+ outputs = self._intermediate_layers(x, n)
544
+ if norm:
545
+ outputs = [self.norm(out) for out in outputs]
546
+ prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
547
+ outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
548
+
549
+ if reshape:
550
+ grid_size = self.patch_embed.grid_size
551
+ outputs = [
552
+ out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
553
+ .permute(0, 3, 1, 2)
554
+ .contiguous()
555
+ for out in outputs
556
+ ]
557
+
558
+ if return_prefix_tokens:
559
+ return tuple(zip(outputs, prefix_tokens))
560
+ return tuple(outputs)
561
+
562
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
563
+ x = self.patch_embed(x)
564
+ x = self._pos_embed(x)
565
+ x = self.patch_drop(x)
566
+ x = self.norm_pre(x)
567
+ if self.grad_checkpointing and not torch.jit.is_scripting():
568
+ x = checkpoint_seq(self.blocks, x)
569
+ else:
570
+ x = self.blocks(x)
571
+ x = self.norm(x)
572
+ return x
573
+
574
+ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
575
+ if self.attn_pool is not None:
576
+ x = self.attn_pool(x)
577
+ elif self.global_pool == "avg":
578
+ x = x[:, self.num_prefix_tokens :].mean(dim=1)
579
+ elif self.global_pool:
580
+ x = x[:, 0] # class token
581
+ x = self.fc_norm(x)
582
+ x = self.head_drop(x)
583
+ return x if pre_logits else self.head(x)
584
+
585
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
586
+ x = self.forward_features(x)
587
+ if not self.ignore_head:
588
+ x = self.forward_head(x)
589
+ return x
590
+
591
+
592
+ @dataclass
593
+ class SigLIPVisionCfg:
594
+ width: int = 1152
595
+ layers: Union[Tuple[int, int, int, int], int] = 27
596
+ heads: int = 16
597
+ patch_size: int = 14
598
+ image_size: Union[Tuple[int, int], int] = 336
599
+ global_pool: str = "map"
600
+ mlp_ratio: float = 3.7362
601
+ class_token: bool = False
602
+ num_classes: int = 0
603
+ use_checkpoint: bool = False
604
+
605
+
606
+ SigLIP_MODEL_CONFIG = {
607
+ "siglip_so400m_patch14_384": {
608
+ "image_size": 336,
609
+ "patch_size": 14,
610
+ "width": 1152,
611
+ "layers": 27,
612
+ "heads": 16,
613
+ "mlp_ratio": 3.7362,
614
+ "global_pool": "map",
615
+ "use_checkpoint": False,
616
+ },
617
+ "siglip_so400m_patch14_224": {
618
+ "image_size": 224,
619
+ "patch_size": 14,
620
+ "width": 1152,
621
+ "layers": 27,
622
+ "heads": 16,
623
+ "mlp_ratio": 3.7362,
624
+ "global_pool": "map",
625
+ "use_checkpoint": False,
626
+ },
627
+ "siglip_large_patch16_384": {
628
+ "image_size": 384,
629
+ "patch_size": 16,
630
+ "width": 1024,
631
+ "layers": 24,
632
+ "heads": 16,
633
+ "mlp_ratio": 4,
634
+ "global_pool": "map",
635
+ "use_checkpoint": False,
636
+ },
637
+ }
638
+
639
+
640
+ def create_siglip_vit(
641
+ model_name: str = "siglip_so400m_patch14_384",
642
+ image_size: int = 384,
643
+ select_layer: int = -1,
644
+ ckpt_path: str = "",
645
+ **kwargs,
646
+ ):
647
+ assert (
648
+ model_name in SigLIP_MODEL_CONFIG.keys()
649
+ ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
650
+
651
+ vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
652
+
653
+ if select_layer <= 0:
654
+ layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
655
+ else:
656
+ layers = min(vision_cfg.layers, select_layer)
657
+
658
+ model = VisionTransformer(
659
+ img_size=image_size,
660
+ patch_size=vision_cfg.patch_size,
661
+ embed_dim=vision_cfg.width,
662
+ depth=layers,
663
+ num_heads=vision_cfg.heads,
664
+ mlp_ratio=vision_cfg.mlp_ratio,
665
+ class_token=vision_cfg.class_token,
666
+ global_pool=vision_cfg.global_pool,
667
+ ignore_head=kwargs.get("ignore_head", True),
668
+ weight_init=kwargs.get("weight_init", "skip"),
669
+ num_classes=0,
670
+ )
671
+
672
+ if ckpt_path:
673
+ state_dict = torch.load(ckpt_path, map_location="cpu")
674
+
675
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
676
+ print(
677
+ f"SigLIP-ViT restores from {ckpt_path},\n"
678
+ f"\tincompatible_keys:', {incompatible_keys}."
679
+ )
680
+
681
+ return model
@@ -0,0 +1,18 @@
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.