ai-edge-torch-nightly 0.3.0.dev20250214__py3-none-any.whl → 0.3.0.dev20250216__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,92 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting a Qwen 2.5 VL model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
24
+ from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
+
27
+ _CHECKPOINT_PATH = flags.DEFINE_string(
28
+ 'checkpoint_path',
29
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen-vl'),
30
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
31
+ )
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
+ '/tmp/',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'qwen_vl',
40
+ 'The prefix of the output tflite model name.',
41
+ )
42
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
43
+ 'prefill_seq_len',
44
+ 1024,
45
+ 'The maximum size of prefill input tensor.',
46
+ )
47
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
+ 'kv_cache_max_len',
49
+ 1280,
50
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
51
+ )
52
+ _IMAGE_HEIGHT = flags.DEFINE_integer(
53
+ 'image_height',
54
+ 34 * 14,
55
+ 'The height of image.',
56
+ )
57
+ _IMAGE_WIDTH = flags.DEFINE_integer(
58
+ 'image_width',
59
+ 46 * 14,
60
+ 'The width of image.',
61
+ )
62
+ _QUANTIZE = flags.DEFINE_bool(
63
+ 'quantize',
64
+ True,
65
+ 'Whether the model should be quantized.',
66
+ )
67
+
68
+
69
+ def main(_):
70
+ pytorch_model = qwen_vl.build_model(
71
+ _CHECKPOINT_PATH.value,
72
+ kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
73
+ image_size=(_IMAGE_HEIGHT.value, _IMAGE_WIDTH.value),
74
+ )
75
+
76
+ grid_thw = pytorch_model.image_encoder.get_grid_thw()
77
+ converter.convert_to_tflite(
78
+ pytorch_model,
79
+ output_path=_OUTPUT_PATH.value,
80
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
81
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
82
+ pixel_values_size=(
83
+ pytorch_model.image_encoder.get_pixel_values_size(grid_thw)
84
+ ),
85
+ quantize=_QUANTIZE.value,
86
+ config=pytorch_model.config.decoder_config,
87
+ export_config=ExportConfig(),
88
+ )
89
+
90
+
91
+ if __name__ == '__main__':
92
+ app.run(main)
@@ -16,7 +16,7 @@
16
16
  """Example of building an image encoder of Qwen 2.5 VL model."""
17
17
 
18
18
  import dataclasses
19
- from typing import Optional
19
+ from typing import List, Optional, Tuple
20
20
 
21
21
  from ai_edge_torch.generative.layers import attention
22
22
  from ai_edge_torch.generative.layers import attention_utils
@@ -93,7 +93,7 @@ class QwenVLImageEncoder(nn.Module):
93
93
 
94
94
  # Tensor shape used to reshape pixel_values in forward() and various places.
95
95
  self.kernel_size = (
96
- -1, # batch size
96
+ -1, # pixel_values.size(0)
97
97
  config.image_embedding.channels,
98
98
  config.image_embedding.temporal_patch_size,
99
99
  config.image_embedding.patch_size,
@@ -118,28 +118,22 @@ class QwenVLImageEncoder(nn.Module):
118
118
  )
119
119
  self.merger = QwenVLMerger(config)
120
120
  self.config = config
121
+ self.set_image_size(config.image_embedding.image_size)
121
122
 
122
123
  @torch.inference_mode
123
- def forward(
124
- self, pixel_values: torch.Tensor, grid_thw: torch.Tensor
125
- ) -> torch.Tensor:
126
- # Get window index and sequence lengths to rearrange the input tensor.
127
- window_index, cu_seqlens = self._get_window_index(grid_thw)
124
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
125
+ # Check if the pixel value size matches with grid size and image config.
126
+ assert pixel_values.size() == self.get_pixel_values_size(self.grid_thw)
128
127
 
129
128
  # Embed the image and rearrange the embedding tensor.
130
- pixel_reshaped = pixel_values.view(self.kernel_size)
129
+ pixel_reshaped = pixel_values.reshape(self.kernel_size)
131
130
  x = self.tok_embedding(pixel_reshaped)
