ai-edge-torch-nightly 0.3.0.dev20250213__py3-none-any.whl → 0.3.0.dev20250215__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,92 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting a Qwen 2.5 VL model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
24
+ from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
+
27
+ _CHECKPOINT_PATH = flags.DEFINE_string(
28
+ 'checkpoint_path',
29
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen-vl'),
30
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
31
+ )
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
+ '/tmp/',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'qwen_vl',
40
+ 'The prefix of the output tflite model name.',
41
+ )
42
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
43
+ 'prefill_seq_len',
44
+ 1024,
45
+ 'The maximum size of prefill input tensor.',
46
+ )
47
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
+ 'kv_cache_max_len',
49
+ 1280,
50
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
51
+ )
52
+ _IMAGE_HEIGHT = flags.DEFINE_integer(
53
+ 'image_height',
54
+ 34 * 14,
55
+ 'The height of image.',
56
+ )
57
+ _IMAGE_WIDTH = flags.DEFINE_integer(
58
+ 'image_width',
59
+ 46 * 14,
60
+ 'The width of image.',
61
+ )
62
+ _QUANTIZE = flags.DEFINE_bool(
63
+ 'quantize',
64
+ True,
65
+ 'Whether the model should be quantized.',
66
+ )
67
+
68
+
69
+ def main(_):
70
+ pytorch_model = qwen_vl.build_model(
71
+ _CHECKPOINT_PATH.value,
72
+ kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
73
+ image_size=(_IMAGE_HEIGHT.value, _IMAGE_WIDTH.value),
74
+ )
75
+
76
+ grid_thw = pytorch_model.image_encoder.get_grid_thw()
77
+ converter.convert_to_tflite(
78
+ pytorch_model,
79
+ output_path=_OUTPUT_PATH.value,
80
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
81
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
82
+ pixel_values_size=(
83
+ pytorch_model.image_encoder.get_pixel_values_size(grid_thw)
84
+ ),
85
+ quantize=_QUANTIZE.value,
86
+ config=pytorch_model.config.decoder_config,
87
+ export_config=ExportConfig(),
88
+ )
89
+
90
+
91
+ if __name__ == '__main__':
92
+ app.run(main)
@@ -16,7 +16,7 @@
16
16
  """Example of building an image encoder of Qwen 2.5 VL model."""
17
17
 
18
18
  import dataclasses
19
- from typing import Optional
19
+ from typing import List, Optional, Tuple
20
20
 
21
21
  from ai_edge_torch.generative.layers import attention
22
22
  from ai_edge_torch.generative.layers import attention_utils
@@ -93,7 +93,7 @@ class QwenVLImageEncoder(nn.Module):
93
93
 
94
94
  # Tensor shape used to reshape pixel_values in forward() and various places.
95
95
  self.kernel_size = (
96
- -1, # batch size
96
+ -1, # pixel_values.size(0)
97
97
  config.image_embedding.channels,
98
98
  config.image_embedding.temporal_patch_size,
99
99
  config.image_embedding.patch_size,
@@ -118,28 +118,22 @@ class QwenVLImageEncoder(nn.Module):
118
118
  )
119
119
  self.merger = QwenVLMerger(config)
120
120
  self.config = config
121
+ self.set_image_size(config.image_embedding.image_size)
121
122
 
122
123
  @torch.inference_mode
123
- def forward(
124
- self, pixel_values: torch.Tensor, grid_thw: torch.Tensor
125
- ) -> torch.Tensor:
126
- # Get window index and sequence lengths to rearrange the input tensor.
127
- window_index, cu_seqlens = self._get_window_index(grid_thw)
124
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
125
+ # Check if the pixel value size matches with grid size and image config.
126
+ assert pixel_values.size() == self.get_pixel_values_size(self.grid_thw)
128
127
 
129
128
  # Embed the image and rearrange the embedding tensor.
130
- pixel_reshaped = pixel_values.view(self.kernel_size)
129
+ pixel_reshaped = pixel_values.reshape(self.kernel_size)
131
130
  x = self.tok_embedding(pixel_reshaped)
