xinference 1.4.0__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +1 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +4 -0
- xinference/core/model.py +23 -3
- xinference/core/supervisor.py +6 -0
- xinference/core/worker.py +54 -11
- xinference/model/llm/__init__.py +4 -2
- xinference/model/llm/core.py +1 -0
- xinference/model/llm/llama_cpp/core.py +6 -1
- xinference/model/llm/llm_family.json +117 -1
- xinference/model/llm/llm_family_modelscope.json +125 -1
- xinference/model/llm/reasoning_parser.py +3 -3
- xinference/model/llm/sglang/core.py +111 -13
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/deepseek_vl.py +1 -1
- xinference/model/llm/transformers/deepseek_vl2.py +287 -0
- xinference/model/llm/utils.py +26 -14
- xinference/model/llm/vllm/core.py +149 -8
- xinference/model/llm/vllm/distributed_executor.py +314 -0
- xinference/model/rerank/core.py +16 -11
- xinference/thirdparty/deepseek_vl2/__init__.py +31 -0
- xinference/thirdparty/deepseek_vl2/models/__init__.py +26 -0
- xinference/thirdparty/deepseek_vl2/models/configuration_deepseek.py +210 -0
- xinference/thirdparty/deepseek_vl2/models/conversation.py +310 -0
- xinference/thirdparty/deepseek_vl2/models/modeling_deepseek.py +1975 -0
- xinference/thirdparty/deepseek_vl2/models/modeling_deepseek_vl_v2.py +697 -0
- xinference/thirdparty/deepseek_vl2/models/processing_deepseek_vl_v2.py +675 -0
- xinference/thirdparty/deepseek_vl2/models/siglip_vit.py +661 -0
- xinference/thirdparty/deepseek_vl2/serve/__init__.py +0 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/__init__.py +0 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/gradio_utils.py +83 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/overwrites.py +81 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/presets.py +115 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/utils.py +333 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/simsun.ttc +0 -0
- xinference/thirdparty/deepseek_vl2/serve/inference.py +197 -0
- xinference/thirdparty/deepseek_vl2/utils/__init__.py +18 -0
- xinference/thirdparty/deepseek_vl2/utils/io.py +80 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.3cea968e.js → main.5ca4eea1.js} +3 -3
- xinference/web/ui/build/static/js/main.5ca4eea1.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +1 -0
- {xinference-1.4.0.dist-info → xinference-1.4.1.dist-info}/METADATA +4 -4
- {xinference-1.4.0.dist-info → xinference-1.4.1.dist-info}/RECORD +56 -31
- xinference/web/ui/build/static/js/main.3cea968e.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/7f59e45e3f268ab8a4788b6fb024cf8dab088736dff22f5a3a39c122a83ab930.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/dcd60488509450bfff37bfff56de2c096d51de17dd00ec60d4db49c8b483ada1.json +0 -1
- /xinference/web/ui/build/static/js/{main.3cea968e.js.LICENSE.txt → main.5ca4eea1.js.LICENSE.txt} +0 -0
- {xinference-1.4.0.dist-info → xinference-1.4.1.dist-info}/LICENSE +0 -0
- {xinference-1.4.0.dist-info → xinference-1.4.1.dist-info}/WHEEL +0 -0
- {xinference-1.4.0.dist-info → xinference-1.4.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.4.0.dist-info → xinference-1.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,661 @@
|
|
|
1
|
+
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from typing import Final, Optional, Callable, Union, Tuple, List, Set, Dict, Type, Literal, Sequence
|
|
8
|
+
import math
|
|
9
|
+
import warnings
|
|
10
|
+
from timm.layers import (
|
|
11
|
+
PatchEmbed, Mlp, DropPath,
|
|
12
|
+
AttentionPoolLatent, PatchDropout, resample_abs_pos_embed, LayerType
|
|
13
|
+
)
|
|
14
|
+
from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
|
15
|
+
from transformers.modeling_utils import is_flash_attn_2_available
|
|
16
|
+
from functools import partial
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
if is_flash_attn_2_available():
|
|
20
|
+
from flash_attn import flash_attn_qkvpacked_func
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
|
24
|
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
|
25
|
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
|
26
|
+
def norm_cdf(x):
|
|
27
|
+
# Computes standard normal cumulative distribution function
|
|
28
|
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
|
29
|
+
|
|
30
|
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
|
31
|
+
warnings.warn(
|
|
32
|
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
|
33
|
+
"The distribution of values may be incorrect.",
|
|
34
|
+
stacklevel=2,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
with torch.no_grad():
|
|
38
|
+
# Values are generated by using a truncated uniform distribution and
|
|
39
|
+
# then using the inverse CDF for the normal distribution.
|
|
40
|
+
# Get upper and lower cdf values
|
|
41
|
+
l = norm_cdf((a - mean) / std) # noqa: E741
|
|
42
|
+
u = norm_cdf((b - mean) / std)
|
|
43
|
+
|
|
44
|
+
# Uniformly fill tensor with values from [l, u], then translate to
|
|
45
|
+
# [2l-1, 2u-1].
|
|
46
|
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
|
47
|
+
|
|
48
|
+
# Use inverse cdf transform for normal distribution to get truncated
|
|
49
|
+
# standard normal
|
|
50
|
+
tensor.erfinv_()
|
|
51
|
+
|
|
52
|
+
# Transform to proper mean, std
|
|
53
|
+
tensor.mul_(std * math.sqrt(2.0))
|
|
54
|
+
tensor.add_(mean)
|
|
55
|
+
|
|
56
|
+
# Clamp to ensure it's in the proper range
|
|
57
|
+
tensor.clamp_(min=a, max=b)
|
|
58
|
+
return tensor
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
|
62
|
+
# type: (torch.Tensor, float, float, float, float) -> torch.Tensor
|
|
63
|
+
r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
|
|
64
|
+
convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype.
|
|
65
|
+
Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
|
|
66
|
+
from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
|
67
|
+
with values outside :math:`[a, b]` redrawn until they are within
|
|
68
|
+
the bounds. The method used for generating the random values works
|
|
69
|
+
best when :math:`a \leq \text{mean} \leq b`.
|
|
70
|
+
Args:
|
|
71
|
+
tensor: an n-dimensional `torch.Tensor`
|
|
72
|
+
mean: the mean of the normal distribution
|
|
73
|
+
std: the standard deviation of the normal distribution
|
|
74
|
+
a: the minimum cutoff value
|
|
75
|
+
b: the maximum cutoff value
|
|
76
|
+
Examples:
|
|
77
|
+
>>> w = torch.empty(3, 5)
|
|
78
|
+
>>> nn.init.trunc_normal_(w)
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
with torch.no_grad():
|
|
82
|
+
dtype = tensor.dtype
|
|
83
|
+
tensor_fp32 = tensor.float()
|
|
84
|
+
tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
|
|
85
|
+
tensor_dtype = tensor_fp32.to(dtype=dtype)
|
|
86
|
+
tensor.copy_(tensor_dtype)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def init_weights(self):
|
|
90
|
+
if self.pos_embed is not None:
|
|
91
|
+
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
|
92
|
+
trunc_normal_(self.latent, std=self.latent_dim ** -0.5)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def init_weights_vit_timm(module: nn.Module, name: str = '') -> None:
|
|
96
|
+
""" ViT weight initialization, original timm impl (for reproducibility) """
|
|
97
|
+
if isinstance(module, nn.Linear):
|
|
98
|
+
trunc_normal_(module.weight, std=.02)
|
|
99
|
+
if module.bias is not None:
|
|
100
|
+
nn.init.zeros_(module.bias)
|
|
101
|
+
elif hasattr(module, 'init_weights'):
|
|
102
|
+
module.init_weights()
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class Attention(nn.Module):
|
|
106
|
+
fused_attn: Final[bool]
|
|
107
|
+
|
|
108
|
+
def __init__(
|
|
109
|
+
self,
|
|
110
|
+
dim: int,
|
|
111
|
+
num_heads: int = 8,
|
|
112
|
+
qkv_bias: bool = False,
|
|
113
|
+
qk_norm: bool = False,
|
|
114
|
+
attn_drop: float = 0.,
|
|
115
|
+
proj_drop: float = 0.,
|
|
116
|
+
norm_layer: nn.Module = nn.LayerNorm,
|
|
117
|
+
deterministic: bool = False,
|
|
118
|
+
) -> None:
|
|
119
|
+
super().__init__()
|
|
120
|
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
|
121
|
+
self.num_heads = num_heads
|
|
122
|
+
self.head_dim = dim // num_heads
|
|
123
|
+
self.scale = self.head_dim ** -0.5
|
|
124
|
+
self.qk_norm = qk_norm
|
|
125
|
+
self.fused_attn = True
|
|
126
|
+
self.deterministic = deterministic
|
|
127
|
+
|
|
128
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
129
|
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
130
|
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
131
|
+
self.attn_drop = nn.Dropout(attn_drop)
|
|
132
|
+
self.proj = nn.Linear(dim, dim)
|
|
133
|
+
self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0. else nn.Identity()
|
|
134
|
+
|
|
135
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
136
|
+
from xformers.ops import memory_efficient_attention
|
|
137
|
+
|
|
138
|
+
B, N, C = x.shape
|
|
139
|
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
|
|
140
|
+
|
|
141
|
+
if not self.qk_norm:
|
|
142
|
+
if self.head_dim % 32 == 0 and is_flash_attn_2_available():
|
|
143
|
+
# flashattn must have head_dim as a multiple of 32
|
|
144
|
+
x = flash_attn_qkvpacked_func(qkv, dropout_p=self.attn_drop.p if self.training else 0.,
|
|
145
|
+
deterministic=self.deterministic)
|
|
146
|
+
else:
|
|
147
|
+
q, k, v = qkv.unbind(2)
|
|
148
|
+
x = memory_efficient_attention(q, k, v, p=self.attn_drop.p if self.training else 0.)
|
|
149
|
+
x = x.reshape(B, N, C)
|
|
150
|
+
x = self.proj(x)
|
|
151
|
+
x = self.proj_drop(x)
|
|
152
|
+
return x
|
|
153
|
+
|
|
154
|
+
qkv = qkv.permute(2, 0, 3, 1, 4)
|
|
155
|
+
q, k, v = qkv.unbind(0)
|
|
156
|
+
q, k = self.q_norm(q), self.k_norm(k)
|
|
157
|
+
|
|
158
|
+
if self.fused_attn:
|
|
159
|
+
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=False):
|
|
160
|
+
# 用上下文的方式强行使用fa
|
|
161
|
+
x = F.scaled_dot_product_attention(
|
|
162
|
+
q, k, v,
|
|
163
|
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
|
164
|
+
)
|
|
165
|
+
else:
|
|
166
|
+
q = q * self.scale
|
|
167
|
+
attn = q @ k.transpose(-2, -1)
|
|
168
|
+
attn = attn.softmax(dim=-1)
|
|
169
|
+
attn = self.attn_drop(attn)
|
|
170
|
+
x = attn @ v
|
|
171
|
+
|
|
172
|
+
x = x.transpose(1, 2).reshape(B, N, C)
|
|
173
|
+
x = self.proj(x)
|
|
174
|
+
x = self.proj_drop(x)
|
|
175
|
+
return x
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class LayerScale(nn.Module):
|
|
179
|
+
def __init__(
|
|
180
|
+
self,
|
|
181
|
+
dim: int,
|
|
182
|
+
init_values: float = 1e-5,
|
|
183
|
+
inplace: bool = False,
|
|
184
|
+
) -> None:
|
|
185
|
+
super().__init__()
|
|
186
|
+
self.inplace = inplace
|
|
187
|
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
|
188
|
+
|
|
189
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
190
|
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class Block(nn.Module):
|
|
194
|
+
def __init__(
|
|
195
|
+
self,
|
|
196
|
+
dim: int,
|
|
197
|
+
num_heads: int,
|
|
198
|
+
mlp_ratio: float = 4.,
|
|
199
|
+
qkv_bias: bool = False,
|
|
200
|
+
qk_norm: bool = False,
|
|
201
|
+
proj_drop: float = 0.,
|
|
202
|
+
attn_drop: float = 0.,
|
|
203
|
+
init_values: Optional[float] = None,
|
|
204
|
+
drop_path: float = 0.,
|
|
205
|
+
act_layer: nn.Module = nn.GELU,
|
|
206
|
+
norm_layer: nn.Module = nn.LayerNorm,
|
|
207
|
+
mlp_layer: nn.Module = Mlp,
|
|
208
|
+
deterministic: bool = False,
|
|
209
|
+
) -> None:
|
|
210
|
+
super().__init__()
|
|
211
|
+
self.norm1 = norm_layer(dim)
|
|
212
|
+
self.attn = Attention(
|
|
213
|
+
dim,
|
|
214
|
+
num_heads=num_heads,
|
|
215
|
+
qkv_bias=qkv_bias,
|
|
216
|
+
qk_norm=qk_norm,
|
|
217
|
+
attn_drop=attn_drop,
|
|
218
|
+
proj_drop=proj_drop,
|
|
219
|
+
norm_layer=norm_layer,
|
|
220
|
+
deterministic=deterministic,
|
|
221
|
+
)
|
|
222
|
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
223
|
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
224
|
+
|
|
225
|
+
self.norm2 = norm_layer(dim)
|
|
226
|
+
self.mlp = mlp_layer(
|
|
227
|
+
in_features=dim,
|
|
228
|
+
hidden_features=int(dim * mlp_ratio),
|
|
229
|
+
act_layer=act_layer,
|
|
230
|
+
drop=proj_drop,
|
|
231
|
+
)
|
|
232
|
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
233
|
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
234
|
+
|
|
235
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
236
|
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
|
237
|
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
|
238
|
+
return x
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class VisionTransformer(nn.Module):
|
|
242
|
+
""" Vision Transformer
|
|
243
|
+
|
|
244
|
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
|
245
|
+
- https://arxiv.org/abs/2010.11929
|
|
246
|
+
"""
|
|
247
|
+
dynamic_img_size: Final[bool]
|
|
248
|
+
|
|
249
|
+
def __init__(
|
|
250
|
+
self,
|
|
251
|
+
img_size: Union[int, Tuple[int, int]] = 224,
|
|
252
|
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
|
253
|
+
in_chans: int = 3,
|
|
254
|
+
num_classes: int = 1000,
|
|
255
|
+
global_pool: Literal['', 'avg', 'token', 'map'] = 'token',
|
|
256
|
+
embed_dim: int = 768,
|
|
257
|
+
depth: int = 12,
|
|
258
|
+
num_heads: int = 12,
|
|
259
|
+
mlp_ratio: float = 4.,
|
|
260
|
+
qkv_bias: bool = True,
|
|
261
|
+
qk_norm: bool = False,
|
|
262
|
+
init_values: Optional[float] = None,
|
|
263
|
+
class_token: bool = True,
|
|
264
|
+
no_embed_class: bool = False,
|
|
265
|
+
reg_tokens: int = 0,
|
|
266
|
+
pre_norm: bool = False,
|
|
267
|
+
fc_norm: Optional[bool] = None,
|
|
268
|
+
dynamic_img_size: bool = False,
|
|
269
|
+
dynamic_img_pad: bool = False,
|
|
270
|
+
drop_rate: float = 0.,
|
|
271
|
+
pos_drop_rate: float = 0.,
|
|
272
|
+
patch_drop_rate: float = 0.,
|
|
273
|
+
proj_drop_rate: float = 0.,
|
|
274
|
+
attn_drop_rate: float = 0.,
|
|
275
|
+
drop_path_rate: float = 0.,
|
|
276
|
+
weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '',
|
|
277
|
+
embed_layer: Callable = PatchEmbed,
|
|
278
|
+
norm_layer: Optional[LayerType] = None,
|
|
279
|
+
act_layer: Optional[LayerType] = None,
|
|
280
|
+
block_fn: Type[nn.Module] = Block,
|
|
281
|
+
mlp_layer: Type[nn.Module] = Mlp,
|
|
282
|
+
ignore_head: bool = False,
|
|
283
|
+
deterministic: bool = False,
|
|
284
|
+
num_recomputing_layers: int = 0
|
|
285
|
+
) -> None:
|
|
286
|
+
"""
|
|
287
|
+
Args:
|
|
288
|
+
img_size: Input image size.
|
|
289
|
+
patch_size: Patch size.
|
|
290
|
+
in_chans: Number of image input channels.
|
|
291
|
+
num_classes: Mumber of classes for classification head.
|
|
292
|
+
global_pool: Type of global pooling for final sequence (default: 'token').
|
|
293
|
+
embed_dim: Transformer embedding dimension.
|
|
294
|
+
depth: Depth of transformer.
|
|
295
|
+
num_heads: Number of attention heads.
|
|
296
|
+
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
|
297
|
+
qkv_bias: Enable bias for qkv projections if True.
|
|
298
|
+
init_values: Layer-scale init values (layer-scale enabled if not None).
|
|
299
|
+
class_token: Use class token.
|
|
300
|
+
no_embed_class: Don't include position embeddings for class (or reg) tokens.
|
|
301
|
+
reg_tokens: Number of register tokens.
|
|
302
|
+
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
|
303
|
+
drop_rate: Head dropout rate.
|
|
304
|
+
pos_drop_rate: Position embedding dropout rate.
|
|
305
|
+
attn_drop_rate: Attention dropout rate.
|
|
306
|
+
drop_path_rate: Stochastic depth rate.
|
|
307
|
+
weight_init: Weight initialization scheme.
|
|
308
|
+
embed_layer: Patch embedding layer.
|
|
309
|
+
norm_layer: Normalization layer.
|
|
310
|
+
act_layer: MLP activation layer.
|
|
311
|
+
block_fn: Transformer block layer.
|
|
312
|
+
"""
|
|
313
|
+
super().__init__()
|
|
314
|
+
assert global_pool in ('', 'avg', 'token', 'map')
|
|
315
|
+
assert class_token or global_pool != 'token'
|
|
316
|
+
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
|
|
317
|
+
# norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
|
318
|
+
# act_layer = get_act_layer(act_layer) or nn.GELU
|
|
319
|
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
|
320
|
+
# siglip use PytorchGELUTanh() rather than the vanilla nn.GELU()
|
|
321
|
+
# https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/siglip/configuration_siglip.py#L191
|
|
322
|
+
act_layer = partial(nn.GELU, approximate='tanh')
|
|
323
|
+
|
|
324
|
+
self.num_classes = num_classes
|
|
325
|
+
self.global_pool = global_pool
|
|
326
|
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
|
327
|
+
self.num_prefix_tokens = 1 if class_token else 0
|
|
328
|
+
self.num_prefix_tokens += reg_tokens
|
|
329
|
+
self.num_reg_tokens = reg_tokens
|
|
330
|
+
self.has_class_token = class_token
|
|
331
|
+
self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
|
|
332
|
+
self.dynamic_img_size = dynamic_img_size
|
|
333
|
+
self.grad_checkpointing = False
|
|
334
|
+
self.ignore_head = ignore_head
|
|
335
|
+
self.num_recomputing_layers = num_recomputing_layers
|
|
336
|
+
|
|
337
|
+
embed_args = {}
|
|
338
|
+
if dynamic_img_size:
|
|
339
|
+
# flatten deferred until after pos embed
|
|
340
|
+
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
|
|
341
|
+
self.patch_embed = embed_layer(
|
|
342
|
+
img_size=img_size,
|
|
343
|
+
patch_size=patch_size,
|
|
344
|
+
in_chans=in_chans,
|
|
345
|
+
embed_dim=embed_dim,
|
|
346
|
+
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
|
347
|
+
dynamic_img_pad=dynamic_img_pad,
|
|
348
|
+
**embed_args,
|
|
349
|
+
)
|
|
350
|
+
num_patches = self.patch_embed.num_patches
|
|
351
|
+
|
|
352
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
|
353
|
+
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
|
354
|
+
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
|
355
|
+
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
|
356
|
+
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
|
357
|
+
if patch_drop_rate > 0:
|
|
358
|
+
self.patch_drop = PatchDropout(
|
|
359
|
+
patch_drop_rate,
|
|
360
|
+
num_prefix_tokens=self.num_prefix_tokens,
|
|
361
|
+
)
|
|
362
|
+
else:
|
|
363
|
+
self.patch_drop = nn.Identity()
|
|
364
|
+
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
|
365
|
+
|
|
366
|
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
|
367
|
+
self.blocks = nn.Sequential(*[
|
|
368
|
+
block_fn(
|
|
369
|
+
dim=embed_dim,
|
|
370
|
+
num_heads=num_heads,
|
|
371
|
+
mlp_ratio=mlp_ratio,
|
|
372
|
+
qkv_bias=qkv_bias,
|
|
373
|
+
qk_norm=qk_norm,
|
|
374
|
+
init_values=init_values,
|
|
375
|
+
proj_drop=proj_drop_rate,
|
|
376
|
+
attn_drop=attn_drop_rate,
|
|
377
|
+
drop_path=dpr[i],
|
|
378
|
+
norm_layer=norm_layer,
|
|
379
|
+
act_layer=act_layer,
|
|
380
|
+
mlp_layer=mlp_layer,
|
|
381
|
+
deterministic=deterministic,
|
|
382
|
+
)
|
|
383
|
+
for i in range(depth)])
|
|
384
|
+
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
|
385
|
+
|
|
386
|
+
# Classifier Head
|
|
387
|
+
if global_pool == 'map':
|
|
388
|
+
AttentionPoolLatent.init_weights = init_weights
|
|
389
|
+
self.attn_pool = AttentionPoolLatent(
|
|
390
|
+
self.embed_dim,
|
|
391
|
+
num_heads=num_heads,
|
|
392
|
+
mlp_ratio=mlp_ratio,
|
|
393
|
+
norm_layer=norm_layer,
|
|
394
|
+
)
|
|
395
|
+
else:
|
|
396
|
+
self.attn_pool = None
|
|
397
|
+
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
|
398
|
+
self.head_drop = nn.Dropout(drop_rate)
|
|
399
|
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
400
|
+
|
|
401
|
+
if weight_init != 'skip':
|
|
402
|
+
self.init_weights(weight_init)
|
|
403
|
+
|
|
404
|
+
def init_weights(self, mode: Literal['jax', 'jax_nlhb', 'moco', ''] = '') -> None:
|
|
405
|
+
assert mode in ('jax', 'jax_nlhb', 'moco', '')
|
|
406
|
+
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
|
407
|
+
trunc_normal_(self.pos_embed, std=.02)
|
|
408
|
+
if self.cls_token is not None:
|
|
409
|
+
nn.init.normal_(self.cls_token, std=1e-6)
|
|
410
|
+
named_apply(init_weights_vit_timm, self)
|
|
411
|
+
|
|
412
|
+
@torch.jit.ignore
|
|
413
|
+
def no_weight_decay(self) -> Set:
|
|
414
|
+
return {'pos_embed', 'cls_token', 'dist_token'}
|
|
415
|
+
|
|
416
|
+
@torch.jit.ignore
|
|
417
|
+
def group_matcher(self, coarse: bool = False) -> Dict:
|
|
418
|
+
return dict(
|
|
419
|
+
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
|
|
420
|
+
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
@torch.jit.ignore
|
|
424
|
+
def set_grad_checkpointing(self, enable: bool = True) -> None:
|
|
425
|
+
self.grad_checkpointing = enable
|
|
426
|
+
|
|
427
|
+
@torch.jit.ignore
|
|
428
|
+
def get_classifier(self) -> nn.Module:
|
|
429
|
+
return self.head
|
|
430
|
+
|
|
431
|
+
def reset_classifier(self, num_classes: int, global_pool=None) -> None:
|
|
432
|
+
self.num_classes = num_classes
|
|
433
|
+
if global_pool is not None:
|
|
434
|
+
assert global_pool in ('', 'avg', 'token', 'map')
|
|
435
|
+
if global_pool == 'map' and self.attn_pool is None:
|
|
436
|
+
assert False, "Cannot currently add attention pooling in reset_classifier()."
|
|
437
|
+
elif global_pool != 'map ' and self.attn_pool is not None:
|
|
438
|
+
self.attn_pool = None # remove attention pooling
|
|
439
|
+
self.global_pool = global_pool
|
|
440
|
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
441
|
+
|
|
442
|
+
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
|
443
|
+
if self.dynamic_img_size:
|
|
444
|
+
B, H, W, C = x.shape
|
|
445
|
+
pos_embed = resample_abs_pos_embed(
|
|
446
|
+
self.pos_embed,
|
|
447
|
+
(H, W),
|
|
448
|
+
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
|
449
|
+
)
|
|
450
|
+
x = x.view(B, -1, C)
|
|
451
|
+
else:
|
|
452
|
+
pos_embed = self.pos_embed
|
|
453
|
+
|
|
454
|
+
to_cat = []
|
|
455
|
+
if self.cls_token is not None:
|
|
456
|
+
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
|
457
|
+
if self.reg_token is not None:
|
|
458
|
+
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
|
459
|
+
|
|
460
|
+
if self.no_embed_class:
|
|
461
|
+
# deit-3, updated JAX (big vision)
|
|
462
|
+
# position embedding does not overlap with class token, add then concat
|
|
463
|
+
x = x + pos_embed
|
|
464
|
+
if to_cat:
|
|
465
|
+
x = torch.cat(to_cat + [x], dim=1)
|
|
466
|
+
else:
|
|
467
|
+
# original timm, JAX, and deit vit impl
|
|
468
|
+
# pos_embed has entry for class token, concat then add
|
|
469
|
+
if to_cat:
|
|
470
|
+
x = torch.cat(to_cat + [x], dim=1)
|
|
471
|
+
x = x + pos_embed
|
|
472
|
+
|
|
473
|
+
return self.pos_drop(x)
|
|
474
|
+
|
|
475
|
+
def _intermediate_layers(
|
|
476
|
+
self,
|
|
477
|
+
x: torch.Tensor,
|
|
478
|
+
n: Union[int, Sequence] = 1,
|
|
479
|
+
) -> List[torch.Tensor]:
|
|
480
|
+
outputs, num_blocks = [], len(self.blocks)
|
|
481
|
+
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
|
|
482
|
+
|
|
483
|
+
# forward pass
|
|
484
|
+
x = self.patch_embed(x)
|
|
485
|
+
x = self._pos_embed(x)
|
|
486
|
+
x = self.patch_drop(x)
|
|
487
|
+
x = self.norm_pre(x)
|
|
488
|
+
for i, blk in enumerate(self.blocks):
|
|
489
|
+
x = blk(x)
|
|
490
|
+
if i in take_indices:
|
|
491
|
+
outputs.append(x)
|
|
492
|
+
|
|
493
|
+
return outputs
|
|
494
|
+
|
|
495
|
+
def get_intermediate_layers(
|
|
496
|
+
self,
|
|
497
|
+
x: torch.Tensor,
|
|
498
|
+
n: Union[int, Sequence] = 1,
|
|
499
|
+
reshape: bool = False,
|
|
500
|
+
return_prefix_tokens: bool = False,
|
|
501
|
+
norm: bool = False,
|
|
502
|
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
|
503
|
+
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
|
|
504
|
+
Inspired by DINO / DINOv2 interface
|
|
505
|
+
"""
|
|
506
|
+
# take last n blocks if n is an int, if in is a sequence, select by matching indices
|
|
507
|
+
outputs = self._intermediate_layers(x, n)
|
|
508
|
+
if norm:
|
|
509
|
+
outputs = [self.norm(out) for out in outputs]
|
|
510
|
+
prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
|
|
511
|
+
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
|
|
512
|
+
|
|
513
|
+
if reshape:
|
|
514
|
+
grid_size = self.patch_embed.grid_size
|
|
515
|
+
outputs = [
|
|
516
|
+
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
|
|
517
|
+
for out in outputs
|
|
518
|
+
]
|
|
519
|
+
|
|
520
|
+
if return_prefix_tokens:
|
|
521
|
+
return tuple(zip(outputs, prefix_tokens))
|
|
522
|
+
return tuple(outputs)
|
|
523
|
+
|
|
524
|
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
525
|
+
if getattr(self, "is_first_stage", True):
|
|
526
|
+
x = self.patch_embed(x)
|
|
527
|
+
x = self._pos_embed(x)
|
|
528
|
+
x = self.patch_drop(x)
|
|
529
|
+
x = self.norm_pre(x)
|
|
530
|
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
531
|
+
skip_last = max(1, len(self.blocks) - self.num_recomputing_layers)
|
|
532
|
+
x = checkpoint_seq(self.blocks, x, skip_last=skip_last)
|
|
533
|
+
else:
|
|
534
|
+
x = self.blocks(x)
|
|
535
|
+
if getattr(self, "is_last_stage", True):
|
|
536
|
+
x = self.norm(x)
|
|
537
|
+
return x
|
|
538
|
+
|
|
539
|
+
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
|
540
|
+
if not getattr(self, "is_last_stage", True):
|
|
541
|
+
return x
|
|
542
|
+
if self.attn_pool is not None:
|
|
543
|
+
x = self.attn_pool(x)
|
|
544
|
+
elif self.global_pool == 'avg':
|
|
545
|
+
x = x[:, self.num_prefix_tokens:].mean(dim=1)
|
|
546
|
+
elif self.global_pool:
|
|
547
|
+
x = x[:, 0] # class token
|
|
548
|
+
x = self.fc_norm(x)
|
|
549
|
+
x = self.head_drop(x)
|
|
550
|
+
return x if pre_logits else self.head(x)
|
|
551
|
+
|
|
552
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
553
|
+
x = self.forward_features(x)
|
|
554
|
+
if not self.ignore_head:
|
|
555
|
+
x = self.forward_head(x)
|
|
556
|
+
return x
|
|
557
|
+
|
|
558
|
+
def to_pipeline(self, pp_size, pp_rank, pp_splits: Optional[List[int]] = None):
|
|
559
|
+
self.is_first_stage = pp_rank == 0
|
|
560
|
+
self.is_last_stage = pp_rank == pp_size - 1
|
|
561
|
+
if not self.is_first_stage and hasattr(self, "patch_embed"):
|
|
562
|
+
del self.patch_embed, self.cls_token, self.reg_token, self.pos_embed, self.pos_drop, self.patch_drop, self.norm_pre
|
|
563
|
+
if not self.is_last_stage and hasattr(self, "norm"):
|
|
564
|
+
del self.norm, self.attn_pool, self.fc_norm, self.head_drop, self.head
|
|
565
|
+
if pp_splits is not None:
|
|
566
|
+
assert len(self.blocks) == sum(pp_splits)
|
|
567
|
+
splits = np.cumsum([0] + pp_splits)
|
|
568
|
+
self.blocks = self.blocks[splits[pp_rank]:splits[pp_rank + 1]]
|
|
569
|
+
return self
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
@dataclass
|
|
573
|
+
class SigLIPVisionCfg:
|
|
574
|
+
width: int = 1152
|
|
575
|
+
layers: Union[Tuple[int, int, int, int], int] = 27
|
|
576
|
+
heads: int = 16
|
|
577
|
+
patch_size: int = 14
|
|
578
|
+
image_size: Union[Tuple[int, int], int] = 336
|
|
579
|
+
global_pool: str = "map"
|
|
580
|
+
mlp_ratio: float = 3.7362
|
|
581
|
+
class_token: bool = False
|
|
582
|
+
num_classes: int = 0
|
|
583
|
+
use_checkpoint: bool = False
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
SigLIP_MODEL_CONFIG = {
|
|
587
|
+
"siglip_so400m_patch14_384": {
|
|
588
|
+
"image_size": 384,
|
|
589
|
+
"patch_size": 14,
|
|
590
|
+
"width": 1152,
|
|
591
|
+
"layers": 27,
|
|
592
|
+
"heads": 16,
|
|
593
|
+
"mlp_ratio": 3.7362,
|
|
594
|
+
"global_pool": "map",
|
|
595
|
+
"use_checkpoint": False
|
|
596
|
+
},
|
|
597
|
+
|
|
598
|
+
"siglip_so400m_patch14_224": {
|
|
599
|
+
"image_size": 224,
|
|
600
|
+
"patch_size": 14,
|
|
601
|
+
"width": 1152,
|
|
602
|
+
"layers": 27,
|
|
603
|
+
"heads": 16,
|
|
604
|
+
"mlp_ratio": 3.7362,
|
|
605
|
+
"global_pool": "map",
|
|
606
|
+
"use_checkpoint": False
|
|
607
|
+
},
|
|
608
|
+
|
|
609
|
+
"siglip_large_patch16_384": {
|
|
610
|
+
"image_size": 384,
|
|
611
|
+
"patch_size": 16,
|
|
612
|
+
"width": 1024,
|
|
613
|
+
"layers": 24,
|
|
614
|
+
"heads": 16,
|
|
615
|
+
"mlp_ratio": 4,
|
|
616
|
+
"global_pool": "map",
|
|
617
|
+
"use_checkpoint": False
|
|
618
|
+
}
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
def create_siglip_vit(
|
|
623
|
+
model_name: str = "siglip_so400m_patch14_384",
|
|
624
|
+
image_size: int = 384,
|
|
625
|
+
select_layer: int = -1,
|
|
626
|
+
ckpt_path: str = "",
|
|
627
|
+
**kwargs
|
|
628
|
+
):
|
|
629
|
+
assert model_name in SigLIP_MODEL_CONFIG.keys(), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
|
|
630
|
+
|
|
631
|
+
vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
|
|
632
|
+
|
|
633
|
+
if select_layer <= 0:
|
|
634
|
+
layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
|
|
635
|
+
else:
|
|
636
|
+
layers = min(vision_cfg.layers, select_layer)
|
|
637
|
+
|
|
638
|
+
model = VisionTransformer(
|
|
639
|
+
img_size=image_size,
|
|
640
|
+
patch_size=vision_cfg.patch_size,
|
|
641
|
+
embed_dim=vision_cfg.width,
|
|
642
|
+
depth=layers,
|
|
643
|
+
num_heads=vision_cfg.heads,
|
|
644
|
+
mlp_ratio=vision_cfg.mlp_ratio,
|
|
645
|
+
class_token=vision_cfg.class_token,
|
|
646
|
+
global_pool=vision_cfg.global_pool,
|
|
647
|
+
ignore_head=kwargs.get("ignore_head", True),
|
|
648
|
+
weight_init=kwargs.get("weight_init", "skip"),
|
|
649
|
+
num_classes=0,
|
|
650
|
+
deterministic=kwargs.get("deterministic", False),
|
|
651
|
+
num_recomputing_layers=kwargs.get("num_recomputing_layers", 0)
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
if ckpt_path:
|
|
655
|
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
|
656
|
+
|
|
657
|
+
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
|
658
|
+
print(f"SigLIP-ViT restores from {ckpt_path},\n"
|
|
659
|
+
f"\tincompatible_keys:', {incompatible_keys}.")
|
|
660
|
+
|
|
661
|
+
return model
|
|
File without changes
|
|
File without changes
|