132
131
  x = x.view(-1, self.config.embedding_dim)
133
- x = self._rearrange(x, window_index).unsqueeze(0)
132
+ x = self._rearrange(x, self.window_index).unsqueeze(0)
134
133
 
135
- # Get RoPE and attention mask arranged according to the window index.
136
- cos, sin = self._get_rope(grid_thw)
137
- rope = (
138
- self._rearrange(cos, window_index),
139
- self._rearrange(sin, window_index),
140
- )
134
+ rope = self._get_rope(self.grid_thw, self.window_index)
141
135
 
142
- mask = self._get_mask(x.shape[1], cu_seqlens)
136
+ mask = self._get_mask(self.grid_thw, self.cu_seqlens)
143
137
  full_mask = torch.zeros(x.shape[:2])
144
138
  for i, block in enumerate(self.transformer_blocks):
145
139
  x = block(
@@ -150,10 +144,42 @@ class QwenVLImageEncoder(nn.Module):
150
144
 
151
145
  y = self.merger.forward(self.final_norm(x))
152
146
  # Arrange the output back to the original order.
153
- reverse_index = torch.argsort(window_index)
154
- return y[reverse_index, ...]
155
-
156
- def _get_rope(self, grid_thw: torch.Tensor) -> torch.Tensor:
147
+ return y[self.reverse_index, ...]
148
+
149
+ def set_image_size(self, image_size: Tuple[int, int]):
150
+ """Set the image size and pre-calculate some values including mask."""
151
+ self.config.image_embedding.image_size = image_size
152
+ self.grid_thw = self.get_grid_thw()
153
+
154
+ # Precalculate the window index which can't be lowered to HLO because of
155
+ # inconcrete index in:
156
+ # index_new = index_padded[index_padded != -100]
157
+ self.window_index, self.cu_seqlens = self._get_window_index(self.grid_thw)
158
+
159
+ # Precalculate the reverse index of window_index until "vhlo.sort_v1" op is
160
+ # supported.
161
+ self.reverse_index = torch.argsort(self.window_index)
162
+
163
+ def get_grid_thw(self, num_images: int = 1) -> List[Tuple[int, int, int]]:
164
+ """Calculate the grid size of the input images based on the image config."""
165
+ height, width = self.config.image_embedding.image_size
166
+ patch_height = height // self.config.image_embedding.patch_size
167
+ patch_width = width // self.config.image_embedding.patch_size
168
+ # Support only image, i.e. temporal step size is always 1.
169
+ return [(1, patch_height, patch_width)] * num_images
170
+
171
+ def get_pixel_values_size(
172
+ self, grid_thw: List[Tuple[int, int, int]]
173
+ ) -> torch.Size:
174
+ """Calculate the size of pixel values tensor."""
175
+ dim_0 = sum(t * h * w for t, h, w in grid_thw)
176
+ config = self.config.image_embedding
177
+ dim_1 = config.channels * config.temporal_patch_size * config.patch_size**2
178
+ return torch.Size((dim_0, dim_1))
179
+
180
+ def _get_rope(
181
+ self, grid_thw: List[Tuple[int, int, int]], window_index: torch.Tensor
182
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
157
183
  """Get RoPE for Qwen VL model based on image grid information.
158
184
 
159
185
  It's copied from Qwen2_5_VisionTransformerPretrainedModel.rot_pos_emb() and
@@ -182,16 +208,20 @@ class QwenVLImageEncoder(nn.Module):
182
208
  wpos_ids = wpos_ids.flatten()
183
209
  pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
184
210
  pos_ids = torch.cat(pos_ids, dim=0)
185
- max_grid_size = grid_thw[:, 1:].max()
211
+ # Assume all the heights and widths are the same for all images.
212
+ max_grid_size = max(grid_thw[0][1], grid_thw[0][2])
186
213
 
187
214
  cos, sin = attention_utils.build_rope_cache(
188
215
  max_grid_size,
189
216
  # ROPE parameters for all attn_configs are the same. Take the first one.
190
217
  self.config.block_config(0).attn_config.head_dim // 2,
191
218
  )
192
- return cos[pos_ids].flatten(1), sin[pos_ids].flatten(1)
219
+ return (
220
+ self._rearrange(cos[pos_ids].flatten(1), window_index),
221
+ self._rearrange(sin[pos_ids].flatten(1), window_index),
222
+ )
193
223
 
194
- def _get_window_index(self, grid_thw: torch.Tensor):
224
+ def _get_window_index(self, grid_thw: List[Tuple[int, int, int]]):
195
225
  """Get window index for Qwen VL model to rearrange the input tensor.
196
226
 
197
227
  It's copied from Qwen2_5_VisionTransformerPretrainedModel.get_window_index()
@@ -207,13 +237,10 @@ class QwenVLImageEncoder(nn.Module):
207
237
  )
208
238
 
209
239
  for grid_t, grid_h, grid_w in grid_thw:
210
- llm_grid_h, llm_grid_w = (
211
- grid_h // self.config.spatial_merge_size,
212
- grid_w // self.config.spatial_merge_size,
213
- )
214
- index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
215
- grid_t, llm_grid_h, llm_grid_w
216
- )
240
+ llm_grid_h = grid_h // self.config.spatial_merge_size
241
+ llm_grid_w = grid_w // self.config.spatial_merge_size
242
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w)
243
+ index = index.reshape((grid_t, llm_grid_h, llm_grid_w))
217
244
  pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
