ai-edge-torch-nightly 0.3.0.dev20250214__py3-none-any.whl → 0.3.0.dev20250216__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +92 -0
- ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +75 -51
- ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +16 -12
- ai_edge_torch/generative/examples/qwen_vl/verify.py +6 -8
- ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py +5 -2
- ai_edge_torch/generative/layers/model_config.py +6 -4
- ai_edge_torch/generative/test/test_model_conversion_large.py +54 -6
- ai_edge_torch/generative/utilities/converter.py +1 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250214.dist-info → ai_edge_torch_nightly-0.3.0.dev20250216.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250214.dist-info → ai_edge_torch_nightly-0.3.0.dev20250216.dist-info}/RECORD +14 -13
- {ai_edge_torch_nightly-0.3.0.dev20250214.dist-info → ai_edge_torch_nightly-0.3.0.dev20250216.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250214.dist-info → ai_edge_torch_nightly-0.3.0.dev20250216.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250214.dist-info → ai_edge_torch_nightly-0.3.0.dev20250216.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,92 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting a Qwen 2.5 VL model to multi-signature tflite model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
|
+
|
27
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
|
+
'checkpoint_path',
|
29
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen-vl'),
|
30
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
|
+
)
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
|
+
'/tmp/',
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'qwen_vl',
|
40
|
+
'The prefix of the output tflite model name.',
|
41
|
+
)
|
42
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
43
|
+
'prefill_seq_len',
|
44
|
+
1024,
|
45
|
+
'The maximum size of prefill input tensor.',
|
46
|
+
)
|
47
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
48
|
+
'kv_cache_max_len',
|
49
|
+
1280,
|
50
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
51
|
+
)
|
52
|
+
_IMAGE_HEIGHT = flags.DEFINE_integer(
|
53
|
+
'image_height',
|
54
|
+
34 * 14,
|
55
|
+
'The height of image.',
|
56
|
+
)
|
57
|
+
_IMAGE_WIDTH = flags.DEFINE_integer(
|
58
|
+
'image_width',
|
59
|
+
46 * 14,
|
60
|
+
'The width of image.',
|
61
|
+
)
|
62
|
+
_QUANTIZE = flags.DEFINE_bool(
|
63
|
+
'quantize',
|
64
|
+
True,
|
65
|
+
'Whether the model should be quantized.',
|
66
|
+
)
|
67
|
+
|
68
|
+
|
69
|
+
def main(_):
|
70
|
+
pytorch_model = qwen_vl.build_model(
|
71
|
+
_CHECKPOINT_PATH.value,
|
72
|
+
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
|
73
|
+
image_size=(_IMAGE_HEIGHT.value, _IMAGE_WIDTH.value),
|
74
|
+
)
|
75
|
+
|
76
|
+
grid_thw = pytorch_model.image_encoder.get_grid_thw()
|
77
|
+
converter.convert_to_tflite(
|
78
|
+
pytorch_model,
|
79
|
+
output_path=_OUTPUT_PATH.value,
|
80
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
81
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
82
|
+
pixel_values_size=(
|
83
|
+
pytorch_model.image_encoder.get_pixel_values_size(grid_thw)
|
84
|
+
),
|
85
|
+
quantize=_QUANTIZE.value,
|
86
|
+
config=pytorch_model.config.decoder_config,
|
87
|
+
export_config=ExportConfig(),
|
88
|
+
)
|
89
|
+
|
90
|
+
|
91
|
+
if __name__ == '__main__':
|
92
|
+
app.run(main)
|
@@ -16,7 +16,7 @@
|
|
16
16
|
"""Example of building an image encoder of Qwen 2.5 VL model."""
|
17
17
|
|
18
18
|
import dataclasses
|
19
|
-
from typing import Optional
|
19
|
+
from typing import List, Optional, Tuple
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import attention
|
22
22
|
from ai_edge_torch.generative.layers import attention_utils
|
@@ -93,7 +93,7 @@ class QwenVLImageEncoder(nn.Module):
|
|
93
93
|
|
94
94
|
# Tensor shape used to reshape pixel_values in forward() and various places.
|
95
95
|
self.kernel_size = (
|
96
|
-
-1, #
|
96
|
+
-1, # pixel_values.size(0)
|
97
97
|
config.image_embedding.channels,
|
98
98
|
config.image_embedding.temporal_patch_size,
|
99
99
|
config.image_embedding.patch_size,
|
@@ -118,28 +118,22 @@ class QwenVLImageEncoder(nn.Module):
|
|
118
118
|
)
|
119
119
|
self.merger = QwenVLMerger(config)
|
120
120
|
self.config = config
|
121
|
+
self.set_image_size(config.image_embedding.image_size)
|
121
122
|
|
122
123
|
@torch.inference_mode
|
123
|
-
def forward(
|
124
|
-
|
125
|
-
|
126
|
-
# Get window index and sequence lengths to rearrange the input tensor.
|
127
|
-
window_index, cu_seqlens = self._get_window_index(grid_thw)
|
124
|
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
125
|
+
# Check if the pixel value size matches with grid size and image config.
|
126
|
+
assert pixel_values.size() == self.get_pixel_values_size(self.grid_thw)
|
128
127
|
|
129
128
|
# Embed the image and rearrange the embedding tensor.
|
130
|
-
pixel_reshaped = pixel_values.
|
129
|
+
pixel_reshaped = pixel_values.reshape(self.kernel_size)
|
131
130
|
x = self.tok_embedding(pixel_reshaped)
|
132
131
|
x = x.view(-1, self.config.embedding_dim)
|
133
|
-
x = self._rearrange(x, window_index).unsqueeze(0)
|
132
|
+
x = self._rearrange(x, self.window_index).unsqueeze(0)
|
134
133
|
|
135
|
-
|
136
|
-
cos, sin = self._get_rope(grid_thw)
|
137
|
-
rope = (
|
138
|
-
self._rearrange(cos, window_index),
|
139
|
-
self._rearrange(sin, window_index),
|
140
|
-
)
|
134
|
+
rope = self._get_rope(self.grid_thw, self.window_index)
|
141
135
|
|
142
|
-
mask = self._get_mask(
|
136
|
+
mask = self._get_mask(self.grid_thw, self.cu_seqlens)
|
143
137
|
full_mask = torch.zeros(x.shape[:2])
|
144
138
|
for i, block in enumerate(self.transformer_blocks):
|
145
139
|
x = block(
|
@@ -150,10 +144,42 @@ class QwenVLImageEncoder(nn.Module):
|
|
150
144
|
|
151
145
|
y = self.merger.forward(self.final_norm(x))
|
152
146
|
# Arrange the output back to the original order.
|
153
|
-
reverse_index
|
154
|
-
|
155
|
-
|
156
|
-
|
147
|
+
return y[self.reverse_index, ...]
|
148
|
+
|
149
|
+
def set_image_size(self, image_size: Tuple[int, int]):
|
150
|
+
"""Set the image size and pre-calculate some values including mask."""
|
151
|
+
self.config.image_embedding.image_size = image_size
|
152
|
+
self.grid_thw = self.get_grid_thw()
|
153
|
+
|
154
|
+
# Precalculate the window index which can't be lowered to HLO because of
|
155
|
+
# inconcrete index in:
|
156
|
+
# index_new = index_padded[index_padded != -100]
|
157
|
+
self.window_index, self.cu_seqlens = self._get_window_index(self.grid_thw)
|
158
|
+
|
159
|
+
# Precalculate the reverse index of window_index until "vhlo.sort_v1" op is
|
160
|
+
# supported.
|
161
|
+
self.reverse_index = torch.argsort(self.window_index)
|
162
|
+
|
163
|
+
def get_grid_thw(self, num_images: int = 1) -> List[Tuple[int, int, int]]:
|
164
|
+
"""Calculate the grid size of the input images based on the image config."""
|
165
|
+
height, width = self.config.image_embedding.image_size
|
166
|
+
patch_height = height // self.config.image_embedding.patch_size
|
167
|
+
patch_width = width // self.config.image_embedding.patch_size
|
168
|
+
# Support only image, i.e. temporal step size is always 1.
|
169
|
+
return [(1, patch_height, patch_width)] * num_images
|
170
|
+
|
171
|
+
def get_pixel_values_size(
|
172
|
+
self, grid_thw: List[Tuple[int, int, int]]
|
173
|
+
) -> torch.Size:
|
174
|
+
"""Calculate the size of pixel values tensor."""
|
175
|
+
dim_0 = sum(t * h * w for t, h, w in grid_thw)
|
176
|
+
config = self.config.image_embedding
|
177
|
+
dim_1 = config.channels * config.temporal_patch_size * config.patch_size**2
|
178
|
+
return torch.Size((dim_0, dim_1))
|
179
|
+
|
180
|
+
def _get_rope(
|
181
|
+
self, grid_thw: List[Tuple[int, int, int]], window_index: torch.Tensor
|
182
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
157
183
|
"""Get RoPE for Qwen VL model based on image grid information.
|
158
184
|
|
159
185
|
It's copied from Qwen2_5_VisionTransformerPretrainedModel.rot_pos_emb() and
|
@@ -182,16 +208,20 @@ class QwenVLImageEncoder(nn.Module):
|
|
182
208
|
wpos_ids = wpos_ids.flatten()
|
183
209
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
184
210
|
pos_ids = torch.cat(pos_ids, dim=0)
|
185
|
-
|
211
|
+
# Assume all the heights and widths are the same for all images.
|
212
|
+
max_grid_size = max(grid_thw[0][1], grid_thw[0][2])
|
186
213
|
|
187
214
|
cos, sin = attention_utils.build_rope_cache(
|
188
215
|
max_grid_size,
|
189
216
|
# ROPE parameters for all attn_configs are the same. Take the first one.
|
190
217
|
self.config.block_config(0).attn_config.head_dim // 2,
|
191
218
|
)
|
192
|
-
return
|
219
|
+
return (
|
220
|
+
self._rearrange(cos[pos_ids].flatten(1), window_index),
|
221
|
+
self._rearrange(sin[pos_ids].flatten(1), window_index),
|
222
|
+
)
|
193
223
|
|
194
|
-
def _get_window_index(self, grid_thw:
|
224
|
+
def _get_window_index(self, grid_thw: List[Tuple[int, int, int]]):
|
195
225
|
"""Get window index for Qwen VL model to rearrange the input tensor.
|
196
226
|
|
197
227
|
It's copied from Qwen2_5_VisionTransformerPretrainedModel.get_window_index()
|
@@ -207,13 +237,10 @@ class QwenVLImageEncoder(nn.Module):
|
|
207
237
|
)
|
208
238
|
|
209
239
|
for grid_t, grid_h, grid_w in grid_thw:
|
210
|
-
llm_grid_h
|
211
|
-
|
212
|
-
|
213
|
-
)
|
214
|
-
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
215
|
-
grid_t, llm_grid_h, llm_grid_w
|
216
|
-
)
|
240
|
+
llm_grid_h = grid_h // self.config.spatial_merge_size
|
241
|
+
llm_grid_w = grid_w // self.config.spatial_merge_size
|
242
|
+
index = torch.arange(grid_t * llm_grid_h * llm_grid_w)
|
243
|
+
index = index.reshape((grid_t, llm_grid_h, llm_grid_w))
|
217
244
|
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
218
245
|
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
219
246
|
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
@@ -236,18 +263,14 @@ class QwenVLImageEncoder(nn.Module):
|
|
236
263
|
index_padded = index_padded.reshape(-1)
|
237
264
|
index_new = index_padded[index_padded != -100]
|
238
265
|
window_index.append(index_new + window_index_id)
|
239
|
-
spatial_merge_unit =
|
240
|
-
self.config.spatial_merge_size * self.config.spatial_merge_size
|
241
|
-
)
|
266
|
+
spatial_merge_unit = self.config.spatial_merge_size**2
|
242
267
|
cu_seqlens_tmp = (
|
243
268
|
seqlens.cumsum(0) * spatial_merge_unit + cu_window_seqlens[-1]
|
244
269
|
)
|
245
270
|
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
246
|
-
window_index_id +=
|
271
|
+
window_index_id += grid_t * llm_grid_h * llm_grid_w
|
247
272
|
|
248
273
|
window_index = torch.cat(window_index, dim=0)
|
249
|
-
cu_window_seqlens = torch.tensor(cu_window_seqlens)
|
250
|
-
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
251
274
|
return window_index, cu_window_seqlens
|
252
275
|
|
253
276
|
def _rearrange(
|
@@ -258,20 +281,20 @@ class QwenVLImageEncoder(nn.Module):
|
|
258
281
|
It's copied from Qwen2_5_VisionTransformerPretrainedModel.forward() and
|
259
282
|
modified accordingly.
|
260
283
|
"""
|
261
|
-
|
262
|
-
|
263
|
-
self.config.spatial_merge_size * self.config.spatial_merge_size
|
264
|
-
)
|
265
|
-
x_reshaped = x.view(size // spatial_merge_unit, spatial_merge_unit, -1)
|
284
|
+
spatial_merge_unit = self.config.spatial_merge_size**2
|
285
|
+
x_reshaped = x.view(x.size(0) // spatial_merge_unit, spatial_merge_unit, -1)
|
266
286
|
x_rearranged = x_reshaped[window_index, ...]
|
267
|
-
return x_rearranged.view(
|
287
|
+
return x_rearranged.view(x.shape)
|
268
288
|
|
269
|
-
def _get_mask(
|
289
|
+
def _get_mask(
|
290
|
+
self, grid_thw: List[Tuple[int, int, int]], cu_seqlens: List[int]
|
291
|
+
) -> torch.Tensor:
|
270
292
|
"""Get attention mask for Qwen VL model.
|
271
293
|
|
272
294
|
It's copied from Qwen2_5_VLVisionAttention.forward() and modified
|
273
295
|
accordingly.
|
274
296
|
"""
|
297
|
+
seqlen = self.get_pixel_values_size(grid_thw)[0]
|
275
298
|
mask = torch.full([1, 1, seqlen, seqlen], float("-inf"))
|
276
299
|
for i in range(1, len(cu_seqlens)):
|
277
300
|
mask[
|
@@ -282,7 +305,7 @@ class QwenVLImageEncoder(nn.Module):
|
|
282
305
|
return mask
|
283
306
|
|
284
307
|
|
285
|
-
def get_image_encoder_config() -> QwenVLImageConfig:
|
308
|
+
def get_image_encoder_config(image_size: Tuple[int, int]) -> QwenVLImageConfig:
|
286
309
|
"""Returns the model config for the image encoder of a Qwen 2.5 VL model.
|
287
310
|
|
288
311
|
Returns:
|
@@ -290,7 +313,7 @@ def get_image_encoder_config() -> QwenVLImageConfig:
|
|
290
313
|
"""
|
291
314
|
image_embedding_config = cfg.ImageEmbeddingConfig(
|
292
315
|
channels=3,
|
293
|
-
image_size=
|
316
|
+
image_size=image_size,
|
294
317
|
patch_size=14,
|
295
318
|
temporal_patch_size=2,
|
296
319
|
)
|
@@ -336,15 +359,13 @@ def get_image_encoder_config() -> QwenVLImageConfig:
|
|
336
359
|
window_size=112,
|
337
360
|
spatial_merge_size=2,
|
338
361
|
full_atten_block_indexes=[7, 15, 23, 31],
|
339
|
-
|
340
|
-
# enable_hlfb can be set to True. See b/383865404#comment3 for details.
|
341
|
-
# enable_hlfb=True,
|
362
|
+
enable_hlfb=True,
|
342
363
|
)
|
343
364
|
return config
|
344
365
|
|
345
366
|
|
346
367
|
def get_fake_image_encoder_config() -> QwenVLImageConfig:
|
347
|
-
config = get_image_encoder_config()
|
368
|
+
config = get_image_encoder_config((8, 12))
|
348
369
|
# PaliGemma image encoder has only one block config.
|
349
370
|
config.block_config(0).ff_config.intermediate_size = 128
|
350
371
|
config.image_embedding.patch_size = 2
|
@@ -353,8 +374,11 @@ def get_fake_image_encoder_config() -> QwenVLImageConfig:
|
|
353
374
|
return config
|
354
375
|
|
355
376
|
|
356
|
-
def build_image_encoder(
|
357
|
-
|
377
|
+
def build_image_encoder(
|
378
|
+
checkpoint_path: str,
|
379
|
+
image_size: Tuple[int, int] = (34 * 14, 46 * 14),
|
380
|
+
) -> QwenVLImageEncoder:
|
381
|
+
config = get_image_encoder_config(image_size)
|
358
382
|
encoder = QwenVLImageEncoder(config)
|
359
383
|
load_image_encoder(checkpoint_path, encoder)
|
360
384
|
encoder.eval()
|
@@ -61,7 +61,6 @@ class QwenVL(nn.Module):
|
|
61
61
|
kv_cache: kv_utils.KVCache,
|
62
62
|
mask: Optional[torch.Tensor] = None,
|
63
63
|
pixel_values: torch.Tensor = None,
|
64
|
-
grid_thw: torch.Tensor = None,
|
65
64
|
export_config: Optional[model_builder.ExportConfig] = None,
|
66
65
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
67
66
|
if pixel_values is None:
|
@@ -69,14 +68,14 @@ class QwenVL(nn.Module):
|
|
69
68
|
tokens=tokens,
|
70
69
|
input_pos=input_pos,
|
71
70
|
kv_cache=kv_cache,
|
72
|
-
mask=mask,
|
73
|
-
rope=self._build_text_rope(input_pos),
|
74
71
|
input_embeds=None,
|
72
|
+
rope=self._build_text_rope(input_pos),
|
73
|
+
mask=mask,
|
75
74
|
export_config=export_config,
|
76
75
|
)
|
77
76
|
|
78
77
|
input_embeds = self.decoder.tok_embedding(tokens)
|
79
|
-
image_embeds = self.image_encoder(pixel_values
|
78
|
+
image_embeds = self.image_encoder(pixel_values).unsqueeze(0)
|
80
79
|
|
81
80
|
# Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
|
82
81
|
# can be done like:
|
@@ -92,18 +91,19 @@ class QwenVL(nn.Module):
|
|
92
91
|
(
|
93
92
|
input_embeds[:, :1, :],
|
94
93
|
image_embeds,
|
95
|
-
input_embeds[:, image_embeds.
|
94
|
+
input_embeds[:, image_embeds.size(1) + 1 :, :],
|
96
95
|
),
|
97
96
|
dim=1,
|
98
97
|
)
|
99
98
|
|
99
|
+
grid_thw = self.image_encoder.get_grid_thw()
|
100
100
|
return self.decoder(
|
101
101
|
tokens=None,
|
102
102
|
input_pos=input_pos,
|
103
103
|
kv_cache=kv_cache,
|
104
|
-
mask=mask,
|
105
104
|
input_embeds=input_embeds,
|
106
105
|
rope=self._build_multimodal_rope(input_pos, grid_thw),
|
106
|
+
mask=mask,
|
107
107
|
export_config=export_config,
|
108
108
|
)
|
109
109
|
|
@@ -120,9 +120,9 @@ class QwenVL(nn.Module):
|
|
120
120
|
def _build_text_rope(
|
121
121
|
self, input_pos: torch.Tensor
|
122
122
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
123
|
-
# Reset rope_pos_adjust to 0 when
|
124
|
-
#
|
125
|
-
if input_pos
|
123
|
+
# Reset rope_pos_adjust to 0 when it's prefill, i.e. input has 2 or more
|
124
|
+
# tokens.
|
125
|
+
if input_pos.numel() > 1:
|
126
126
|
self.rope_pos_adjust = 0
|
127
127
|
return self._build_rope(input_pos + self.rope_pos_adjust)
|
128
128
|
|
@@ -178,15 +178,18 @@ class QwenVL(nn.Module):
|
|
178
178
|
return torch.cat([m[i % 3] for i, m in enumerate(split)], dim=-1)
|
179
179
|
|
180
180
|
|
181
|
-
def get_model_config(
|
181
|
+
def get_model_config(
|
182
|
+
kv_cache_max_len: int = 1024,
|
183
|
+
image_size: Tuple[int, int] = (34 * 14, 46 * 14),
|
184
|
+
) -> QwenVLConfig:
|
182
185
|
"""Returns the model config for a PaliGemma 3B-224 model.
|
183
186
|
|
184
187
|
Returns:
|
185
188
|
The model config for a PaliGemma 3B model.
|
186
189
|
"""
|
187
190
|
return QwenVLConfig(
|
188
|
-
image_encoder_config=image_encoder.get_image_encoder_config(),
|
189
|
-
decoder_config=decoder.get_decoder_config(
|
191
|
+
image_encoder_config=image_encoder.get_image_encoder_config(image_size),
|
192
|
+
decoder_config=decoder.get_decoder_config(kv_cache_max_len),
|
190
193
|
image_token_id=151655,
|
191
194
|
mrope_section=[16, 24, 24],
|
192
195
|
)
|
@@ -197,6 +200,7 @@ def get_fake_model_config(**kwargs) -> QwenVLConfig:
|
|
197
200
|
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
198
201
|
decoder_config=decoder.get_fake_decoder_config(**kwargs),
|
199
202
|
image_token_id=127,
|
203
|
+
mrope_section=[16, 24, 24],
|
200
204
|
)
|
201
205
|
|
202
206
|
|
@@ -17,6 +17,7 @@
|
|
17
17
|
|
18
18
|
import logging
|
19
19
|
import pathlib
|
20
|
+
|
20
21
|
from absl import app
|
21
22
|
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
|
@@ -47,16 +48,9 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
47
48
|
class ReauthoredQwenVLWrapper(verifier.ReauthoredModelWrapper):
|
48
49
|
"""Reauthored Qwen VL model wrapper."""
|
49
50
|
|
50
|
-
def __init__(self, model: torch.nn.Module):
|
51
|
-
super().__init__(model)
|
52
|
-
self.grid_thw = None
|
53
|
-
|
54
51
|
def _init_kv_cache(self):
|
55
52
|
return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
|
56
53
|
|
57
|
-
def _get_extra_args_for_forward(self):
|
58
|
-
return {"grid_thw": self.grid_thw}
|
59
|
-
|
60
54
|
|
61
55
|
def main(_):
|
62
56
|
checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
|
@@ -94,7 +88,11 @@ def main(_):
|
|
94
88
|
|
95
89
|
logging.info("Forwarding the reauthored model...")
|
96
90
|
wrapped_reauthored_model = ReauthoredQwenVLWrapper(reauthored_model)
|
97
|
-
|
91
|
+
grid_thw = inputs["image_grid_thw"].tolist()
|
92
|
+
config = reauthored_model.config.image_encoder_config.image_embedding
|
93
|
+
reauthored_model.image_encoder.set_image_size(
|
94
|
+
(grid_thw[0][1] * config.patch_size, grid_thw[0][2] * config.patch_size)
|
95
|
+
)
|
98
96
|
outputs_reauthored = wrapped_reauthored_model.forward(
|
99
97
|
tokens=inputs["input_ids"],
|
100
98
|
pixel_values=inputs["pixel_values"],
|
@@ -64,9 +64,12 @@ def main(_):
|
|
64
64
|
logging.info("outputs_original: %s", outputs_original)
|
65
65
|
|
66
66
|
logging.info("Forwarding the reauthored model...")
|
67
|
-
|
68
|
-
|
67
|
+
grid_thw = image_input["image_grid_thw"].tolist()
|
68
|
+
config = reauthored_model.config.image_embedding
|
69
|
+
reauthored_model.set_image_size(
|
70
|
+
(grid_thw[0][1] * config.patch_size, grid_thw[0][2] * config.patch_size)
|
69
71
|
)
|
72
|
+
outputs_reauthored = reauthored_model.forward(image_input["pixel_values"])
|
70
73
|
logging.info("outputs_reauthored: %s", outputs_reauthored)
|
71
74
|
|
72
75
|
try:
|
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
import dataclasses
|
19
19
|
import enum
|
20
|
-
from typing import Callable, Optional, Sequence, Union
|
20
|
+
from typing import Callable, Optional, Sequence, Tuple, Union
|
21
21
|
from ai_edge_torch.generative.layers import rotary_position_embedding
|
22
22
|
|
23
23
|
@enum.unique
|
@@ -174,8 +174,10 @@ class ImageEmbeddingConfig:
|
|
174
174
|
"""Image embedding parameters."""
|
175
175
|
|
176
176
|
channels: int
|
177
|
-
# All images should be normalized to
|
178
|
-
image_size
|
177
|
+
# All images should be normalized to image_size * image_size if image_size is
|
178
|
+
# a single integer, or image_size[0] (height) * image_size[1] (width) if
|
179
|
+
# image_size is a tuple of 2 integers.
|
180
|
+
image_size: Union[int | Tuple[int, int]]
|
179
181
|
patch_size: int
|
180
182
|
# Meaningful only when image embedding is Conv3d.
|
181
183
|
temporal_patch_size: Optional[int] = None
|
@@ -205,7 +207,7 @@ class ModelConfig:
|
|
205
207
|
embedding_use_bias: bool = False
|
206
208
|
# Image embedding parameters.
|
207
209
|
image_embedding: Optional[ImageEmbeddingConfig] = None
|
208
|
-
# Number of image tokens
|
210
|
+
# Number of image tokens
|
209
211
|
num_mm_tokens_per_image: Optional[int] = None
|
210
212
|
# Use bias term within LLM's HEAD.
|
211
213
|
lm_head_use_bias: bool = False
|
@@ -28,6 +28,7 @@ from ai_edge_torch.generative.examples.paligemma import paligemma
|
|
28
28
|
from ai_edge_torch.generative.examples.phi import phi2
|
29
29
|
from ai_edge_torch.generative.examples.phi import phi3
|
30
30
|
from ai_edge_torch.generative.examples.qwen import qwen
|
31
|
+
from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
|
31
32
|
from ai_edge_torch.generative.examples.smollm import smollm
|
32
33
|
from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
|
33
34
|
from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
|
@@ -196,17 +197,15 @@ class TestModelConversion(googletest.TestCase):
|
|
196
197
|
config = paligemma.get_fake_model_config(decoder_config)
|
197
198
|
pytorch_model = paligemma.PaliGemma(config, decoder_class).eval()
|
198
199
|
|
199
|
-
|
200
|
-
num_patches = (
|
201
|
-
image_embedding_config.image_size // image_embedding_config.patch_size
|
202
|
-
) ** 2
|
200
|
+
image_config = config.image_encoder_config.image_embedding
|
201
|
+
num_patches = (image_config.image_size // image_config.patch_size) ** 2
|
203
202
|
|
204
203
|
# Make sure the token size is longer than the number of image patches.
|
205
204
|
seq_len = num_patches + 10
|
206
|
-
tokens = torch.zeros((1, seq_len), dtype=torch.int
|
205
|
+
tokens = torch.zeros((1, seq_len), dtype=torch.int)
|
207
206
|
input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
208
207
|
kv = kv_cache.KVCache.from_model_config(config.decoder_config)
|
209
|
-
pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32
|
208
|
+
pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32)
|
210
209
|
|
211
210
|
edge_model = ai_edge_torch.signature(
|
212
211
|
"prefill_pixel",
|
@@ -258,6 +257,55 @@ class TestModelConversion(googletest.TestCase):
|
|
258
257
|
rtol=1e-5,
|
259
258
|
)
|
260
259
|
|
260
|
+
@googletest.skipIf(
|
261
|
+
ai_edge_torch.config.in_oss,
|
262
|
+
reason="tests with custom ops are not supported in oss",
|
263
|
+
)
|
264
|
+
def test_qwen_vl_model(self):
|
265
|
+
config = qwen_vl.get_fake_model_config()
|
266
|
+
pytorch_model = qwen_vl.QwenVL(config).eval()
|
267
|
+
|
268
|
+
grid_thw = pytorch_model.image_encoder.get_grid_thw()
|
269
|
+
pixel_values_size = pytorch_model.image_encoder.get_pixel_values_size(
|
270
|
+
grid_thw
|
271
|
+
)
|
272
|
+
|
273
|
+
# Make sure the token size is longer than the number of pixel values.
|
274
|
+
seq_len = pixel_values_size[0] + 10
|
275
|
+
tokens = torch.zeros((1, seq_len), dtype=torch.int)
|
276
|
+
input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
277
|
+
kv = kv_cache.KVCache.from_model_config(config.decoder_config)
|
278
|
+
pixel_values = torch.zeros(pixel_values_size, dtype=torch.float32)
|
279
|
+
|
280
|
+
edge_model = ai_edge_torch.signature(
|
281
|
+
"prefill_pixel",
|
282
|
+
pytorch_model,
|
283
|
+
sample_kwargs={
|
284
|
+
"tokens": tokens,
|
285
|
+
"input_pos": input_pos,
|
286
|
+
"kv_cache": kv,
|
287
|
+
"pixel_values": pixel_values,
|
288
|
+
},
|
289
|
+
).convert()
|
290
|
+
edge_model.set_interpreter_builder(
|
291
|
+
self._interpreter_builder(edge_model.tflite_model())
|
292
|
+
)
|
293
|
+
|
294
|
+
tokens = torch.arange(1, seq_len + 1, dtype=torch.int).unsqueeze(0)
|
295
|
+
self.assertTrue(
|
296
|
+
test_utils.compare_tflite_torch(
|
297
|
+
edge_model,
|
298
|
+
pytorch_model,
|
299
|
+
tokens,
|
300
|
+
input_pos,
|
301
|
+
kv,
|
302
|
+
pixel_values=pixel_values,
|
303
|
+
signature_name="prefill_pixel",
|
304
|
+
atol=1e-3,
|
305
|
+
rtol=1e-5,
|
306
|
+
)
|
307
|
+
)
|
308
|
+
|
261
309
|
@googletest.skipIf(
|
262
310
|
ai_edge_torch.config.in_oss,
|
263
311
|
reason="tests with custom ops are not supported in oss",
|
@@ -170,7 +170,7 @@ def _export_helper(
|
|
170
170
|
|
171
171
|
# For export, we create a module that captures any non-exportable,
|
172
172
|
# arugments, e.g. the generation config object.
|
173
|
-
mod = ExportableModule(pytorch_model, export_config=export_config)
|
173
|
+
mod = ExportableModule(pytorch_model, export_config=export_config).eval()
|
174
174
|
|
175
175
|
converter = converter_utils.Converter()
|
176
176
|
for lora in loras:
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250216
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=vklXbqGLRDju4mlU9vpIceTDodbvgQCmd7eyCsV5ckM,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -94,12 +94,13 @@ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=tqvXVGNdDehda
|
|
94
94
|
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
|
95
95
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
96
96
|
ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
97
|
+
ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=MXK75-Upoq_RhCbiXJEl8SKJ-msmvpVivsgfqqy-cfg,2780
|
97
98
|
ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=0x4iDg2cBe3PFnjVce3nj7g2rjagGHcKqRCfbASNxA8,4402
|
98
|
-
ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=
|
99
|
-
ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=
|
100
|
-
ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=
|
99
|
+
ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=nHzBe_YSPnUe1d5i09v4bePQomVifzJNeUjRfprmxC0,14878
|
100
|
+
ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=rcYHkpO-NbF4F1Da7q2xNiTng9NHiLx59HyuOgQX5W0,7753
|
101
|
+
ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=cKinMEDXauR5yKxtNTQk1RvwIHUG8-FOkmAie18sukY,5039
|
101
102
|
ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
|
102
|
-
ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=
|
103
|
+
ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=PZ392nDoJG2OmHZ_7Jet3Zu1JkN6QErxKcDc7a-PPds,3126
|
103
104
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
104
105
|
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
|
105
106
|
ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=CjY1i0iCYxFSjhCpQZwxkmVxILgeo0zu1m0oBrHqyDU,2311
|
@@ -141,7 +142,7 @@ ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbX
|
|
141
142
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
142
143
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=sGGAZD0mWYuO4FukZfDbHXoxpBOBE9lTYICvZzDj5F8,6400
|
143
144
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
144
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
145
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=EA1Ey5-c1IOLRNANSUnZ7gtNTA0o6OJxrz_I_mp8cjw,8244
|
145
146
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
146
147
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
147
148
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
|
@@ -167,12 +168,12 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOryp
|
|
167
168
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
168
169
|
ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
|
169
170
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
|
170
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
171
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bXJwDxSPgxVKp-_6BsEmMA3TuMUaUNiZoYomNounxco,14416
|
171
172
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
172
173
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
173
174
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
174
175
|
ai_edge_torch/generative/utilities/bmm_4d.py,sha256=2BMOYiFVUsl-bjxmLkrX4N7kpO0CnhB7eDYxm_iBCr8,2533
|
175
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
176
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=_PO9lYCdNNYPVsAqh8QQVMG_8TUBshKwmaR1cdT6Ang,8065
|
176
177
|
ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
|
177
178
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
178
179
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=5WqcxpeTdt51nVoUwt9g5kKB5wQKj2eYbiaz7k6Ofxg,6815
|
@@ -229,8 +230,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
229
230
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
230
231
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
231
232
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
232
|
-
ai_edge_torch_nightly-0.3.0.
|
233
|
-
ai_edge_torch_nightly-0.3.0.
|
234
|
-
ai_edge_torch_nightly-0.3.0.
|
235
|
-
ai_edge_torch_nightly-0.3.0.
|
236
|
-
ai_edge_torch_nightly-0.3.0.
|
233
|
+
ai_edge_torch_nightly-0.3.0.dev20250216.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
234
|
+
ai_edge_torch_nightly-0.3.0.dev20250216.dist-info/METADATA,sha256=uej57gx3UQtqqGHydXxTTrLAlbzRg48u-YmLvdPxrIk,1966
|
235
|
+
ai_edge_torch_nightly-0.3.0.dev20250216.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
236
|
+
ai_edge_torch_nightly-0.3.0.dev20250216.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
237
|
+
ai_edge_torch_nightly-0.3.0.dev20250216.dist-info/RECORD,,
|
File without changes
|
File without changes
|