ai-edge-torch-nightly 0.3.0.dev20250203__py3-none-any.whl → 0.3.0.dev20250204__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/paligemma/decoder.py +11 -5
- ai_edge_torch/generative/examples/paligemma/decoder2.py +4 -4
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +1 -0
- ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +379 -0
- ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py +84 -0
- ai_edge_torch/generative/layers/attention.py +0 -1
- ai_edge_torch/generative/layers/model_config.py +2 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250203.dist-info → ai_edge_torch_nightly-0.3.0.dev20250204.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250203.dist-info → ai_edge_torch_nightly-0.3.0.dev20250204.dist-info}/RECORD +13 -11
- {ai_edge_torch_nightly-0.3.0.dev20250203.dist-info → ai_edge_torch_nightly-0.3.0.dev20250204.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250203.dist-info → ai_edge_torch_nightly-0.3.0.dev20250204.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250203.dist-info → ai_edge_torch_nightly-0.3.0.dev20250204.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,6 @@ from typing import Optional
|
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
|
-
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
23
22
|
from ai_edge_torch.generative.utilities import model_builder
|
24
23
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
25
24
|
import torch
|
@@ -59,7 +58,9 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
59
58
|
called_by_generate: bool = True,
|
60
59
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
61
60
|
if input_embeds is None:
|
62
|
-
return super().forward(
|
61
|
+
return super().forward(
|
62
|
+
tokens, input_pos, kv_cache, mask, export_config=export_config
|
63
|
+
)
|
63
64
|
|
64
65
|
assert input_embeds is not None
|
65
66
|
|
@@ -67,17 +68,22 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
67
68
|
# ROPE parameters for all attn_configs are the same. Take the first one.
|
68
69
|
attn_config = self.config.block_config(0).attn_config
|
69
70
|
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
70
|
-
rope =
|
71
|
+
rope = self.config.build_rope(repo_pos, n_elem, attn_config.rotary_base)
|
71
72
|
|
72
73
|
# The first part of input_embeds are image embeddings. Diagonal causal mask
|
73
74
|
# doesn't work here.
|
74
|
-
embeds_len = input_embeds.shape[1]
|
75
75
|
if mask is None:
|
76
|
+
embeds_len = input_embeds.shape[1]
|
76
77
|
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
|
77
78
|
mask[:, embeds_len:] = float("-inf")
|
78
79
|
|
79
80
|
return self._forward_with_embeds(
|
80
|
-
input_embeds,
|
81
|
+
input_embeds,
|
82
|
+
rope,
|
83
|
+
mask,
|
84
|
+
input_pos,
|
85
|
+
kv_cache,
|
86
|
+
export_config=export_config,
|
81
87
|
)
|
82
88
|
|
83
89
|
|
@@ -20,7 +20,6 @@ from typing import Optional
|
|
20
20
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
|
-
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
24
23
|
from ai_edge_torch.generative.utilities import model_builder
|
25
24
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
25
|
import torch
|
@@ -62,7 +61,7 @@ class Decoder2(gemma2.Gemma2):
|
|
62
61
|
called_by_generate: bool = True,
|
63
62
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
64
63
|
if input_embeds is None:
|
65
|
-
return super().forward(tokens, input_pos, kv_cache)
|
64
|
+
return super().forward(tokens, input_pos, kv_cache, mask, export_config)
|
66
65
|
|
67
66
|
assert input_embeds is not None
|
68
67
|
|
@@ -70,11 +69,12 @@ class Decoder2(gemma2.Gemma2):
|
|
70
69
|
# ROPE parameters for all attn_configs are the same. Take the first one.
|
71
70
|
attn_config = self.config.block_config(0).attn_config
|
72
71
|
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
73
|
-
rope =
|
72
|
+
rope = self.config.build_rope(repo_pos, n_elem, attn_config.rotary_base)
|
74
73
|
|
75
74
|
if mask is None:
|
76
75
|
if called_by_generate:
|
77
|
-
# PaliGemma2 generate()
|
76
|
+
# PaliGemma2 generate() uses a diagonal causal mask even with image
|
77
|
+
# embeds.
|
78
78
|
mask = [
|
79
79
|
self.get_attention_mask(
|
80
80
|
self.config.block_config(i).attn_config.attn_type, input_pos
|
@@ -60,6 +60,7 @@ class SiglipVisionEncoder(nn.Module):
|
|
60
60
|
kernel_size=config.image_embedding.patch_size,
|
61
61
|
stride=config.image_embedding.patch_size,
|
62
62
|
padding=0,
|
63
|
+
bias=config.embedding_use_bias,
|
63
64
|
)
|
64
65
|
num_patches = (
|
65
66
|
config.image_embedding.image_size // config.image_embedding.patch_size
|
@@ -0,0 +1,379 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building an image encoder of Qwen 2.5 VL model."""
|
17
|
+
|
18
|
+
import dataclasses
|
19
|
+
from typing import Optional
|
20
|
+
|
21
|
+
from ai_edge_torch.generative.layers import attention
|
22
|
+
from ai_edge_torch.generative.layers import attention_utils
|
23
|
+
from ai_edge_torch.generative.layers import builder
|
24
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
|
+
import torch
|
27
|
+
from torch import nn
|
28
|
+
import torch.nn.functional as F
|
29
|
+
|
30
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
31
|
+
ff_up_proj="visual.blocks.{}.mlp.up_proj",
|
32
|
+
ff_down_proj="visual.blocks.{}.mlp.down_proj",
|
33
|
+
ff_gate_proj="visual.blocks.{}.mlp.gate_proj",
|
34
|
+
attn_fused_qkv_proj="visual.blocks.{}.attn.qkv",
|
35
|
+
attn_output_proj="visual.blocks.{}.attn.proj",
|
36
|
+
pre_attn_norm="visual.blocks.{}.norm1",
|
37
|
+
post_attn_norm="visual.blocks.{}.norm2",
|
38
|
+
embedding="visual.patch_embed.proj",
|
39
|
+
final_norm="visual.merger.ln_q",
|
40
|
+
)
|
41
|
+
|
42
|
+
MERGER_TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
43
|
+
ff_up_proj="visual.merger.mlp.0",
|
44
|
+
ff_down_proj="visual.merger.mlp.2",
|
45
|
+
)
|
46
|
+
|
47
|
+
|
48
|
+
@dataclasses.dataclass
|
49
|
+
class QwenVLMergerConfig:
|
50
|
+
"""Merger parameters."""
|
51
|
+
|
52
|
+
activation: cfg.ActivationConfig
|
53
|
+
intermediate_size: int
|
54
|
+
out_embedding_dim: int
|
55
|
+
use_bias: bool = False
|
56
|
+
|
57
|
+
|
58
|
+
@dataclasses.dataclass
|
59
|
+
class QwenVLImageConfig(cfg.ModelConfig):
|
60
|
+
"""model config for Qwen 2.5 VL model."""
|
61
|
+
|
62
|
+
merger_config: Optional[QwenVLMergerConfig] = None
|
63
|
+
window_size: Optional[int] = None
|
64
|
+
spatial_merge_size: Optional[int] = None
|
65
|
+
full_atten_block_indexes: Optional[list[int]] = None
|
66
|
+
|
67
|
+
|
68
|
+
class QwenVLMerger(nn.Module):
|
69
|
+
"""Merger of Qwen 2.5 VL models from the Edge Generative API.
|
70
|
+
|
71
|
+
It's based on Qwen2_5_VLPatchMerger.
|
72
|
+
"""
|
73
|
+
|
74
|
+
def __init__(self, config: QwenVLImageConfig):
|
75
|
+
super().__init__()
|
76
|
+
self.intermediate_size = config.merger_config.intermediate_size
|
77
|
+
self.w1 = nn.Linear(self.intermediate_size, self.intermediate_size)
|
78
|
+
self.act = builder.get_activation(config.merger_config.activation)
|
79
|
+
self.w2 = nn.Linear(
|
80
|
+
self.intermediate_size, config.merger_config.out_embedding_dim
|
81
|
+
)
|
82
|
+
|
83
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
84
|
+
x_reshaped = x.view(-1, self.intermediate_size)
|
85
|
+
return self.w2(self.act(self.w1(x_reshaped)))
|
86
|
+
|
87
|
+
|
88
|
+
class QwenVLImageEncoder(nn.Module):
|
89
|
+
"""Image encoder of Qwen 2.5 VL models from the Edge Generative API."""
|
90
|
+
|
91
|
+
def __init__(self, config: QwenVLImageConfig):
|
92
|
+
super().__init__()
|
93
|
+
|
94
|
+
# Tensor shape used to reshape pixel_values in forward() and various places.
|
95
|
+
self.kernel_size = (
|
96
|
+
-1, # batch size
|
97
|
+
config.image_embedding.channels,
|
98
|
+
config.image_embedding.temporal_patch_size,
|
99
|
+
config.image_embedding.patch_size,
|
100
|
+
config.image_embedding.patch_size,
|
101
|
+
)
|
102
|
+
self.tok_embedding = nn.Conv3d(
|
103
|
+
in_channels=self.kernel_size[1],
|
104
|
+
out_channels=config.embedding_dim,
|
105
|
+
kernel_size=self.kernel_size[2:],
|
106
|
+
stride=self.kernel_size[2:],
|
107
|
+
padding=0,
|
108
|
+
bias=config.embedding_use_bias,
|
109
|
+
)
|
110
|
+
|
111
|
+
self.transformer_blocks = nn.ModuleList(
|
112
|
+
attention.TransformerBlock(config.block_config(idx), config)
|
113
|
+
for idx in range(config.num_layers)
|
114
|
+
)
|
115
|
+
self.final_norm = builder.build_norm(
|
116
|
+
config.embedding_dim,
|
117
|
+
config.final_norm_config,
|
118
|
+
)
|
119
|
+
self.merger = QwenVLMerger(config)
|
120
|
+
self.config = config
|
121
|
+
|
122
|
+
@torch.inference_mode
|
123
|
+
def forward(
|
124
|
+
self, pixel_values: torch.Tensor, grid_thw: torch.Tensor
|
125
|
+
) -> torch.Tensor:
|
126
|
+
# Get window index and sequence lengths to rearrange the input tensor.
|
127
|
+
window_index, cu_seqlens = self._get_window_index(grid_thw)
|
128
|
+
|
129
|
+
# Embed the image and rearrange the embedding tensor.
|
130
|
+
pixel_reshaped = pixel_values.view(self.kernel_size)
|
131
|
+
x = self.tok_embedding(pixel_reshaped)
|
132
|
+
x = x.view(-1, self.config.embedding_dim)
|
133
|
+
x = self._rearrange(x, window_index).unsqueeze(0)
|
134
|
+
|
135
|
+
# Get RoPE and attention mask arranged according to the window index.
|
136
|
+
cos, sin = self._get_rope(grid_thw)
|
137
|
+
rope = (
|
138
|
+
self._rearrange(cos, window_index),
|
139
|
+
self._rearrange(sin, window_index),
|
140
|
+
)
|
141
|
+
|
142
|
+
mask = self._get_mask(x.shape[1], cu_seqlens)
|
143
|
+
full_mask = torch.zeros(x.shape[:2])
|
144
|
+
for i, block in enumerate(self.transformer_blocks):
|
145
|
+
x = block(
|
146
|
+
x,
|
147
|
+
rope=rope,
|
148
|
+
mask=full_mask if i in self.config.full_atten_block_indexes else mask,
|
149
|
+
)
|
150
|
+
|
151
|
+
y = self.merger.forward(self.final_norm(x))
|
152
|
+
# Arrange the output back to the original order.
|
153
|
+
reverse_index = torch.argsort(window_index)
|
154
|
+
return y[reverse_index, ...]
|
155
|
+
|
156
|
+
def _get_rope(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
157
|
+
"""Get RoPE for Qwen VL model based on image grid information.
|
158
|
+
|
159
|
+
It's copied from Qwen2_5_VisionTransformerPretrainedModel.rot_pos_emb() and
|
160
|
+
modified accordingly.
|
161
|
+
"""
|
162
|
+
pos_ids = []
|
163
|
+
for t, h, w in grid_thw:
|
164
|
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
165
|
+
hpos_ids = hpos_ids.reshape(
|
166
|
+
h // self.config.spatial_merge_size,
|
167
|
+
self.config.spatial_merge_size,
|
168
|
+
w // self.config.spatial_merge_size,
|
169
|
+
self.config.spatial_merge_size,
|
170
|
+
)
|
171
|
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
172
|
+
hpos_ids = hpos_ids.flatten()
|
173
|
+
|
174
|
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
175
|
+
wpos_ids = wpos_ids.reshape(
|
176
|
+
h // self.config.spatial_merge_size,
|
177
|
+
self.config.spatial_merge_size,
|
178
|
+
w // self.config.spatial_merge_size,
|
179
|
+
self.config.spatial_merge_size,
|
180
|
+
)
|
181
|
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
182
|
+
wpos_ids = wpos_ids.flatten()
|
183
|
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
184
|
+
pos_ids = torch.cat(pos_ids, dim=0)
|
185
|
+
max_grid_size = grid_thw[:, 1:].max()
|
186
|
+
|
187
|
+
cos, sin = attention_utils.build_rope_cache(
|
188
|
+
max_grid_size,
|
189
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
190
|
+
self.config.block_config(0).attn_config.head_dim // 2,
|
191
|
+
)
|
192
|
+
return cos[pos_ids].flatten(1), sin[pos_ids].flatten(1)
|
193
|
+
|
194
|
+
def _get_window_index(self, grid_thw: torch.Tensor):
|
195
|
+
"""Get window index for Qwen VL model to rearrange the input tensor.
|
196
|
+
|
197
|
+
It's copied from Qwen2_5_VisionTransformerPretrainedModel.get_window_index()
|
198
|
+
and modified accordingly.
|
199
|
+
"""
|
200
|
+
window_index: list = []
|
201
|
+
cu_window_seqlens: list = [0]
|
202
|
+
window_index_id = 0
|
203
|
+
vit_merger_window_size = (
|
204
|
+
self.config.window_size
|
205
|
+
// self.config.spatial_merge_size
|
206
|
+
// self.config.image_embedding.patch_size
|
207
|
+
)
|
208
|
+
|
209
|
+
for grid_t, grid_h, grid_w in grid_thw:
|
210
|
+
llm_grid_h, llm_grid_w = (
|
211
|
+
grid_h // self.config.spatial_merge_size,
|
212
|
+
grid_w // self.config.spatial_merge_size,
|
213
|
+
)
|
214
|
+
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
215
|
+
grid_t, llm_grid_h, llm_grid_w
|
216
|
+
)
|
217
|
+
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
218
|
+
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
219
|
+
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
220
|
+
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
221
|
+
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
222
|
+
index_padded = index_padded.reshape(
|
223
|
+
grid_t,
|
224
|
+
num_windows_h,
|
225
|
+
vit_merger_window_size,
|
226
|
+
num_windows_w,
|
227
|
+
vit_merger_window_size,
|
228
|
+
)
|
229
|
+
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
230
|
+
grid_t,
|
231
|
+
num_windows_h * num_windows_w,
|
232
|
+
vit_merger_window_size,
|
233
|
+
vit_merger_window_size,
|
234
|
+
)
|
235
|
+
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
236
|
+
index_padded = index_padded.reshape(-1)
|
237
|
+
index_new = index_padded[index_padded != -100]
|
238
|
+
window_index.append(index_new + window_index_id)
|
239
|
+
spatial_merge_unit = (
|
240
|
+
self.config.spatial_merge_size * self.config.spatial_merge_size
|
241
|
+
)
|
242
|
+
cu_seqlens_tmp = (
|
243
|
+
seqlens.cumsum(0) * spatial_merge_unit + cu_window_seqlens[-1]
|
244
|
+
)
|
245
|
+
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
246
|
+
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
247
|
+
|
248
|
+
window_index = torch.cat(window_index, dim=0)
|
249
|
+
cu_window_seqlens = torch.tensor(cu_window_seqlens)
|
250
|
+
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
251
|
+
return window_index, cu_window_seqlens
|
252
|
+
|
253
|
+
def _rearrange(
|
254
|
+
self, x: torch.Tensor, window_index: torch.Tensor
|
255
|
+
) -> torch.Tensor:
|
256
|
+
"""Rearrange the tensor according to window_index.
|
257
|
+
|
258
|
+
It's copied from Qwen2_5_VisionTransformerPretrainedModel.forward() and
|
259
|
+
modified accordingly.
|
260
|
+
"""
|
261
|
+
size = x.shape[0]
|
262
|
+
spatial_merge_unit = (
|
263
|
+
self.config.spatial_merge_size * self.config.spatial_merge_size
|
264
|
+
)
|
265
|
+
x_reshaped = x.view(size // spatial_merge_unit, spatial_merge_unit, -1)
|
266
|
+
x_rearranged = x_reshaped[window_index, ...]
|
267
|
+
return x_rearranged.view(size, -1)
|
268
|
+
|
269
|
+
def _get_mask(self, seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
270
|
+
"""Get attention mask for Qwen VL model.
|
271
|
+
|
272
|
+
It's copied from Qwen2_5_VLVisionAttention.forward() and modified
|
273
|
+
accordingly.
|
274
|
+
"""
|
275
|
+
mask = torch.full([1, 1, seqlen, seqlen], float("-inf"))
|
276
|
+
for i in range(1, len(cu_seqlens)):
|
277
|
+
mask[
|
278
|
+
...,
|
279
|
+
cu_seqlens[i - 1] : cu_seqlens[i],
|
280
|
+
cu_seqlens[i - 1] : cu_seqlens[i],
|
281
|
+
] = 0
|
282
|
+
return mask
|
283
|
+
|
284
|
+
|
285
|
+
def get_image_encoder_config() -> QwenVLImageConfig:
|
286
|
+
"""Returns the model config for the image encoder of a Qwen 2.5 VL model.
|
287
|
+
|
288
|
+
Returns:
|
289
|
+
The model config for the image encoder of a Qwen 2.5 VL model.
|
290
|
+
"""
|
291
|
+
image_embedding_config = cfg.ImageEmbeddingConfig(
|
292
|
+
channels=3,
|
293
|
+
image_size=0, # Not used in image encoder.
|
294
|
+
patch_size=14,
|
295
|
+
temporal_patch_size=2,
|
296
|
+
)
|
297
|
+
attn_config = cfg.AttentionConfig(
|
298
|
+
num_heads=16,
|
299
|
+
head_dim=80,
|
300
|
+
num_query_groups=16,
|
301
|
+
qkv_transpose_before_split=True,
|
302
|
+
qkv_use_bias=True,
|
303
|
+
output_proj_use_bias=True,
|
304
|
+
)
|
305
|
+
ff_config = cfg.FeedForwardConfig(
|
306
|
+
type=cfg.FeedForwardType.GATED,
|
307
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
308
|
+
intermediate_size=3420,
|
309
|
+
use_bias=True,
|
310
|
+
)
|
311
|
+
norm_config = cfg.NormalizationConfig(
|
312
|
+
type=cfg.NormalizationType.RMS_NORM,
|
313
|
+
epsilon=1e-6,
|
314
|
+
)
|
315
|
+
block_config = cfg.TransformerBlockConfig(
|
316
|
+
attn_config=attn_config,
|
317
|
+
ff_config=ff_config,
|
318
|
+
pre_attention_norm_config=norm_config,
|
319
|
+
post_attention_norm_config=norm_config,
|
320
|
+
)
|
321
|
+
merger_config = QwenVLMergerConfig(
|
322
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU),
|
323
|
+
intermediate_size=5120, # embedding_dim(1280) * spatial_merge_size(2)^2
|
324
|
+
out_embedding_dim=2048, # embedding_dim of decoder config.
|
325
|
+
use_bias=True,
|
326
|
+
)
|
327
|
+
config = QwenVLImageConfig(
|
328
|
+
vocab_size=0, # Not used in image encoder.
|
329
|
+
num_layers=32,
|
330
|
+
max_seq_len=0, # Not used in image encoder.
|
331
|
+
embedding_dim=1280,
|
332
|
+
image_embedding=image_embedding_config,
|
333
|
+
block_configs=block_config,
|
334
|
+
final_norm_config=norm_config,
|
335
|
+
merger_config=merger_config,
|
336
|
+
window_size=112,
|
337
|
+
spatial_merge_size=2,
|
338
|
+
full_atten_block_indexes=[7, 15, 23, 31],
|
339
|
+
# TODO: b/377051577 - Once RemoveSDPACompositeZeroMaskPass is removed,
|
340
|
+
# enable_hlfb can be set to True. See b/383865404#comment3 for details.
|
341
|
+
# enable_hlfb=True,
|
342
|
+
)
|
343
|
+
return config
|
344
|
+
|
345
|
+
|
346
|
+
def get_fake_image_encoder_config() -> QwenVLImageConfig:
|
347
|
+
config = get_image_encoder_config()
|
348
|
+
# PaliGemma image encoder has only one block config.
|
349
|
+
config.block_config(0).ff_config.intermediate_size = 128
|
350
|
+
config.image_embedding.patch_size = 2
|
351
|
+
config.num_layers = 2
|
352
|
+
config.merger_config.intermediate_size = 128
|
353
|
+
return config
|
354
|
+
|
355
|
+
|
356
|
+
def build_image_encoder(checkpoint_path: str) -> QwenVLImageEncoder:
|
357
|
+
config = get_image_encoder_config()
|
358
|
+
encoder = QwenVLImageEncoder(config)
|
359
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
360
|
+
# Loose the strictness because only image encoder is being loaded.
|
361
|
+
loader.load(encoder, strict=False)
|
362
|
+
|
363
|
+
# Load merger weights.
|
364
|
+
merger_loader = loading_utils.ModelLoader(checkpoint_path, None)
|
365
|
+
state = merger_loader.get_state()
|
366
|
+
w1_state = dict()
|
367
|
+
w1_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.weight")
|
368
|
+
if config.merger_config.use_bias:
|
369
|
+
w1_state["bias"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.bias")
|
370
|
+
encoder.merger.w1.load_state_dict(w1_state)
|
371
|
+
|
372
|
+
w2_state = dict()
|
373
|
+
w2_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_down_proj}.weight")
|
374
|
+
if config.merger_config.use_bias:
|
375
|
+
w2_state["bias"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_down_proj}.bias")
|
376
|
+
encoder.merger.w2.load_state_dict(w2_state)
|
377
|
+
|
378
|
+
encoder.eval()
|
379
|
+
return encoder
|
@@ -0,0 +1,84 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored image encoder of Qwen 2.5 VL model."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import pathlib
|
20
|
+
from absl import app
|
21
|
+
from absl import flags
|
22
|
+
from ai_edge_torch.generative.examples.qwen_vl import image_encoder
|
23
|
+
from PIL import Image
|
24
|
+
import requests
|
25
|
+
import torch
|
26
|
+
import transformers
|
27
|
+
|
28
|
+
_IMAGE_URL = flags.DEFINE_string(
|
29
|
+
"image_url",
|
30
|
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
|
31
|
+
"The image URI to encode.",
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
def main(_):
|
36
|
+
checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
|
37
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
38
|
+
original_model = (
|
39
|
+
transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
40
|
+
checkpoint
|
41
|
+
)
|
42
|
+
)
|
43
|
+
original_vision_model = original_model.eval().visual
|
44
|
+
|
45
|
+
# Locate the cached dir.
|
46
|
+
cached_config_file = transformers.utils.cached_file(
|
47
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
48
|
+
)
|
49
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
50
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
51
|
+
reauthored_model = image_encoder.build_image_encoder(reauthored_checkpoint)
|
52
|
+
|
53
|
+
logging.info("Loading the processor from: %s", checkpoint)
|
54
|
+
processor = transformers.AutoProcessor.from_pretrained(checkpoint)
|
55
|
+
|
56
|
+
logging.info("Loading the image from: %s", _IMAGE_URL.value)
|
57
|
+
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
|
58
|
+
image_input = processor(images=image, text="", return_tensors="pt")
|
59
|
+
|
60
|
+
logging.info("Forwarding the original model...")
|
61
|
+
outputs_original = original_vision_model.forward(
|
62
|
+
image_input["pixel_values"], image_input["image_grid_thw"]
|
63
|
+
)
|
64
|
+
logging.info("outputs_original: %s", outputs_original)
|
65
|
+
|
66
|
+
logging.info("Forwarding the reauthored model...")
|
67
|
+
outputs_reauthored = reauthored_model.forward(
|
68
|
+
image_input["pixel_values"], image_input["image_grid_thw"]
|
69
|
+
)
|
70
|
+
logging.info("outputs_reauthored: %s", outputs_reauthored)
|
71
|
+
|
72
|
+
try:
|
73
|
+
assert torch.allclose(
|
74
|
+
outputs_original, outputs_reauthored, atol=1e-03, rtol=1e-05
|
75
|
+
)
|
76
|
+
except AssertionError as e:
|
77
|
+
logging.error("*** FAILED *** verify with an image")
|
78
|
+
raise e
|
79
|
+
else:
|
80
|
+
logging.info("*** PASSED *** verify with an image")
|
81
|
+
|
82
|
+
|
83
|
+
if __name__ == "__main__":
|
84
|
+
app.run(main)
|
@@ -224,7 +224,6 @@ class CausalSelfAttention(nn.Module):
|
|
224
224
|
|
225
225
|
if rope is not None:
|
226
226
|
# Compute rotary positional embedding for query and key.
|
227
|
-
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
228
227
|
cos, sin = rope
|
229
228
|
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
230
229
|
|
@@ -177,6 +177,8 @@ class ImageEmbeddingConfig:
|
|
177
177
|
# All images should be normalized to the size of [image_size * image_size].
|
178
178
|
image_size: int
|
179
179
|
patch_size: int
|
180
|
+
# Meaningful only when image embedding is Conv3d.
|
181
|
+
temporal_patch_size: Optional[int] = None
|
180
182
|
|
181
183
|
|
182
184
|
@dataclasses.dataclass
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250204
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=4XOGz1x6yfOnkOtBndF7qE1L3Ma12ZMJNwQ7wIWkyEs,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -74,9 +74,9 @@ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2u
|
|
74
74
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
75
75
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
76
76
|
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=scLsguzzuHfKYDWUd2uZkKYVRzdAbQHLd-kPam8QwvM,3004
|
77
|
-
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=
|
78
|
-
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=
|
79
|
-
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=
|
77
|
+
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=S_W-0ojRu2Vd5SLNPs1kC-70xHB8AdSWslm-yPxyezk,5478
|
78
|
+
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=W009ky-yobueTzdaybSCqBAvNyArLXW3jDyp5MarzZU,6376
|
79
|
+
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=7K1xl64UvoHaYmqWjIbahwXHfppwTQ8sN7JrpGKX1XQ,5771
|
80
80
|
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=CEMG9gh51ev1KXPew927a6nfampiXX9bL6m-25tNYN8,6340
|
81
81
|
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=KT3Ruy40tSESxQuy-Sw01NAI3zId1BZr6Bp7FZj1wZk,5622
|
82
82
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
@@ -95,7 +95,9 @@ ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-L
|
|
95
95
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
96
96
|
ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
97
97
|
ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=rD_Ch5CzuXeatqv0C3z8vU-zou1z9QDUhoB6V4YTPIg,2829
|
98
|
+
ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=QIPbcturxn5OaVsF5zkRRsyAvCM2Bojyz9XBekHOaro,13405
|
98
99
|
ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=FEY_PifD9fQGnERzSOljFLraRIbUVF3XTnCv95A30Cs,2602
|
100
|
+
ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=lQR8p6Zp7PxDN_erMf-FKLIn_Rv4BGyQHjDbModFkeY,2946
|
99
101
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
100
102
|
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
|
101
103
|
ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=CjY1i0iCYxFSjhCpQZwxkmVxILgeo0zu1m0oBrHqyDU,2311
|
@@ -131,13 +133,13 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f
|
|
131
133
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=4rFrppMRKlTwwZeX1ON_cdp4yUqoTOES161IZQkJF6c,1143
|
132
134
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
|
133
135
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
134
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
136
|
+
ai_edge_torch/generative/layers/attention.py,sha256=Pm8FLKh-NnOvUjqQC9oX5oghPbdivZvlPVkgOVTShoU,13703
|
135
137
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
136
138
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
137
139
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
138
140
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=sGGAZD0mWYuO4FukZfDbHXoxpBOBE9lTYICvZzDj5F8,6400
|
139
141
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
140
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
142
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=Yqa3wqZLBe0Lj4PPTIaVFaZ--sV6NJ6k8KPjRguDvCc,8095
|
141
143
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
142
144
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
143
145
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
|
@@ -225,8 +227,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
225
227
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
226
228
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
227
229
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
228
|
-
ai_edge_torch_nightly-0.3.0.
|
229
|
-
ai_edge_torch_nightly-0.3.0.
|
230
|
-
ai_edge_torch_nightly-0.3.0.
|
231
|
-
ai_edge_torch_nightly-0.3.0.
|
232
|
-
ai_edge_torch_nightly-0.3.0.
|
230
|
+
ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
231
|
+
ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/METADATA,sha256=Rf4w5EMQlNWOoFIuVlXUZPU9vmXlOJW7oB4yPrtgK0c,1966
|
232
|
+
ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
233
|
+
ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
234
|
+
ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/RECORD,,
|
File without changes
|
File without changes
|