218
245
  pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
219
246
  num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
@@ -236,18 +263,14 @@ class QwenVLImageEncoder(nn.Module):
236
263
  index_padded = index_padded.reshape(-1)
237
264
  index_new = index_padded[index_padded != -100]
238
265
  window_index.append(index_new + window_index_id)
239
- spatial_merge_unit = (
240
- self.config.spatial_merge_size * self.config.spatial_merge_size
241
- )
266
+ spatial_merge_unit = self.config.spatial_merge_size**2
242
267
  cu_seqlens_tmp = (
243
268
  seqlens.cumsum(0) * spatial_merge_unit + cu_window_seqlens[-1]
244
269
  )
245
270
  cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
246
- window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
271
+ window_index_id += grid_t * llm_grid_h * llm_grid_w
247
272
 
248
273
  window_index = torch.cat(window_index, dim=0)
249
- cu_window_seqlens = torch.tensor(cu_window_seqlens)
250
- cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
251
274
  return window_index, cu_window_seqlens
252
275
 
253
276
  def _rearrange(
@@ -258,20 +281,20 @@ class QwenVLImageEncoder(nn.Module):
258
281
  It's copied from Qwen2_5_VisionTransformerPretrainedModel.forward() and
259
282
  modified accordingly.
260
283
  """
261
- size = x.shape[0]
262
- spatial_merge_unit = (
263
- self.config.spatial_merge_size * self.config.spatial_merge_size
264
- )
265
- x_reshaped = x.view(size // spatial_merge_unit, spatial_merge_unit, -1)
284
+ spatial_merge_unit = self.config.spatial_merge_size**2
285
+ x_reshaped = x.view(x.size(0) // spatial_merge_unit, spatial_merge_unit, -1)
266
286
  x_rearranged = x_reshaped[window_index, ...]
267
- return x_rearranged.view(size, -1)
287
+ return x_rearranged.view(x.shape)
268
288
 
269
- def _get_mask(self, seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
289
+ def _get_mask(
290
+ self, grid_thw: List[Tuple[int, int, int]], cu_seqlens: List[int]
291
+ ) -> torch.Tensor:
270
292
  """Get attention mask for Qwen VL model.
271
293
 
272
294
  It's copied from Qwen2_5_VLVisionAttention.forward() and modified
273
295
  accordingly.
274
296
  """
297
+ seqlen = self.get_pixel_values_size(grid_thw)[0]
275
298
  mask = torch.full([1, 1, seqlen, seqlen], float("-inf"))
276
299
  for i in range(1, len(cu_seqlens)):
277
300
  mask[
@@ -282,7 +305,7 @@ class QwenVLImageEncoder(nn.Module):
282
305
  return mask
283
306
 
284
307
 
285
- def get_image_encoder_config() -> QwenVLImageConfig:
308
+ def get_image_encoder_config(image_size: Tuple[int, int]) -> QwenVLImageConfig:
286
309
  """Returns the model config for the image encoder of a Qwen 2.5 VL model.
287
310
 
288
311
  Returns:
@@ -290,7 +313,7 @@ def get_image_encoder_config() -> QwenVLImageConfig:
290
313
  """
291
314
  image_embedding_config = cfg.ImageEmbeddingConfig(
292
315
  channels=3,
293
- image_size=0, # Not used in image encoder.
316
+ image_size=image_size,
294
317
  patch_size=14,
295
318
  temporal_patch_size=2,
296
319
  )
@@ -336,15 +359,13 @@ def get_image_encoder_config() -> QwenVLImageConfig:
336
359
  window_size=112,
337
360
  spatial_merge_size=2,
338
361
  full_atten_block_indexes=[7, 15, 23, 31],
339
- # TODO: b/377051577 - Once RemoveSDPACompositeZeroMaskPass is removed,
340
- # enable_hlfb can be set to True. See b/383865404#comment3 for details.
341
- # enable_hlfb=True,
362
+ enable_hlfb=True,
342
363
  )
343
364
  return config
344
365
 
345
366
 
346
367
  def get_fake_image_encoder_config() -> QwenVLImageConfig:
347
- config = get_image_encoder_config()
368
+ config = get_image_encoder_config((8, 12))
348
369
  # PaliGemma image encoder has only one block config.
349
370
  config.block_config(0).ff_config.intermediate_size = 128
350
371
  config.image_embedding.patch_size = 2
@@ -353,8 +374,11 @@ def get_fake_image_encoder_config() -> QwenVLImageConfig:
353
374
  return config
354
375
 
355
376
 
356
- def build_image_encoder(checkpoint_path: str) -> QwenVLImageEncoder:
357
- config = get_image_encoder_config()
377
+ def build_image_encoder(
378
+ checkpoint_path: str,
379
+ image_size: Tuple[int, int] = (34 * 14, 46 * 14),
380
+ ) -> QwenVLImageEncoder:
381
+ config = get_image_encoder_config(image_size)
358
382
  encoder = QwenVLImageEncoder(config)
359
383
  load_image_encoder(checkpoint_path, encoder)
360
384
  encoder.eval()
@@ -61,7 +61,6 @@ class QwenVL(nn.Module):
61
61
  kv_cache: kv_utils.KVCache,
62
62
  mask: Optional[torch.Tensor] = None,
63
63
  pixel_values: torch.Tensor = None,
64
- grid_thw: torch.Tensor = None,
65
64
  export_config: Optional[model_builder.ExportConfig] = None,
66
65
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
67
66
  if pixel_values is None:
@@ -69,14 +68,14 @@ class QwenVL(nn.Module):
69
68
  tokens=tokens,
70
69
  input_pos=input_pos,
71
70
  kv_cache=kv_cache,
72
- mask=mask,
73
- rope=self._build_text_rope(input_pos),
74
71
  input_embeds=None,
72
+ rope=self._build_text_rope(input_pos),
73
+ mask=mask,
75
74
  export_config=export_config,
76
75
  )
77
76
 
78
77
  input_embeds = self.decoder.tok_embedding(tokens)
79
- image_embeds = self.image_encoder(pixel_values, grid_thw).unsqueeze(0)
78
+ image_embeds = self.image_encoder(pixel_values).unsqueeze(0)
80
79
 
81
80
  # Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
82
81
  # can be done like:
@@ -92,18 +91,19 @@ class QwenVL(nn.Module):
92
91
  (
93
92
  input_embeds[:, :1, :],
94
93
  image_embeds,
95
- input_embeds[:, image_embeds.shape[1] + 1 :, :],
94
+ input_embeds[:, image_embeds.size(1) + 1 :, :],
96
95
  ),
97
96
  dim=1,
98
97
  )
99
98
 
99
+ grid_thw = self.image_encoder.get_grid_thw()
100
100
  return self.decoder(
101
101
  tokens=None,
102
102
  input_pos=input_pos,
103
103
  kv_cache=kv_cache,
104
- mask=mask,
105
104
  input_embeds=input_embeds,
106
105
  rope=self._build_multimodal_rope(input_pos, grid_thw),
106
+ mask=mask,
107
107
  export_config=export_config,
108
108
  )
109
109
 
@@ -120,9 +120,9 @@ class QwenVL(nn.Module):
120
120
  def _build_text_rope(
121
121
  self, input_pos: torch.Tensor
122
122
  ) -> Tuple[torch.Tensor, torch.Tensor]:
123
- # Reset rope_pos_adjust to 0 when input sequence starts from scratch, i.e.
124
- # input_pos[0] = 0.
125
- if input_pos[0] == 0:
123
+ # Reset rope_pos_adjust to 0 when it's prefill, i.e. input has 2 or more
124
+ # tokens.
125
+ if input_pos.numel() > 1:
126
126
  self.rope_pos_adjust = 0
127
127
  return self._build_rope(input_pos + self.rope_pos_adjust)
128
128
 
@@ -178,15 +178,18 @@ class QwenVL(nn.Module):
178
178
  return torch.cat([m[i % 3] for i, m in enumerate(split)], dim=-1)
179
179
 
180
180
 
181
- def get_model_config(**kwargs) -> QwenVLConfig:
181
+ def get_model_config(
182
+ kv_cache_max_len: int = 1024,
183
+ image_size: Tuple[int, int] = (34 * 14, 46 * 14),
184
+ ) -> QwenVLConfig:
182
185
  """Returns the model config for a PaliGemma 3B-224 model.
183
186
 
184
187
  Returns:
185
188
  The model config for a PaliGemma 3B model.
186
189
  """
187
190
  return QwenVLConfig(
188
- image_encoder_config=image_encoder.get_image_encoder_config(),
189
- decoder_config=decoder.get_decoder_config(**kwargs),
191
+ image_encoder_config=image_encoder.get_image_encoder_config(image_size),
192
+ decoder_config=decoder.get_decoder_config(kv_cache_max_len),
190
193
  image_token_id=151655,
191
194
  mrope_section=[16, 24, 24],
192
195
  )
@@ -197,6 +200,7 @@ def get_fake_model_config(**kwargs) -> QwenVLConfig:
197
200
  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
198
201
  decoder_config=decoder.get_fake_decoder_config(**kwargs),
199
202
  image_token_id=127,
203
+ mrope_section=[16, 24, 24],
200
204
  )
201
205
 
202
206
 
@@ -17,6 +17,7 @@
17
17
 
18
18
  import logging
19
19
  import pathlib
20
+
20
21
  from absl import app
21
22
  from absl import flags
22
23
  from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
@@ -47,16 +48,9 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
47
48
  class ReauthoredQwenVLWrapper(verifier.ReauthoredModelWrapper):
48
49
  """Reauthored Qwen VL model wrapper."""
49
50
 
50
- def __init__(self, model: torch.nn.Module):
51
- super().__init__(model)
52
- self.grid_thw = None
53
-
54
51
  def _init_kv_cache(self):
55
52
  return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
56
53
 
57
- def _get_extra_args_for_forward(self):
58
- return {"grid_thw": self.grid_thw}
59
-
60
54
 
61
55
  def main(_):
62
56
  checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
@@ -94,7 +88,11 @@ def main(_):
94
88
 
95
89
  logging.info("Forwarding the reauthored model...")
96
90
  wrapped_reauthored_model = ReauthoredQwenVLWrapper(reauthored_model)
97
- wrapped_reauthored_model.grid_thw = inputs["image_grid_thw"]
91
+ grid_thw = inputs["image_grid_thw"].tolist()
92
+ config = reauthored_model.config.image_encoder_config.image_embedding
93
+ reauthored_model.image_encoder.set_image_size(
94
+ (grid_thw[0][1] * config.patch_size, grid_thw[0][2] * config.patch_size)
95
+ )
98
96
  outputs_reauthored = wrapped_reauthored_model.forward(
99
97
  tokens=inputs["input_ids"],
100
98
  pixel_values=inputs["pixel_values"],
@@ -64,9 +64,12 @@ def main(_):
64
64
  logging.info("outputs_original: %s", outputs_original)
65
65
 
66
66
  logging.info("Forwarding the reauthored model...")
67
- outputs_reauthored = reauthored_model.forward(
68
- image_input["pixel_values"], image_input["image_grid_thw"]
67
+ grid_thw = image_input["image_grid_thw"].tolist()
68
+ config = reauthored_model.config.image_embedding
69
+ reauthored_model.set_image_size(
70
+ (grid_thw[0][1] * config.patch_size, grid_thw[0][2] * config.patch_size)
69
71
  )
72
+ outputs_reauthored = reauthored_model.forward(image_input["pixel_values"])
70
73
  logging.info("outputs_reauthored: %s", outputs_reauthored)
71
74
 
72
75
  try:
@@ -17,7 +17,7 @@
17
17
 
18
18
  import dataclasses
19
19
  import enum
20
- from typing import Callable, Optional, Sequence, Union
20
+ from typing import Callable, Optional, Sequence, Tuple, Union
21
21
  from ai_edge_torch.generative.layers import rotary_position_embedding
22
22
 
23
23
  @enum.unique
@@ -174,8 +174,10 @@ class ImageEmbeddingConfig:
174
174
  """Image embedding parameters."""
175
175
 
176
176
  channels: int
177
- # All images should be normalized to the size of [image_size * image_size].
178
- image_size: int
177
+ # All images should be normalized to image_size * image_size if image_size is
178
+ # a single integer, or image_size[0] (height) * image_size[1] (width) if
179
+ # image_size is a tuple of 2 integers.
180
+ image_size: Union[int | Tuple[int, int]]
179
181
  patch_size: int
180
182
  # Meaningful only when image embedding is Conv3d.
181
183
  temporal_patch_size: Optional[int] = None
@@ -205,7 +207,7 @@ class ModelConfig:
205
207
  embedding_use_bias: bool = False
206
208
  # Image embedding parameters.
207
209
  image_embedding: Optional[ImageEmbeddingConfig] = None
208
- # Number of image tokens
210
+ # Number of image tokens
209
211
  num_mm_tokens_per_image: Optional[int] = None
210
212
  # Use bias term within LLM's HEAD.
211
213
  lm_head_use_bias: bool = False
@@ -28,6 +28,7 @@ from ai_edge_torch.generative.examples.paligemma import paligemma
28
28
  from ai_edge_torch.generative.examples.phi import phi2
29
29
  from ai_edge_torch.generative.examples.phi import phi3
30
30
  from ai_edge_torch.generative.examples.qwen import qwen
31
+ from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
31
32
  from ai_edge_torch.generative.examples.smollm import smollm
32
33
  from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
33
34
  from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
@@ -196,17 +197,15 @@ class TestModelConversion(googletest.TestCase):
196
197
  config = paligemma.get_fake_model_config(decoder_config)
197
198
  pytorch_model = paligemma.PaliGemma(config, decoder_class).eval()
198
199
 
199
- image_embedding_config = config.image_encoder_config.image_embedding
200
- num_patches = (
201
- image_embedding_config.image_size // image_embedding_config.patch_size
202
- ) ** 2
200
+ image_config = config.image_encoder_config.image_embedding
201
+ num_patches = (image_config.image_size // image_config.patch_size) ** 2
203
202
 
204
203
  # Make sure the token size is longer than the number of image patches.
205
204
  seq_len = num_patches + 10
206
- tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
205
+ tokens = torch.zeros((1, seq_len), dtype=torch.int)
207
206
  input_pos = torch.arange(0, seq_len, dtype=torch.int)
208
207
  kv = kv_cache.KVCache.from_model_config(config.decoder_config)
209
- pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32, device="cpu")
208
+ pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32)
210
209
 
211
210
  edge_model = ai_edge_torch.signature(
212
211
  "prefill_pixel",
@@ -258,6 +257,55 @@ class TestModelConversion(googletest.TestCase):
258
257
  rtol=1e-5,
259
258
  )
260
259
 
260
+ @googletest.skipIf(
261
+ ai_edge_torch.config.in_oss,
262
+ reason="tests with custom ops are not supported in oss",
263
+ )
264
+ def test_qwen_vl_model(self):
265
+ config = qwen_vl.get_fake_model_config()
266
+ pytorch_model = qwen_vl.QwenVL(config).eval()
267
+
268
+ grid_thw = pytorch_model.image_encoder.get_grid_thw()
269
+ pixel_values_size = pytorch_model.image_encoder.get_pixel_values_size(
270
+ grid_thw
271
+ )
272
+
273
+ # Make sure the token size is longer than the number of pixel values.
274
+ seq_len = pixel_values_size[0] + 10
275
+ tokens = torch.zeros((1, seq_len), dtype=torch.int)
276
+ input_pos = torch.arange(0, seq_len, dtype=torch.int)
277
+ kv = kv_cache.KVCache.from_model_config(config.decoder_config)
278
+ pixel_values = torch.zeros(pixel_values_size, dtype=torch.float32)
279
+
280
+ edge_model = ai_edge_torch.signature(
281
+ "prefill_pixel",
282
+ pytorch_model,
283
+ sample_kwargs={
284
+ "tokens": tokens,
285
+ "input_pos": input_pos,
286
+ "kv_cache": kv,
287
+ "pixel_values": pixel_values,
288
+ },
289
+ ).convert()
290
+ edge_model.set_interpreter_builder(
291
+ self._interpreter_builder(edge_model.tflite_model())
292
+ )
293
+
294
+ tokens = torch.arange(1, seq_len + 1, dtype=torch.int).unsqueeze(0)
295
+ self.assertTrue(
296
+ test_utils.compare_tflite_torch(
297
+ edge_model,
298
+ pytorch_model,
299
+ tokens,
300
+ input_pos,
301
+ kv,
302
+ pixel_values=pixel_values,
303
+ signature_name="prefill_pixel",
304
+ atol=1e-3,
305
+ rtol=1e-5,
306
+ )
307
+ )
308
+
261
309
  @googletest.skipIf(
262
310
  ai_edge_torch.config.in_oss,
263
311
  reason="tests with custom ops are not supported in oss",
@@ -170,7 +170,7 @@ def _export_helper(
170
170
 
171
171
  # For export, we create a module that captures any non-exportable,
172
172
  # arugments, e.g. the generation config object.
173
- mod = ExportableModule(pytorch_model, export_config=export_config)
173
+ mod = ExportableModule(pytorch_model, export_config=export_config).eval()
174
174
 
175
175
  converter = converter_utils.Converter()
176
176
  for lora in loras:
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250214"
16
+ __version__ = "0.3.0.dev20250216"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250214
3
+ Version: 0.3.0.dev20250216
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=Gg-N8K4Pfmxd2OLKzGJ1nVBowEkZcjrFj8TYG8TNnWI,706
5
+ ai_edge_torch/version.py,sha256=vklXbqGLRDju4mlU9vpIceTDodbvgQCmd7eyCsV5ckM,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -94,12 +94,13 @@ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=tqvXVGNdDehda
94
94
  ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
95
95
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
96
96
  ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
97
+ ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=MXK75-Upoq_RhCbiXJEl8SKJ-msmvpVivsgfqqy-cfg,2780
97
98
  ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=0x4iDg2cBe3PFnjVce3nj7g2rjagGHcKqRCfbASNxA8,4402
98
- ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=OYyF0bLVYJno9azmKDqX3gT8ojYYWEyp_F8nLtltPWs,13544
99
- ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=Uzl1ZPkdYIaHN9QxezqxNwagZiGOHf1VreWnqgRQwf8,7627
100
- ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=2GPi0Vay4a69EwBSOfPMCMjE9PTwPOQus5j2KN7HE7I,5031
99
+ ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=nHzBe_YSPnUe1d5i09v4bePQomVifzJNeUjRfprmxC0,14878
100
+ ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=rcYHkpO-NbF4F1Da7q2xNiTng9NHiLx59HyuOgQX5W0,7753
101
+ ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=cKinMEDXauR5yKxtNTQk1RvwIHUG8-FOkmAie18sukY,5039
101
102
  ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
102
- ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=lQR8p6Zp7PxDN_erMf-FKLIn_Rv4BGyQHjDbModFkeY,2946
103
+ ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=PZ392nDoJG2OmHZ_7Jet3Zu1JkN6QErxKcDc7a-PPds,3126
103
104
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
104
105
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
105
106
  ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=CjY1i0iCYxFSjhCpQZwxkmVxILgeo0zu1m0oBrHqyDU,2311
@@ -141,7 +142,7 @@ ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbX
141
142
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
142
143
  ai_edge_torch/generative/layers/kv_cache.py,sha256=sGGAZD0mWYuO4FukZfDbHXoxpBOBE9lTYICvZzDj5F8,6400
143
144
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
144
- ai_edge_torch/generative/layers/model_config.py,sha256=Yqa3wqZLBe0Lj4PPTIaVFaZ--sV6NJ6k8KPjRguDvCc,8095
145
+ ai_edge_torch/generative/layers/model_config.py,sha256=EA1Ey5-c1IOLRNANSUnZ7gtNTA0o6OJxrz_I_mp8cjw,8244
145
146
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
146
147
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
147
148
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
@@ -167,12 +168,12 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOryp
167
168
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
168
169
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
169
170
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
170
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=AJs_ARfWUqwuFRwYtQQOLd87CiD4mUDwAhq885cqc4Q,12875
171
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bXJwDxSPgxVKp-_6BsEmMA3TuMUaUNiZoYomNounxco,14416
171
172
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
172
173
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
173
174
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
174
175
  ai_edge_torch/generative/utilities/bmm_4d.py,sha256=2BMOYiFVUsl-bjxmLkrX4N7kpO0CnhB7eDYxm_iBCr8,2533
175
- ai_edge_torch/generative/utilities/converter.py,sha256=K9taR0KY59dvfU_jO1yBe_p7w8lDns1Q3U6oJTTKZzM,8058
176
+ ai_edge_torch/generative/utilities/converter.py,sha256=_PO9lYCdNNYPVsAqh8QQVMG_8TUBshKwmaR1cdT6Ang,8065
176
177
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
177
178
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
178
179
  ai_edge_torch/generative/utilities/model_builder.py,sha256=5WqcxpeTdt51nVoUwt9g5kKB5wQKj2eYbiaz7k6Ofxg,6815
@@ -229,8 +230,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
229
230
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
230
231
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
231
232
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
232
- ai_edge_torch_nightly-0.3.0.dev20250214.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
233
- ai_edge_torch_nightly-0.3.0.dev20250214.dist-info/METADATA,sha256=u-x1rHrzHOUQBPLQbu8r3-HvX0EvMYP1RkZQ1ZZHEKY,1966
234
- ai_edge_torch_nightly-0.3.0.dev20250214.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
235
- ai_edge_torch_nightly-0.3.0.dev20250214.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
236
- ai_edge_torch_nightly-0.3.0.dev20250214.dist-info/RECORD,,
233
+ ai_edge_torch_nightly-0.3.0.dev20250216.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
234
+ ai_edge_torch_nightly-0.3.0.dev20250216.dist-info/METADATA,sha256=uej57gx3UQtqqGHydXxTTrLAlbzRg48u-YmLvdPxrIk,1966
235
+ ai_edge_torch_nightly-0.3.0.dev20250216.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
236
+ ai_edge_torch_nightly-0.3.0.dev20250216.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
237
+ ai_edge_torch_nightly-0.3.0.dev20250216.dist-info/RECORD,,