132
131
  x = x.view(-1, self.config.embedding_dim)
133
- x = self._rearrange(x, window_index).unsqueeze(0)
132
+ x = self._rearrange(x, self.window_index).unsqueeze(0)
134
133
 
135
- # Get RoPE and attention mask arranged according to the window index.
136
- cos, sin = self._get_rope(grid_thw)
137
- rope = (
138
- self._rearrange(cos, window_index),
139
- self._rearrange(sin, window_index),
140
- )
134
+ rope = self._get_rope(self.grid_thw, self.window_index)
141
135
 
142
- mask = self._get_mask(x.shape[1], cu_seqlens)
136
+ mask = self._get_mask(self.grid_thw, self.cu_seqlens)
143
137
  full_mask = torch.zeros(x.shape[:2])
144
138
  for i, block in enumerate(self.transformer_blocks):
145
139
  x = block(
@@ -150,10 +144,42 @@ class QwenVLImageEncoder(nn.Module):
150
144
 
151
145
  y = self.merger.forward(self.final_norm(x))
152
146
  # Arrange the output back to the original order.
153
- reverse_index = torch.argsort(window_index)
154
- return y[reverse_index, ...]
155
-
156
- def _get_rope(self, grid_thw: torch.Tensor) -> torch.Tensor:
147
+ return y[self.reverse_index, ...]
148
+
149
+ def set_image_size(self, image_size: Tuple[int, int]):
150
+ """Set the image size and pre-calculate some values including mask."""
151
+ self.config.image_embedding.image_size = image_size
152
+ self.grid_thw = self.get_grid_thw()
153
+
154
+ # Precalculate the window index which can't be lowered to HLO because of
155
+ # inconcrete index in:
156
+ # index_new = index_padded[index_padded != -100]
157
+ self.window_index, self.cu_seqlens = self._get_window_index(self.grid_thw)
158
+
159
+ # Precalculate the reverse index of window_index until "vhlo.sort_v1" op is
160
+ # supported.
161
+ self.reverse_index = torch.argsort(self.window_index)
162
+
163
+ def get_grid_thw(self, num_images: int = 1) -> List[Tuple[int, int, int]]:
164
+ """Calculate the grid size of the input images based on the image config."""
165
+ height, width = self.config.image_embedding.image_size
166
+ patch_height = height // self.config.image_embedding.patch_size
167
+ patch_width = width // self.config.image_embedding.patch_size
168
+ # Support only image, i.e. temporal step size is always 1.
169
+ return [(1, patch_height, patch_width)] * num_images
170
+
171
+ def get_pixel_values_size(
172
+ self, grid_thw: List[Tuple[int, int, int]]
173
+ ) -> torch.Size:
174
+ """Calculate the size of pixel values tensor."""
175
+ dim_0 = sum(t * h * w for t, h, w in grid_thw)
176
+ config = self.config.image_embedding
177
+ dim_1 = config.channels * config.temporal_patch_size * config.patch_size**2
178
+ return torch.Size((dim_0, dim_1))
179
+
180
+ def _get_rope(
181
+ self, grid_thw: List[Tuple[int, int, int]], window_index: torch.Tensor
182
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
157
183
  """Get RoPE for Qwen VL model based on image grid information.
158
184
 
159
185
  It's copied from Qwen2_5_VisionTransformerPretrainedModel.rot_pos_emb() and
@@ -182,16 +208,20 @@ class QwenVLImageEncoder(nn.Module):
182
208
  wpos_ids = wpos_ids.flatten()
183
209
  pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
184
210
  pos_ids = torch.cat(pos_ids, dim=0)
185
- max_grid_size = grid_thw[:, 1:].max()
211
+ # Assume all the heights and widths are the same for all images.
212
+ max_grid_size = max(grid_thw[0][1], grid_thw[0][2])
186
213
 
187
214
  cos, sin = attention_utils.build_rope_cache(
188
215
  max_grid_size,
189
216
  # ROPE parameters for all attn_configs are the same. Take the first one.
190
217
  self.config.block_config(0).attn_config.head_dim // 2,
191
218
  )
192
- return cos[pos_ids].flatten(1), sin[pos_ids].flatten(1)
219
+ return (
220
+ self._rearrange(cos[pos_ids].flatten(1), window_index),
221
+ self._rearrange(sin[pos_ids].flatten(1), window_index),
222
+ )
193
223
 
194
- def _get_window_index(self, grid_thw: torch.Tensor):
224
+ def _get_window_index(self, grid_thw: List[Tuple[int, int, int]]):
195
225
  """Get window index for Qwen VL model to rearrange the input tensor.
196
226
 
197
227
  It's copied from Qwen2_5_VisionTransformerPretrainedModel.get_window_index()
@@ -207,13 +237,10 @@ class QwenVLImageEncoder(nn.Module):
207
237
  )
208
238
 
209
239
  for grid_t, grid_h, grid_w in grid_thw:
210
- llm_grid_h, llm_grid_w = (
211
- grid_h // self.config.spatial_merge_size,
212
- grid_w // self.config.spatial_merge_size,
213
- )
214
- index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
215
- grid_t, llm_grid_h, llm_grid_w
216
- )
240
+ llm_grid_h = grid_h // self.config.spatial_merge_size
241
+ llm_grid_w = grid_w // self.config.spatial_merge_size
242
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w)
243
+ index = index.reshape((grid_t, llm_grid_h, llm_grid_w))
217
244
  pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
