ai-edge-torch-nightly 0.3.0.dev20250213__py3-none-any.whl → 0.3.0.dev20250215__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +92 -0
- ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +75 -51
- ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +16 -12
- ai_edge_torch/generative/examples/qwen_vl/verify.py +6 -8
- ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py +5 -2
- ai_edge_torch/generative/layers/model_config.py +6 -4
- ai_edge_torch/generative/test/test_model_conversion_large.py +54 -6
- ai_edge_torch/generative/utilities/converter.py +1 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250213.dist-info → ai_edge_torch_nightly-0.3.0.dev20250215.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250213.dist-info → ai_edge_torch_nightly-0.3.0.dev20250215.dist-info}/RECORD +14 -13
- {ai_edge_torch_nightly-0.3.0.dev20250213.dist-info → ai_edge_torch_nightly-0.3.0.dev20250215.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250213.dist-info → ai_edge_torch_nightly-0.3.0.dev20250215.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250213.dist-info → ai_edge_torch_nightly-0.3.0.dev20250215.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,92 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting a Qwen 2.5 VL model to multi-signature tflite model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
|
+
|
27
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
|
+
'checkpoint_path',
|
29
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen-vl'),
|
30
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
|
+
)
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
|
+
'/tmp/',
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'qwen_vl',
|
40
|
+
'The prefix of the output tflite model name.',
|
41
|
+
)
|
42
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
43
|
+
'prefill_seq_len',
|
44
|
+
1024,
|
45
|
+
'The maximum size of prefill input tensor.',
|
46
|
+
)
|
47
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
48
|
+
'kv_cache_max_len',
|
49
|
+
1280,
|
50
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
51
|
+
)
|
52
|
+
_IMAGE_HEIGHT = flags.DEFINE_integer(
|
53
|
+
'image_height',
|
54
|
+
34 * 14,
|
55
|
+
'The height of image.',
|
56
|
+
)
|
57
|
+
_IMAGE_WIDTH = flags.DEFINE_integer(
|
58
|
+
'image_width',
|
59
|
+
46 * 14,
|
60
|
+
'The width of image.',
|
61
|
+
)
|
62
|
+
_QUANTIZE = flags.DEFINE_bool(
|
63
|
+
'quantize',
|
64
|
+
True,
|
65
|
+
'Whether the model should be quantized.',
|
66
|
+
)
|
67
|
+
|
68
|
+
|
69
|
+
def main(_):
|
70
|
+
pytorch_model = qwen_vl.build_model(
|
71
|
+
_CHECKPOINT_PATH.value,
|
72
|
+
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
|
73
|
+
image_size=(_IMAGE_HEIGHT.value, _IMAGE_WIDTH.value),
|
74
|
+
)
|
75
|
+
|
76
|
+
grid_thw = pytorch_model.image_encoder.get_grid_thw()
|
77
|
+
converter.convert_to_tflite(
|
78
|
+
pytorch_model,
|
79
|
+
output_path=_OUTPUT_PATH.value,
|
80
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
81
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
82
|
+
pixel_values_size=(
|
83
|
+
pytorch_model.image_encoder.get_pixel_values_size(grid_thw)
|
84
|
+
),
|
85
|
+
quantize=_QUANTIZE.value,
|
86
|
+
config=pytorch_model.config.decoder_config,
|
87
|
+
export_config=ExportConfig(),
|
88
|
+
)
|
89
|
+
|
90
|
+
|
91
|
+
if __name__ == '__main__':
|
92
|
+
app.run(main)
|
@@ -16,7 +16,7 @@
|
|
16
16
|
"""Example of building an image encoder of Qwen 2.5 VL model."""
|
17
17
|
|
18
18
|
import dataclasses
|
19
|
-
from typing import Optional
|
19
|
+
from typing import List, Optional, Tuple
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import attention
|
22
22
|
from ai_edge_torch.generative.layers import attention_utils
|
@@ -93,7 +93,7 @@ class QwenVLImageEncoder(nn.Module):
|
|
93
93
|
|
94
94
|
# Tensor shape used to reshape pixel_values in forward() and various places.
|
95
95
|
self.kernel_size = (
|
96
|
-
-1, #
|
96
|
+
-1, # pixel_values.size(0)
|
97
97
|
config.image_embedding.channels,
|
98
98
|
config.image_embedding.temporal_patch_size,
|
99
99
|
config.image_embedding.patch_size,
|
@@ -118,28 +118,22 @@ class QwenVLImageEncoder(nn.Module):
|
|
118
118
|
)
|
119
119
|
self.merger = QwenVLMerger(config)
|
120
120
|
self.config = config
|
121
|
+
self.set_image_size(config.image_embedding.image_size)
|
121
122
|
|
122
123
|
@torch.inference_mode
|
123
|
-
def forward(
|
124
|
-
|
125
|
-
|
126
|
-
# Get window index and sequence lengths to rearrange the input tensor.
|
127
|
-
window_index, cu_seqlens = self._get_window_index(grid_thw)
|
124
|
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
125
|
+
# Check if the pixel value size matches with grid size and image config.
|
126
|
+
assert pixel_values.size() == self.get_pixel_values_size(self.grid_thw)
|
128
127
|
|
129
128
|
# Embed the image and rearrange the embedding tensor.
|
130
|
-
pixel_reshaped = pixel_values.
|
129
|
+
pixel_reshaped = pixel_values.reshape(self.kernel_size)
|
131
130
|
x = self.tok_embedding(pixel_reshaped)
|
132
131
|
x = x.view(-1, self.config.embedding_dim)
|
133
|
-
x = self._rearrange(x, window_index).unsqueeze(0)
|
132
|
+
x = self._rearrange(x, self.window_index).unsqueeze(0)
|
134
133
|
|
135
|
-
|
136
|
-
cos, sin = self._get_rope(grid_thw)
|
137
|
-
rope = (
|
138
|
-
self._rearrange(cos, window_index),
|
139
|
-
self._rearrange(sin, window_index),
|
140
|
-
)
|
134
|
+
rope = self._get_rope(self.grid_thw, self.window_index)
|
141
135
|
|
142
|
-
mask = self._get_mask(
|
136
|
+
mask = self._get_mask(self.grid_thw, self.cu_seqlens)
|
143
137
|
full_mask = torch.zeros(x.shape[:2])
|
144
138
|
for i, block in enumerate(self.transformer_blocks):
|
145
139
|
x = block(
|
@@ -150,10 +144,42 @@ class QwenVLImageEncoder(nn.Module):
|
|
150
144
|
|
151
145
|
y = self.merger.forward(self.final_norm(x))
|
152
146
|
# Arrange the output back to the original order.
|
153
|
-
reverse_index
|
154
|
-
|
155
|
-
|
156
|
-
|
147
|
+
return y[self.reverse_index, ...]
|
148
|
+
|
149
|
+
def set_image_size(self, image_size: Tuple[int, int]):
|
150
|
+
"""Set the image size and pre-calculate some values including mask."""
|
151
|
+
self.config.image_embedding.image_size = image_size
|
152
|
+
self.grid_thw = self.get_grid_thw()
|
153
|
+
|
154
|
+
# Precalculate the window index which can't be lowered to HLO because of
|
155
|
+
# inconcrete index in:
|
156
|
+
# index_new = index_padded[index_padded != -100]
|
157
|
+
self.window_index, self.cu_seqlens = self._get_window_index(self.grid_thw)
|
158
|
+
|
159
|
+
# Precalculate the reverse index of window_index until "vhlo.sort_v1" op is
|
160
|
+
# supported.
|
161
|
+
self.reverse_index = torch.argsort(self.window_index)
|
162
|
+
|
163
|
+
def get_grid_thw(self, num_images: int = 1) -> List[Tuple[int, int, int]]:
|
164
|
+
"""Calculate the grid size of the input images based on the image config."""
|
165
|
+
height, width = self.config.image_embedding.image_size
|
166
|
+
patch_height = height // self.config.image_embedding.patch_size
|
167
|
+
patch_width = width // self.config.image_embedding.patch_size
|
168
|
+
# Support only image, i.e. temporal step size is always 1.
|
169
|
+
return [(1, patch_height, patch_width)] * num_images
|
170
|
+
|
171
|
+
def get_pixel_values_size(
|
172
|
+
self, grid_thw: List[Tuple[int, int, int]]
|
173
|
+
) -> torch.Size:
|
174
|
+
"""Calculate the size of pixel values tensor."""
|
175
|
+
dim_0 = sum(t * h * w for t, h, w in grid_thw)
|
176
|
+
config = self.config.image_embedding
|
177
|
+
dim_1 = config.channels * config.temporal_patch_size * config.patch_size**2
|
178
|
+
return torch.Size((dim_0, dim_1))
|
179
|
+
|
180
|
+
def _get_rope(
|
181
|
+
self, grid_thw: List[Tuple[int, int, int]], window_index: torch.Tensor
|
182
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
157
183
|
"""Get RoPE for Qwen VL model based on image grid information.
|
158
184
|
|
159
185
|
It's copied from Qwen2_5_VisionTransformerPretrainedModel.rot_pos_emb() and
|
@@ -182,16 +208,20 @@ class QwenVLImageEncoder(nn.Module):
|
|
182
208
|
wpos_ids = wpos_ids.flatten()
|
183
209
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
184
210
|
pos_ids = torch.cat(pos_ids, dim=0)
|
185
|
-
|
211
|
+
# Assume all the heights and widths are the same for all images.
|
212
|
+
max_grid_size = max(grid_thw[0][1], grid_thw[0][2])
|
186
213
|
|
187
214
|
cos, sin = attention_utils.build_rope_cache(
|
188
215
|
max_grid_size,
|
189
216
|
# ROPE parameters for all attn_configs are the same. Take the first one.
|
190
217
|
self.config.block_config(0).attn_config.head_dim // 2,
|
191
218
|
)
|
192
|
-
return
|
219
|
+
return (
|
220
|
+
self._rearrange(cos[pos_ids].flatten(1), window_index),
|
221
|
+
self._rearrange(sin[pos_ids].flatten(1), window_index),
|
222
|
+
)
|
193
223
|
|
194
|
-
def _get_window_index(self, grid_thw:
|
224
|
+
def _get_window_index(self, grid_thw: List[Tuple[int, int, int]]):
|
195
225
|
"""Get window index for Qwen VL model to rearrange the input tensor.
|
196
226
|
|
197
227
|
It's copied from Qwen2_5_VisionTransformerPretrainedModel.get_window_index()
|
@@ -207,13 +237,10 @@ class QwenVLImageEncoder(nn.Module):
|
|
207
237
|
)
|
208
238
|
|
209
239
|
for grid_t, grid_h, grid_w in grid_thw:
|
210
|
-
llm_grid_h
|
211
|
-
|
212
|
-
|
213
|
-
)
|
214
|
-
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
215
|
-
grid_t, llm_grid_h, llm_grid_w
|
216
|
-
)
|
240
|
+
llm_grid_h = grid_h // self.config.spatial_merge_size
|
241
|
+
llm_grid_w = grid_w // self.config.spatial_merge_size
|
242
|
+
index = torch.arange(grid_t * llm_grid_h * llm_grid_w)
|
243
|
+
index = index.reshape((grid_t, llm_grid_h, llm_grid_w))
|
217
244
|
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
218
245
|
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
219
246
|
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
@@ -236,18 +263,14 @@ class QwenVLImageEncoder(nn.Module):
|
|
236
263
|
index_padded = index_padded.reshape(-1)
|
237
264
|
index_new = index_padded[index_padded != -100]
|
238
265
|
window_index.append(index_new + window_index_id)
|
239
|
-
spatial_merge_unit =
|
240
|
-
self.config.spatial_merge_size * self.config.spatial_merge_size
|
241
|
-
)
|
266
|
+
spatial_merge_unit = self.config.spatial_merge_size**2
|
242
267
|
cu_seqlens_tmp = (
|
243
268
|
seqlens.cumsum(0) * spatial_merge_unit + cu_window_seqlens[-1]
|
244
269
|
)
|
245
270
|
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
246
|
-
window_index_id +=
|
271
|
+
window_index_id += grid_t * llm_grid_h * llm_grid_w
|
247
272
|
|
248
273
|
window_index = torch.cat(window_index, dim=0)
|
249
|
-
cu_window_seqlens = torch.tensor(cu_window_seqlens)
|
250
|
-
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
251
274
|
return window_index, cu_window_seqlens
|
252
275
|
|
253
276
|
def _rearrange(
|
@@ -258,20 +281,20 @@ class QwenVLImageEncoder(nn.Module):
|
|
258
281
|
It's copied from Qwen2_5_VisionTransformerPretrainedModel.forward() and
|
259
282
|
modified accordingly.
|
260
283
|
"""
|
261
|
-
|
262
|
-
|
263
|
-
self.config.spatial_merge_size * self.config.spatial_merge_size
|
264
|
-
)
|
265
|
-
x_reshaped = x.view(size // spatial_merge_unit, spatial_merge_unit, -1)
|
284
|
+
spatial_merge_unit = self.config.spatial_merge_size**2
|
285
|
+
x_reshaped = x.view(x.size(0) // spatial_merge_unit, spatial_merge_unit, -1)
|
266
286
|
x_rearranged = x_reshaped[window_index, ...]
|
267
|
-
return x_rearranged.view(
|
287
|
+
return x_rearranged.view(x.shape)
|
268
288
|
|
269
|
-
def _get_mask(
|
289
|
+
def _get_mask(
|
290
|
+
self, grid_thw: List[Tuple[int, int, int]], cu_seqlens: List[int]
|
291
|
+
) -> torch.Tensor:
|
270
292
|
"""Get attention mask for Qwen VL model.
|
271
293
|
|
272
294
|
It's copied from Qwen2_5_VLVisionAttention.forward() and modified
|
273
295
|
accordingly.
|
274
296
|
"""
|
297
|
+
seqlen = self.get_pixel_values_size(grid_thw)[0]
|
275
298
|
mask = torch.full([1, 1, seqlen, seqlen], float("-inf"))
|
276
299
|
for i in range(1, len(cu_seqlens)):
|
277
300
|
mask[
|
@@ -282,7 +305,7 @@ class QwenVLImageEncoder(nn.Module):
|
|
282
305
|
return mask
|
283
306
|
|
284
307
|
|
285
|
-
def get_image_encoder_config() -> QwenVLImageConfig:
|
308
|
+
def get_image_encoder_config(image_size: Tuple[int, int]) -> QwenVLImageConfig:
|
286
309
|
"""Returns the model config for the image encoder of a Qwen 2.5 VL model.
|
287
310
|
|
288
311
|
Returns:
|
@@ -290,7 +313,7 @@ def get_image_encoder_config() -> QwenVLImageConfig:
|
|
290
313
|
"""
|
291
314
|
image_embedding_config = cfg.ImageEmbeddingConfig(
|
292
315
|
channels=3,
|
293
|
-
image_size=
|
316
|
+
image_size=image_size,
|
294
317
|
patch_size=14,
|
295
318
|
temporal_patch_size=2,
|
296
319
|
)
|
@@ -336,15 +359,13 @@ def get_image_encoder_config() -> QwenVLImageConfig:
|
|
336
359
|
window_size=112,
|
337
360
|
spatial_merge_size=2,
|
338
361
|
full_atten_block_indexes=[7, 15, 23, 31],
|
339
|
-
|
340
|
-
# enable_hlfb can be set to True. See b/383865404#comment3 for details.
|
341
|
-
# enable_hlfb=True,
|
362
|
+
enable_hlfb=True,
|
342
363
|
)
|
343
364
|
return config
|
344
365
|
|
345
366
|
|
346
367
|
def get_fake_image_encoder_config() -> QwenVLImageConfig:
|
347
|
-
config = get_image_encoder_config()
|
368
|
+
config = get_image_encoder_config((8, 12))
|
348
369
|
# PaliGemma image encoder has only one block config.
|
349
370
|
config.block_config(0).ff_config.intermediate_size = 128
|
350
371
|
config.image_embedding.patch_size = 2
|
@@ -353,8 +374,11 @@ def get_fake_image_encoder_config() -> QwenVLImageConfig:
|
|
353
374
|
return config
|
354
375
|
|
355
376
|
|
356
|
-
def build_image_encoder(
|
357
|
-
|
377
|
+
def build_image_encoder(
|
378
|
+
checkpoint_path: str,
|
379
|
+
image_size: Tuple[int, int] = (34 * 14, 46 * 14),
|
380
|
+
) -> QwenVLImageEncoder:
|
381
|
+
config = get_image_encoder_config(image_size)
|
358
382
|
encoder = QwenVLImageEncoder(config)
|
359
383
|
load_image_encoder(checkpoint_path, encoder)
|
360
384
|
encoder.eval()
|
@@ -61,7 +61,6 @@ class QwenVL(nn.Module):
|
|
61
61
|
kv_cache: kv_utils.KVCache,
|
62
62
|
mask: Optional[torch.Tensor] = None,
|
63
63
|
pixel_values: torch.Tensor = None,
|
64
|
-
grid_thw: torch.Tensor = None,
|
65
64
|
export_config: Optional[model_builder.ExportConfig] = None,
|
66
65
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
67
66
|
if pixel_values is None:
|
@@ -69,14 +68,14 @@ class QwenVL(nn.Module):
|
|
69
68
|
tokens=tokens,
|
70
69
|
input_pos=input_pos,
|
71
70
|
kv_cache=kv_cache,
|
72
|
-
mask=mask,
|
73
|
-
rope=self._build_text_rope(input_pos),
|
74
71
|
input_embeds=None,
|
72
|
+
rope=self._build_text_rope(input_pos),
|
73
|
+
mask=mask,
|
75
74
|
export_config=export_config,
|
76
75
|
)
|
77
76
|
|
78
77
|
input_embeds = self.decoder.tok_embedding(tokens)
|
79
|
-
image_embeds = self.image_encoder(pixel_values
|
78
|
+
image_embeds = self.image_encoder(pixel_values).unsqueeze(0)
|
80
79
|
|
81
80
|
# Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
|
82
81
|
# can be done like:
|
@@ -92,18 +91,19 @@ class QwenVL(nn.Module):
|
|
92
91
|
(
|
93
92
|
input_embeds[:, :1, :],
|
94
93
|
image_embeds,
|
95
|
-
input_embeds[:, image_embeds.
|
94
|
+
input_embeds[:, image_embeds.size(1) + 1 :, :],
|
96
95
|
),
|
97
96
|
dim=1,
|
98
97
|
)
|
99
98
|
|
99
|
+
grid_thw = self.image_encoder.get_grid_thw()
|
100
100
|
return self.decoder(
|
101
101
|
tokens=None,
|
102
102
|
input_pos=input_pos,
|
103
103
|
kv_cache=kv_cache,
|
104
|
-
mask=mask,
|
105
104
|
input_embeds=input_embeds,
|
106
105
|
rope=self._build_multimodal_rope(input_pos, grid_thw),
|
106
|
+
mask=mask,
|
107
107
|
export_config=export_config,
|
108
108
|
)
|
109
109
|
|
@@ -120,9 +120,9 @@ class QwenVL(nn.Module):
|
|
120
120
|
def _build_text_rope(
|
121
121
|
self, input_pos: torch.Tensor
|
122
122
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
123
|
-
# Reset rope_pos_adjust to 0 when
|
124
|
-
#
|
125
|
-
if input_pos
|
123
|
+
# Reset rope_pos_adjust to 0 when it's prefill, i.e. input has 2 or more
|
124
|
+
# tokens.
|
125
|
+
if input_pos.numel() > 1:
|
126
126
|
self.rope_pos_adjust = 0
|
127
127
|
return self._build_rope(input_pos + self.rope_pos_adjust)
|
128
128
|
|
@@ -178,15 +178,18 @@ class QwenVL(nn.Module):
|
|
178
178
|
return torch.cat([m[i % 3] for i, m in enumerate(split)], dim=-1)
|
179
179
|
|
180
180
|
|
181
|
-
def get_model_config(
|
181
|
+
def get_model_config(
|
182
|
+
kv_cache_max_len: int = 1024,
|
183
|
+
image_size: Tuple[int, int] = (34 * 14, 46 * 14),
|
184
|
+
) -> QwenVLConfig:
|
182
185
|
"""Returns the model config for a PaliGemma 3B-224 model.
|
183
186
|
|
184
187
|
Returns:
|
185
188
|
The model config for a PaliGemma 3B model.
|
186
189
|
"""
|
187
190
|
return QwenVLConfig(
|
188
|
-
image_encoder_config=image_encoder.get_image_encoder_config(),
|
189
|
-
decoder_config=decoder.get_decoder_config(
|
191
|
+
image_encoder_config=image_encoder.get_image_encoder_config(image_size),
|
192
|
+
decoder_config=decoder.get_decoder_config(kv_cache_max_len),
|
190
193
|
image_token_id=151655,
|
191
194
|
mrope_section=[16, 24, 24],
|
192
195
|
)
|
@@ -197,6 +200,7 @@ def get_fake_model_config(**kwargs) -> QwenVLConfig:
|
|
197
200
|
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
198
201
|
decoder_config=decoder.get_fake_decoder_config(**kwargs),
|
199
202
|
image_token_id=127,
|
203
|
+
mrope_section=[16, 24, 24],
|
200
204
|
)
|
201
205
|
|
202
206
|
|
@@ -17,6 +17,7 @@
|
|
17
17
|
|
18
18
|
import logging
|
19
19
|
import pathlib
|
20
|
+
|
20
21
|
from absl import app
|
21
22
|
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
|
@@ -47,16 +48,9 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
47
48
|
class ReauthoredQwenVLWrapper(verifier.ReauthoredModelWrapper):
|
48
49
|
"""Reauthored Qwen VL model wrapper."""
|
49
50
|
|
50
|
-
def __init__(self, model: torch.nn.Module):
|
51
|
-
super().__init__(model)
|
52
|
-
self.grid_thw = None
|
53
|
-
|
54
51
|
def _init_kv_cache(self):
|
55
52
|
return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
|
56
53
|
|
57
|
-
def _get_extra_args_for_forward(self):
|
58
|
-
return {"grid_thw": self.grid_thw}
|
59
|
-
|
60
54
|
|
61
55
|
def main(_):
|
62
56
|
checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
|
@@ -94,7 +88,11 @@ def main(_):
|
|
94
88
|
|
95
89
|
logging.info("Forwarding the reauthored model...")
|
96
90
|
wrapped_reauthored_model = ReauthoredQwenVLWrapper(reauthored_model)
|
97
|
-
|
91
|
+
grid_thw = inputs["image_grid_thw"].tolist()
|
92
|
+
config = reauthored_model.config.image_encoder_config.image_embedding
|
93
|
+
reauthored_model.image_encoder.set_image_size(
|
94
|
+
(grid_thw[0][1] * config.patch_size, grid_thw[0][2] * config.patch_size)
|
95
|
+
)
|
98
96
|
outputs_reauthored = wrapped_reauthored_model.forward(
|
99
97
|
tokens=inputs["input_ids"],
|
100
98
|
pixel_values=inputs["pixel_values"],
|
@@ -64,9 +64,12 @@ def main(_):
|
|
64
64
|
logging.info("outputs_original: %s", outputs_original)
|
65
65
|
|
66
66
|
logging.info("Forwarding the reauthored model...")
|
67
|
-
|
68
|
-
|
67
|
+
grid_thw = image_input["image_grid_thw"].tolist()
|
68
|
+
config = reauthored_model.config.image_embedding
|
69
|
+
reauthored_model.set_image_size(
|
70
|
+
(grid_thw[0][1] * config.patch_size, grid_thw[0][2] * config.patch_size)
|
69
71
|
)
|
72
|
+
outputs_reauthored = reauthored_model.forward(image_input["pixel_values"])
|
70
73
|
logging.info("outputs_reauthored: %s", outputs_reauthored)
|
71
74
|
|
72
75
|
try:
|
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
import dataclasses
|
19
19
|
import enum
|
20
|
-
from typing import Callable, Optional, Sequence, Union
|
20
|
+
from typing import Callable, Optional, Sequence, Tuple, Union
|
21
21
|
from ai_edge_torch.generative.layers import rotary_position_embedding
|
22
22
|
|
23
23
|
@enum.unique
|
@@ -174,8 +174,10 @@ class ImageEmbeddingConfig:
|
|
174
174
|
"""Image embedding parameters."""
|
175
175
|
|
176
176
|
channels: int
|
177
|
-
# All images should be normalized to
|
178
|
-
image_size
|
177
|
+
# All images should be normalized to image_size * image_size if image_size is
|
178
|
+
# a single integer, or image_size[0] (height) * image_size[1] (width) if
|
179
|
+
# image_size is a tuple of 2 integers.
|
180
|
+
image_size: Union[int | Tuple[int, int]]
|
179
181
|
patch_size: int
|
180
182
|
# Meaningful only when image embedding is Conv3d.
|
181
183
|
temporal_patch_size: Optional[int] = None
|
@@ -205,7 +207,7 @@ class ModelConfig:
|
|
205
207
|
embedding_use_bias: bool = False
|
206
208
|
# Image embedding parameters.
|
207
209
|
image_embedding: Optional[ImageEmbeddingConfig] = None
|
208
|
-
# Number of image tokens
|
210
|
+
# Number of image tokens
|
209
211
|
num_mm_tokens_per_image: Optional[int] = None
|
210
212
|
# Use bias term within LLM's HEAD.
|
211
213
|
lm_head_use_bias: bool = False
|
@@ -28,6 +28,7 @@ from ai_edge_torch.generative.examples.paligemma import paligemma
|
|
28
28
|
from ai_edge_torch.generative.examples.phi import phi2
|
29
29
|
from ai_edge_torch.generative.examples.phi import phi3
|
30
30
|
from ai_edge_torch.generative.examples.qwen import qwen
|
31
|
+
from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
|
31
32
|
from ai_edge_torch.generative.examples.smollm import smollm
|
32
33
|
from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
|
33
34
|
from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
|
@@ -196,17 +197,15 @@ class TestModelConversion(googletest.TestCase):
|
|
196
197
|
config = paligemma.get_fake_model_config(decoder_config)
|
197
198
|
pytorch_model = paligemma.PaliGemma(config, decoder_class).eval()
|
198
199
|
|
199
|
-
|
200
|
-
num_patches = (
|
201
|
-
image_embedding_config.image_size // image_embedding_config.patch_size
|
202
|
-
) ** 2
|
200
|
+
image_config = config.image_encoder_config.image_embedding
|
201
|
+
num_patches = (image_config.image_size // image_config.patch_size) ** 2
|
203
202
|
|
204
203
|
# Make sure the token size is longer than the number of image patches.
|
205
204
|
seq_len = num_patches + 10
|
206
|
-
tokens = torch.zeros((1, seq_len), dtype=torch.int
|
205
|
+
tokens = torch.zeros((1, seq_len), dtype=torch.int)
|
207
206
|
input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
208
207
|
kv = kv_cache.KVCache.from_model_config(config.decoder_config)
|
209
|
-
pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32
|
208
|
+
pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32)
|
210
209
|
|
211
210
|
edge_model = ai_edge_torch.signature(
|
212
211
|
"prefill_pixel",
|
@@ -258,6 +257,55 @@ class TestModelConversion(googletest.TestCase):
|
|
258
257
|
rtol=1e-5,
|
259
258
|
)
|
260
259
|
|
260
|
+
@googletest.skipIf(
|
261
|
+
ai_edge_torch.config.in_oss,
|
262
|
+
reason="tests with custom ops are not supported in oss",
|
263
|
+
)
|
264
|
+
def test_qwen_vl_model(self):
|
265
|
+
config = qwen_vl.get_fake_model_config()
|
266
|
+
pytorch_model = qwen_vl.QwenVL(config).eval()
|
267
|
+
|
268
|
+
grid_thw = pytorch_model.image_encoder.get_grid_thw()
|
269
|
+
pixel_values_size = pytorch_model.image_encoder.get_pixel_values_size(
|
270
|
+
grid_thw
|
271
|
+
)
|
272
|
+
|
273
|
+
# Make sure the token size is longer than the number of pixel values.
|
274
|
+
seq_len = pixel_values_size[0] + 10
|
275
|
+
tokens = torch.zeros((1, seq_len), dtype=torch.int)
|
276
|
+
input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
277
|
+
kv = kv_cache.KVCache.from_model_config(config.decoder_config)
|
278
|
+
pixel_values = torch.zeros(pixel_values_size, dtype=torch.float32)
|
279
|
+
|
280
|
+
edge_model = ai_edge_torch.signature(
|
281
|
+
"prefill_pixel",
|
282
|
+
pytorch_model,
|
283
|
+
sample_kwargs={
|
284
|
+
"tokens": tokens,
|
285
|
+
"input_pos": input_pos,
|
286
|
+
"kv_cache": kv,
|
287
|
+
"pixel_values": pixel_values,
|
288
|
+
},
|
289
|
+
).convert()
|
290
|
+
edge_model.set_interpreter_builder(
|
291
|
+
self._interpreter_builder(edge_model.tflite_model())
|
292
|
+
)
|
293
|
+
|
294
|
+
tokens = torch.arange(1, seq_len + 1, dtype=torch.int).unsqueeze(0)
|
295
|
+
self.assertTrue(
|
296
|
+
test_utils.compare_tflite_torch(
|
297
|
+
edge_model,
|
298
|
+
pytorch_model,
|
299
|
+
tokens,
|
300
|
+
input_pos,
|
301
|
+
kv,
|
302
|
+
pixel_values=pixel_values,
|
303
|
+
signature_name="prefill_pixel",
|
304
|
+
atol=1e-3,
|
305
|
+
rtol=1e-5,
|
306
|
+
)
|
307
|
+
)
|
308
|
+
|
261
309
|
@googletest.skipIf(
|
262
310
|
ai_edge_torch.config.in_oss,
|
263
311
|
reason="tests with custom ops are not supported in oss",
|
@@ -170,7 +170,7 @@ def _export_helper(
|
|
170
170
|
|
171
171
|
# For export, we create a module that captures any non-exportable,
|
172
172
|
# arugments, e.g. the generation config object.
|
173
|
-
mod = ExportableModule(pytorch_model, export_config=export_config)
|
173
|
+
mod = ExportableModule(pytorch_model, export_config=export_config).eval()
|
174
174
|
|
175
175
|
converter = converter_utils.Converter()
|
176
176
|
for lora in loras:
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250215
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=HRjjQujR7rDiLW1Mt_3LZQYVxZd2h-YktOT8MeVMmTc,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -94,12 +94,13 @@ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=tqvXVGNdDehda
|
|
94
94
|
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
|
95
95
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
96
96
|
ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
97
|
+
ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=MXK75-Upoq_RhCbiXJEl8SKJ-msmvpVivsgfqqy-cfg,2780
|
97
98
|
ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=0x4iDg2cBe3PFnjVce3nj7g2rjagGHcKqRCfbASNxA8,4402
|
98
|
-
ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=
|
99
|
-
ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=
|
100
|
-
ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=
|
99
|
+
ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=nHzBe_YSPnUe1d5i09v4bePQomVifzJNeUjRfprmxC0,14878
|
100
|
+
ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=rcYHkpO-NbF4F1Da7q2xNiTng9NHiLx59HyuOgQX5W0,7753
|
101
|
+
ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=cKinMEDXauR5yKxtNTQk1RvwIHUG8-FOkmAie18sukY,5039
|
101
102
|
ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
|
102
|
-
ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=
|
103
|
+
ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=PZ392nDoJG2OmHZ_7Jet3Zu1JkN6QErxKcDc7a-PPds,3126
|
103
104
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
104
105
|
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
|
105
106
|
ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=CjY1i0iCYxFSjhCpQZwxkmVxILgeo0zu1m0oBrHqyDU,2311
|
@@ -141,7 +142,7 @@ ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbX
|
|
141
142
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
142
143
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=sGGAZD0mWYuO4FukZfDbHXoxpBOBE9lTYICvZzDj5F8,6400
|
143
144
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
144
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
145
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=EA1Ey5-c1IOLRNANSUnZ7gtNTA0o6OJxrz_I_mp8cjw,8244
|
145
146
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
146
147
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
147
148
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
|
@@ -167,12 +168,12 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOryp
|
|
167
168
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
168
169
|
ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
|
169
170
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
|
170
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
171
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bXJwDxSPgxVKp-_6BsEmMA3TuMUaUNiZoYomNounxco,14416
|
171
172
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
172
173
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
173
174
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
174
175
|
ai_edge_torch/generative/utilities/bmm_4d.py,sha256=2BMOYiFVUsl-bjxmLkrX4N7kpO0CnhB7eDYxm_iBCr8,2533
|
175
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
176
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=_PO9lYCdNNYPVsAqh8QQVMG_8TUBshKwmaR1cdT6Ang,8065
|
176
177
|
ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
|
177
178
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
178
179
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=5WqcxpeTdt51nVoUwt9g5kKB5wQKj2eYbiaz7k6Ofxg,6815
|
@@ -229,8 +230,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
229
230
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
230
231
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
231
232
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
232
|
-
ai_edge_torch_nightly-0.3.0.
|
233
|
-
ai_edge_torch_nightly-0.3.0.
|
234
|
-
ai_edge_torch_nightly-0.3.0.
|
235
|
-
ai_edge_torch_nightly-0.3.0.
|
236
|
-
ai_edge_torch_nightly-0.3.0.
|
233
|
+
ai_edge_torch_nightly-0.3.0.dev20250215.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
234
|
+
ai_edge_torch_nightly-0.3.0.dev20250215.dist-info/METADATA,sha256=pZGTxEsYT2Tx_2xma-wcLYoLDVDqZs3lw-3sAIhUhPs,1966
|
235
|
+
ai_edge_torch_nightly-0.3.0.dev20250215.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
236
|
+
ai_edge_torch_nightly-0.3.0.dev20250215.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
237
|
+
ai_edge_torch_nightly-0.3.0.dev20250215.dist-info/RECORD,,
|
File without changes
|
File without changes
|