218
245
  pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
219
246
  num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
@@ -236,18 +263,14 @@ class QwenVLImageEncoder(nn.Module):
236
263
  index_padded = index_padded.reshape(-1)
237
264
  index_new = index_padded[index_padded != -100]
238
265
  window_index.append(index_new + window_index_id)
239
- spatial_merge_unit = (
240
- self.config.spatial_merge_size * self.config.spatial_merge_size
241
- )
266
+ spatial_merge_unit = self.config.spatial_merge_size**2
242
267
  cu_seqlens_tmp = (
243
268
  seqlens.cumsum(0) * spatial_merge_unit + cu_window_seqlens[-1]
244
269
  )
245
270
  cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
246
- window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
271
+ window_index_id += grid_t * llm_grid_h * llm_grid_w
247
272
 
248
273
  window_index = torch.cat(window_index, dim=0)
249
- cu_window_seqlens = torch.tensor(cu_window_seqlens)
250
- cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
251
274
  return window_index, cu_window_seqlens
252
275
 
253
276
  def _rearrange(
@@ -258,20 +281,20 @@ class QwenVLImageEncoder(nn.Module):
258
281
  It's copied from Qwen2_5_VisionTransformerPretrainedModel.forward() and
259
282
  modified accordingly.
260
283
  """
261
- size = x.shape[0]
262
- spatial_merge_unit = (
263
- self.config.spatial_merge_size * self.config.spatial_merge_size
264
- )
265
- x_reshaped = x.view(size // spatial_merge_unit, spatial_merge_unit, -1)
284
+ spatial_merge_unit = self.config.spatial_merge_size**2
285
+ x_reshaped = x.view(x.size(0) // spatial_merge_unit, spatial_merge_unit, -1)
266
286
  x_rearranged = x_reshaped[window_index, ...]
267
- return x_rearranged.view(size, -1)
287
+ return x_rearranged.view(x.shape)
268
288
 
269
- def _get_mask(self, seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
289
+ def _get_mask(
290
+ self, grid_thw: List[Tuple[int, int, int]], cu_seqlens: List[int]
291
+ ) -> torch.Tensor:
270
292
  """Get attention mask for Qwen VL model.
271
293
 
272
294
  It's copied from Qwen2_5_VLVisionAttention.forward() and modified
273
295
  accordingly.
274
296
  """
297
+ seqlen = self.get_pixel_values_size(grid_thw)[0]
275
298
  mask = torch.full([1, 1, seqlen, seqlen], float("-inf"))
276
299
  for i in range(1, len(cu_seqlens)):
277
300
  mask[
@@ -282,7 +305,7 @@ class QwenVLImageEncoder(nn.Module):
282
305
  return mask
283
306
 
284
307
 
285
- def get_image_encoder_config() -> QwenVLImageConfig:
308
+ def get_image_encoder_config(image_size: Tuple[int, int]) -> QwenVLImageConfig:
286
309
  """Returns the model config for the image encoder of a Qwen 2.5 VL model.
287
310
 
288
311
  Returns:
@@ -290,7 +313,7 @@ def get_image_encoder_config() -> QwenVLImageConfig:
290
313
  """
291
314
  image_embedding_config = cfg.ImageEmbeddingConfig(
292
315
  channels=3,
293
- image_size=0, # Not used in image encoder.
316
+ image_size=image_size,
294
317
  patch_size=14,
295
318
  temporal_patch_size=2,
296
319
  )
@@ -336,15 +359,13 @@ def get_image_encoder_config() -> QwenVLImageConfig:
336
359
  window_size=112,
337
360
  spatial_merge_size=2,
338
361
  full_atten_block_indexes=[7, 15, 23, 31],
339
- # TODO: b/377051577 - Once RemoveSDPACompositeZeroMaskPass is removed,
340
- # enable_hlfb can be set to True. See b/383865404#comment3 for details.
341
- # enable_hlfb=True,
362
+ enable_hlfb=True,
342
363
  )
343
364
  return config
344
365
 
345
366
 
346
367
  def get_fake_image_encoder_config() -> QwenVLImageConfig:
347
- config = get_image_encoder_config()
368
+ config = get_image_encoder_config((8, 12))
348
369
  # PaliGemma image encoder has only one block config.
349
370
  config.block_config(0).ff_config.intermediate_size = 128
350
371
  config.image_embedding.patch_size = 2
@@ -353,8 +374,11 @@ def get_fake_image_encoder_config() -> QwenVLImageConfig:
353
374
  return config
354
375
 
355
376
 
356
- def build_image_encoder(checkpoint_path: str) -> QwenVLImageEncoder:
357
- config = get_image_encoder_config()
377
+ def build_image_encoder(
378
+ checkpoint_path: str,
379
+ image_size: Tuple[int, int] = (34 * 14, 46 * 14),
380
+ ) -> QwenVLImageEncoder:
381
+ config = get_image_encoder_config(image_size)
358
382
  encoder = QwenVLImageEncoder(config)
359
383
  load_image_encoder(checkpoint_path, encoder)
360
384
  encoder.eval()
@@ -61,7 +61,6 @@ class QwenVL(nn.Module):
61
61
  kv_cache: kv_utils.KVCache,
62
62
  mask: Optional[torch.Tensor] = None,
63
63
  pixel_values: torch.Tensor = None,
64
- grid_thw: torch.Tensor = None,
65
64
  export_config: Optional[model_builder.ExportConfig] = None,
66
65
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
67
66
  if pixel_values is None:
@@ -69,14 +68,14 @@ class QwenVL(nn.Module):
69
68
  tokens=tokens,
70
69
  input_pos=input_pos,
71
70
  kv_cache=kv_cache,
72
- mask=mask,
73
- rope=self._build_text_rope(input_pos),
74
71
  input_embeds=None,
72
+ rope=self._build_text_rope(input_pos),
73
+ mask=mask,
75
74
  export_config=export_config,
76
75
  )
77
76
 
78
77
  input_embeds = self.decoder.tok_embedding(tokens)
79
- image_embeds = self.image_encoder(pixel_values, grid_thw).unsqueeze(0)
78
+ image_embeds = self.image_encoder(pixel_values).unsqueeze(0)
80
79
 
81
80
  # Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
82
81
  # can be done like:
@@ -92,18 +91,19 @@ class QwenVL(nn.Module):
92
91
  (
93
92
  input_embeds[:, :1, :],
94
93
  image_embeds,
95
- input_embeds[:, image_embeds.shape[1] + 1 :, :],
94
+ input_embeds[:, image_embeds.size(1) + 1 :, :],
96
95
  ),
97
96
  dim=1,
98
97
  )
99
98
 
99
+ grid_thw = self.image_encoder.get_grid_thw()
100
100
  return self.decoder(
101
101
  tokens=None,
102
102
  input_pos=input_pos,
103
103
  kv_cache=kv_cache,
104
- mask=mask,
105
104
  input_embeds=input_embeds,
106
105
  rope=self._build_multimodal_rope(input_pos, grid_thw),
106
+ mask=mask,
107
107
  export_config=export_config,
108
108
  )
109
109
 
@@ -120,9 +120,9 @@ class QwenVL(nn.Module):
120
120
  def _build_text_rope(
121
121
  self, input_pos: torch.Tensor
122
122
  ) -> Tuple[torch.Tensor, torch.Tensor]:
123
- # Reset rope_pos_adjust to 0 when input sequence starts from scratch, i.e.
124
- # input_pos[0] = 0.
125
- if input_pos[0] == 0:
123
+ # Reset rope_pos_adjust to 0 when it's prefill, i.e. input has 2 or more
124
+ # tokens.
125
+ if input_pos.numel() > 1:
126
126
  self.rope_pos_adjust = 0
127
127
  return self._build_rope(input_pos + self.rope_pos_adjust)
128
128
 
@@ -178,15 +178,18 @@ class QwenVL(nn.Module):
178
178
  return torch.cat([m[i % 3] for i, m in enumerate(split)], dim=-1)
179
179
 
180
180
 
181
- def get_model_config(**kwargs) -> QwenVLConfig:
181
+ def get_model_config(
182
+ kv_cache_max_len: int = 1024,
183
+ image_size: Tuple[int, int] = (34 * 14, 46 * 14),
184
+ ) -> QwenVLConfig:
182
185
  """Returns the model config for a PaliGemma 3B-224 model.
183
186
 
184
187
  Returns:
185
188
  The model config for a PaliGemma 3B model.
186
189
  """
187
190
  return QwenVLConfig(
188
- image_encoder_config=image_encoder.get_image_encoder_config(),
189
- decoder_config=decoder.get_decoder_config(**kwargs),
191
+ image_encoder_config=image_encoder.get_image_encoder_config(image_size),
192
+ decoder_config=decoder.get_decoder_config(kv_cache_max_len),
190
193
  image_token_id=151655,
191
194
  mrope_section=[16, 24, 24],
192
195
  )
@@ -197,6 +200,7 @@ def get_fake_model_config(**kwargs) -> QwenVLConfig:
197
200
  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
198
201
  decoder_config=decoder.get_fake_decoder_config(**kwargs),
199
202
  image_token_id=127,
203
+ mrope_section=[16, 24, 24],
200
204
  )
201
205
 
202
206
 
@@ -17,6 +17,7 @@
17
17
 
18
18
  import logging
19
19
  import pathlib
20
+
20
21
  from absl import app
21
22
  from absl import flags
22
23
  from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
@@ -47,16 +48,9 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
47
48
  class ReauthoredQwenVLWrapper(verifier.ReauthoredModelWrapper):
48
49
  """Reauthored Qwen VL model wrapper."""
49
50
 
50
- def __init__(self, model: torch.nn.Module):
51
- super().__init__(model)
52
- self.grid_thw = None
53
-
54
51
  def _init_kv_cache(self):
55
52
  return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
56
53
 
57
- def _get_extra_args_for_forward(self):
58
- return {"grid_thw": self.grid_thw}
59
-
60
54
 
61
55
  def main(_):
62
56
  checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
@@ -94,7 +88,11 @@ def main(_):
94
88
 
95
89
  logging.info("Forwarding the reauthored model...")
96
90
  wrapped_reauthored_model = ReauthoredQwenVLWrapper(reauthored_model)
97
- wrapped_reauthored_model.grid_thw = inputs["image_grid_thw"]
91
+ grid_thw = inputs["image_grid_thw"].tolist()
92
+ config = reauthored_model.config.image_encoder_config.image_embedding
93
+ reauthored_model.image_encoder.set_image_size(
94
+ (grid_thw[0][1] * config.patch_size, grid_thw[0][2] * config.patch_size)
95
+ )
98
96
  outputs_reauthored = wrapped_reauthored_model.forward(
99
97
  tokens=inputs["input_ids"],
100
98
  pixel_values=inputs["pixel_values"],
@@ -64,9 +64,12 @@ def main(_):
64
64
  logging.info("outputs_original: %s", outputs_original)
65
65
 
66
66
  logging.info("Forwarding the reauthored model...")
67
- outputs_reauthored = reauthored_model.forward(
68
- image_input["pixel_values"], image_input["image_grid_thw"]
67
+ grid_thw = image_input["image_grid_thw"].tolist()
68
+ config = reauthored_model.config.image_embedding
69
+ reauthored_model.set_image_size(
70
+ (grid_thw[0][1] * config.patch_size, grid_thw[0][2] * config.patch_size)
69
71
  )
72
+ outputs_reauthored = reauthored_model.forward(image_input["pixel_values"])
70
73
  logging.info("outputs_reauthored: %s", outputs_reauthored)
71
74
 
72
75
  try:
@@ -17,7 +17,7 @@
17
17
 
18
18
  import dataclasses
19
19
  import enum
20
- from typing import Callable, Optional, Sequence, Union
20
+ from typing import Callable, Optional, Sequence, Tuple, Union
21
21
  from ai_edge_torch.generative.layers import rotary_position_embedding
22
22
 
23
23
  @enum.unique
@@ -174,8 +174,10 @@ class ImageEmbeddingConfig:
174
174
  """Image embedding parameters."""
175
175
 
176
176
  channels: int
177
- # All images should be normalized to the size of [image_size * image_size].
178
- image_size: int
177
+ # All images should be normalized to image_size * image_size if image_size is
178
+ # a single integer, or image_size[0] (height) * image_size[1] (width) if
179
+ # image_size is a tuple of 2 integers.
180
+ image_size: Union[int | Tuple[int, int]]
179
181
  patch_size: int
180
182
  # Meaningful only when image embedding is Conv3d.
181
183
  temporal_patch_size: Optional[int] = None
@@ -205,7 +207,7 @@ class ModelConfig:
205
207
  embedding_use_bias: bool = False
206
208
  # Image embedding parameters.
207
209
  image_embedding: Optional[ImageEmbeddingConfig] = None
208
- # Number of image tokens
210
+ # Number of image tokens
209
211
  num_mm_tokens_per_image: Optional[int] = None
210
212
  # Use bias term within LLM's HEAD.
211
213
  lm_head_use_bias: bool = False
@@ -28,6 +28,7 @@ from ai_edge_torch.generative.examples.paligemma import paligemma
28
28
  from ai_edge_torch.generative.examples.phi import phi2
29
29
  from ai_edge_torch.generative.examples.phi import phi3
30
30
  from ai_edge_torch.generative.examples.qwen import qwen
31
+ from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
31
32
  from ai_edge_torch.generative.examples.smollm import smollm
32
33
  from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
33
34
  from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
@@ -196,17 +197,15 @@ class TestModelConversion(googletest.TestCase):
196
197
  config = paligemma.get_fake_model_config(decoder_config)
197
198
  pytorch_model = paligemma.PaliGemma(config, decoder_class).eval()
198
199
 
199
- image_embedding_config = config.image_encoder_config.image_embedding
200
- num_patches = (
201
- image_embedding_config.image_size // image_embedding_config.patch_size
202
- ) ** 2
200
+ image_config = config.image_encoder_config.image_embedding
201
+ num_patches = (image_config.image_size // image_config.patch_size) ** 2
203
202
 
204
203
  # Make sure the token size is longer than the number of image patches.
205
204
  seq_len = num_patches + 10
206
- tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
205
+ tokens = torch.zeros((1, seq_len), dtype=torch.int)
207
206
  input_pos = torch.arange(0, seq_len, dtype=torch.int)
208
207
  kv = kv_cache.KVCache.from_model_config(config.decoder_config)
209
- pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32, device="cpu")
208
+ pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32)
210
209
 
211
210
  edge_model = ai_edge_torch.signature(
212
211
  "prefill_pixel",
@@ -258,6 +257,55 @@ class TestModelConversion(googletest.TestCase):
258
257
  rtol=1e-5,
259
258
  )
260
259
 
260
+ @googletest.skipIf(
261
+ ai_edge_torch.config.in_oss,
262
+ reason="tests with custom ops are not supported in oss",
263
+ )
264
+ def test_qwen_vl_model(self):
265
+ config = qwen_vl.get_fake_model_config()
266
+ pytorch_model = qwen_vl.QwenVL(config).eval()
267
+
268
+ grid_thw = pytorch_model.image_encoder.get_grid_thw()
269
+ pixel_values_size = pytorch_model.image_encoder.get_pixel_values_size(
270
+ grid_thw
271
+ )
272
+
273
+ # Make sure the token size is longer than the number of pixel values.
274
+ seq_len = pixel_values_size[0] + 10
275
+ tokens = torch.zeros((1, seq_len), dtype=torch.int)
276
+ input_pos = torch.arange(0, seq_len, dtype=torch.int)
277
+ kv = kv_cache.KVCache.from_model_config(config.decoder_config)
278
+ pixel_values = torch.zeros(pixel_values_size, dtype=torch.float32)
279
+
280
+ edge_model = ai_edge_torch.signature(
281
+ "prefill_pixel",
282
+ pytorch_model,
283
+ sample_kwargs={
284
+ "tokens": tokens,
285
+ "input_pos": input_pos,
286
+ "kv_cache": kv,
287
+ "pixel_values": pixel_values,
288
+ },
289
+ ).convert()
290
+ edge_model.set_interpreter_builder(
291
+ self._interpreter_builder(edge_model.tflite_model())
292
+ )
293
+
294
+ tokens = torch.arange(1, seq_len + 1, dtype=torch.int).unsqueeze(0)
295
+ self.assertTrue(
296
+ test_utils.compare_tflite_torch(
297
+ edge_model,
298
+ pytorch_model,
299
+ tokens,
300
+ input_pos,
301
+ kv,
302
+ pixel_values=pixel_values,
303
+ signature_name="prefill_pixel",
304
+ atol=1e-3,
305
+ rtol=1e-5,
306
+ )
307
+ )
308
+
261
309
  @googletest.skipIf(
262
310
  ai_edge_torch.config.in_oss,
263
311
  reason="tests with custom ops are not supported in oss",
@@ -170,7 +170,7 @@ def _export_helper(
170
170
 
171
171
  # For export, we create a module that captures any non-exportable,
172
172
  # arugments, e.g. the generation config object.
173
- mod = ExportableModule(pytorch_model, export_config=export_config)
173
+ mod = ExportableModule(pytorch_model, export_config=export_config).eval()
174
174
 
175
175
  converter = converter_utils.Converter()
176
176
  for lora in loras:
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250213"
16
+ __version__ = "0.3.0.dev20250215"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250213
3
+ Version: 0.3.0.dev20250215
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=MtKoJ3-mpjKq8yijJczUhFrKjlM6jKA--_qBzHJgNRg,706
5
+ ai_edge_torch/version.py,sha256=HRjjQujR7rDiLW1Mt_3LZQYVxZd2h-YktOT8MeVMmTc,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -94,12 +94,13 @@ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=tqvXVGNdDehda
94
94
  ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
95
95
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
96
96
  ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
97
+ ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=MXK75-Upoq_RhCbiXJEl8SKJ-msmvpVivsgfqqy-cfg,2780
97
98
  ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=0x4iDg2cBe3PFnjVce3nj7g2rjagGHcKqRCfbASNxA8,4402
98
- ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=OYyF0bLVYJno9azmKDqX3gT8ojYYWEyp_F8nLtltPWs,13544
99
- ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=Uzl1ZPkdYIaHN9QxezqxNwagZiGOHf1VreWnqgRQwf8,7627
100
- ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=2GPi0Vay4a69EwBSOfPMCMjE9PTwPOQus5j2KN7HE7I,5031
99
+ ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=nHzBe_YSPnUe1d5i09v4bePQomVifzJNeUjRfprmxC0,14878
100
+ ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=rcYHkpO-NbF4F1Da7q2xNiTng9NHiLx59HyuOgQX5W0,7753
101
+ ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=cKinMEDXauR5yKxtNTQk1RvwIHUG8-FOkmAie18sukY,5039
101
102
  ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
102
- ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=lQR8p6Zp7PxDN_erMf-FKLIn_Rv4BGyQHjDbModFkeY,2946
103
+ ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=PZ392nDoJG2OmHZ_7Jet3Zu1JkN6QErxKcDc7a-PPds,3126
103
104
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
104
105
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
105
106
  ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=CjY1i0iCYxFSjhCpQZwxkmVxILgeo0zu1m0oBrHqyDU,2311
@@ -141,7 +142,7 @@ ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbX
141
142
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
142
143
  ai_edge_torch/generative/layers/kv_cache.py,sha256=sGGAZD0mWYuO4FukZfDbHXoxpBOBE9lTYICvZzDj5F8,6400
143
144
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
144
- ai_edge_torch/generative/layers/model_config.py,sha256=Yqa3wqZLBe0Lj4PPTIaVFaZ--sV6NJ6k8KPjRguDvCc,8095
145
+ ai_edge_torch/generative/layers/model_config.py,sha256=EA1Ey5-c1IOLRNANSUnZ7gtNTA0o6OJxrz_I_mp8cjw,8244
145
146
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
146
147
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
147
148
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
@@ -167,12 +168,12 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOryp
167
168
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
168
169
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
169
170
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
170
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=AJs_ARfWUqwuFRwYtQQOLd87CiD4mUDwAhq885cqc4Q,12875
171
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bXJwDxSPgxVKp-_6BsEmMA3TuMUaUNiZoYomNounxco,14416
171
172
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
172
173
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
173
174
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
174
175
  ai_edge_torch/generative/utilities/bmm_4d.py,sha256=2BMOYiFVUsl-bjxmLkrX4N7kpO0CnhB7eDYxm_iBCr8,2533
175
- ai_edge_torch/generative/utilities/converter.py,sha256=K9taR0KY59dvfU_jO1yBe_p7w8lDns1Q3U6oJTTKZzM,8058
176
+ ai_edge_torch/generative/utilities/converter.py,sha256=_PO9lYCdNNYPVsAqh8QQVMG_8TUBshKwmaR1cdT6Ang,8065
176
177
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
177
178
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
178
179
  ai_edge_torch/generative/utilities/model_builder.py,sha256=5WqcxpeTdt51nVoUwt9g5kKB5wQKj2eYbiaz7k6Ofxg,6815
@@ -229,8 +230,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
229
230
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
230
231
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
231
232
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
232
- ai_edge_torch_nightly-0.3.0.dev20250213.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
233
- ai_edge_torch_nightly-0.3.0.dev20250213.dist-info/METADATA,sha256=gLUSBS9nUIL1uc3mfWUFYw_lDoXHUCsu4LqFRNxW1IY,1966
234
- ai_edge_torch_nightly-0.3.0.dev20250213.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
235
- ai_edge_torch_nightly-0.3.0.dev20250213.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
236
- ai_edge_torch_nightly-0.3.0.dev20250213.dist-info/RECORD,,
233
+ ai_edge_torch_nightly-0.3.0.dev20250215.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
234
+ ai_edge_torch_nightly-0.3.0.dev20250215.dist-info/METADATA,sha256=pZGTxEsYT2Tx_2xma-wcLYoLDVDqZs3lw-3sAIhUhPs,1966
235
+ ai_edge_torch_nightly-0.3.0.dev20250215.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
236
+ ai_edge_torch_nightly-0.3.0.dev20250215.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
237
+ ai_edge_torch_nightly-0.3.0.dev20250215.dist-info/RECORD,,