ai-edge-torch-nightly 0.5.0.dev20250417__py3-none-any.whl → 0.5.0.dev20250418__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -48,6 +48,7 @@ def main(_):
48
48
  pixel_values_size=torch.Size(
49
49
  [1, config.channels, config.image_size, config.image_size]
50
50
  ),
51
+ pixel_seq_len=(config.image_size // config.patch_size) ** 2,
51
52
  quantize=flags.FLAGS.quantize,
52
53
  config=pytorch_model.config.decoder_config,
53
54
  export_config=ExportConfig(),
@@ -43,6 +43,9 @@ def main(_):
43
43
  )
44
44
 
45
45
  grid_thw = pytorch_model.image_encoder.get_grid_thw()
46
+ spatial_merge_size = (
47
+ pytorch_model.config.image_encoder_config.spatial_merge_size
48
+ )
46
49
  converter.convert_to_tflite(
47
50
  pytorch_model,
48
51
  output_path=flags.FLAGS.output_path,
@@ -51,6 +54,10 @@ def main(_):
51
54
  pixel_values_size=(
52
55
  pytorch_model.image_encoder.get_pixel_values_size(grid_thw)
53
56
  ),
57
+ pixel_seq_len=(
58
+ (grid_thw[0][1] // spatial_merge_size)
59
+ * (grid_thw[0][2] // spatial_merge_size)
60
+ ),
54
61
  quantize=flags.FLAGS.quantize,
55
62
  config=pytorch_model.config.decoder_config,
56
63
  export_config=ExportConfig(),
@@ -57,7 +57,7 @@ def define_conversion_flags(model_name: str):
57
57
  )
58
58
  flags.DEFINE_string(
59
59
  'output_name_prefix',
60
- f'{model_name}',
60
+ model_name,
61
61
  'The prefix of the output tflite model name.',
62
62
  )
63
63
  flags.DEFINE_multi_integer(
@@ -91,6 +91,7 @@ def convert_to_tflite(
91
91
  output_name_prefix: str,
92
92
  prefill_seq_len: Union[int, list[int]],
93
93
  pixel_values_size: torch.Size = None,
94
+ pixel_seq_len: int = 0,
94
95
  quantize: bool = True,
95
96
  config: cfg.ModelConfig = None,
96
97
  lora_ranks: Optional[list[int]] = None,
@@ -133,12 +134,18 @@ def convert_to_tflite(
133
134
  use. If a list, the model will have multiple prefill signatures.
134
135
  pixel_values_size (torch.Size, optional): The size of pixel values to pass
135
136
  to the model. If None, the model is not expected to take pixel values.
137
+ pixel_seq_len (int, optional): The length of pixel tokens, or pixel
138
+ embeddings generated by the image encoder with pixel values. The actual
139
+ length of prefill_seq_len will be added by pixel_seq_len when pixel
140
+ values are passed.
136
141
  quantize (bool, optional): Whether the model should be quanized. Defaults
137
142
  to True.
138
143
  config (cfg.ModelConfig, optional): The model config used to configure KV
139
144
  cache. If None, it uses the config of the pytorch_model.
140
145
  lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
141
146
  no LoRA signatures will be added.
147
+ export_config (ExportConfig, optional): The export configuration. If None,
148
+ it uses the default export configuration.
142
149
  """
143
150
  # pylint: disable=protected-access
144
151
  torch._dynamo.config.cache_size_limit = 64
@@ -173,6 +180,7 @@ def convert_to_tflite(
173
180
  output_file,
174
181
  prefill_seq_lens,
175
182
  pixel_values_size,
183
+ pixel_seq_len,
176
184
  quantize,
177
185
  config,
178
186
  loras,
@@ -185,6 +193,7 @@ def _export_helper(
185
193
  output_file: str,
186
194
  prefill_seq_lens: list[int],
187
195
  pixel_values_size: torch.Size,
196
+ pixel_seq_len: int,
188
197
  quantize: bool,
189
198
  config: cfg.ModelConfig,
190
199
  loras: list[None | lora_utils.LoRA],
@@ -197,11 +206,18 @@ def _export_helper(
197
206
  prefill_tokens_list.append(torch.full((1, seq_len), 0, dtype=torch.int))
198
207
  prefill_input_pos_list.append(torch.arange(0, seq_len, dtype=torch.int))
199
208
 
200
- prefill_pixel_values = (
201
- torch.full(pixel_values_size, 0, dtype=torch.float32)
202
- if pixel_values_size
203
- else None
204
- )
209
+ prefill_pixel_values = None
210
+ prefill_tokens_list_with_pixel = []
211
+ prefill_input_pos_list_with_pixel = []
212
+ if pixel_values_size is not None:
213
+ prefill_pixel_values = torch.full(pixel_values_size, 0, dtype=torch.float32)
214
+ for seq_len in prefill_seq_lens:
215
+ prefill_tokens_list_with_pixel.append(
216
+ torch.full((1, seq_len + pixel_seq_len), 0, dtype=torch.int)
217
+ )
218
+ prefill_input_pos_list_with_pixel.append(
219
+ torch.arange(0, seq_len + pixel_seq_len, dtype=torch.int)
220
+ )
205
221
 
206
222
  if export_config.prefill_mask is None:
207
223
  prefill_masks = None
@@ -238,13 +254,11 @@ def _export_helper(
238
254
  for lora in loras:
239
255
  for i in range(len(prefill_seq_lens)):
240
256
  prefill_seq_len = prefill_seq_lens[i]
241
- prefill_tokens = prefill_tokens_list[i]
242
- prefill_input_pos = prefill_input_pos_list[i]
243
257
  prefill_signature_name = f'prefill_{prefill_seq_len}'
244
258
 
245
259
  sample_kwargs = {
246
- 'tokens': prefill_tokens,
247
- 'input_pos': prefill_input_pos,
260
+ 'tokens': prefill_tokens_list[i],
261
+ 'input_pos': prefill_input_pos_list[i],
248
262
  'kv_cache': prefill_kv,
249
263
  }
250
264
  if prefill_masks is not None:
@@ -261,13 +275,13 @@ def _export_helper(
261
275
  )
262
276
 
263
277
  if prefill_pixel_values is not None:
278
+ sample_kwargs['tokens'] = prefill_tokens_list_with_pixel[i]
279
+ sample_kwargs['input_pos'] = prefill_input_pos_list_with_pixel[i]
280
+ sample_kwargs['pixel_values'] = prefill_pixel_values
264
281
  converter.add_signature(
265
282
  prefill_signature_name + '_pixel',
266
283
  mod,
267
- sample_kwargs={
268
- **sample_kwargs,
269
- 'pixel_values': prefill_pixel_values,
270
- },
284
+ sample_kwargs=sample_kwargs,
271
285
  )
272
286
 
273
287
  sample_kwargs = {
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250417"
16
+ __version__ = "0.5.0.dev20250418"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250417
3
+ Version: 0.5.0.dev20250418
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=kwb6M7GEr85K6sLrsbI9sNCggXojl_5TX9GeVCyP9OI,706
5
+ ai_edge_torch/version.py,sha256=RijgVzFc62dFf_A0Pr5jwDo3wtofAoxUcfBieTnHvIw,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=GPDsXhfECjDzOut4vh_d9qWcyfpxobFMBTsC7MyJbM0,5557
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -84,7 +84,7 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=Hgp31zIQdJ
84
84
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
85
85
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=4W26ZtPF5Cb9mpHYuRM4b2QB_4W76zf4WV36KzexVjs,2446
86
86
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
87
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=2GLE4empjc8IANssR02ECFUqhdUNJV_OVHCf1UXKL8Y,1956
87
+ ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=7HHXkC-IIu7ieBvBI4RlXs_oITz7R8a6YVYQskAs_Uk,2023
88
88
  ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=G1dwtWp_v77AI3uyIY-8g6qRP2tRH3CIKjJTeYNqFPU,5511
89
89
  ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=Z-SKdb0dd8uWT1d-FRwFx5-tJEqpdrQwiIZnFRhOtVo,6060
90
90
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=SvuR97sjkBtfkerH7Hu1UXB8kCFLpEATNbPfCbNAyfo,5614
@@ -108,7 +108,7 @@ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=om3lXL1RnA87P
108
108
  ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
109
109
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
110
110
  ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
111
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=OcE2-8lqAukoK5hM1sqdgfXU37kxWQ84racweNAdjyk,1995
111
+ ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=yVebRatt2SLCsGvrYTBXOM-0S2REhkpikHTyy5MCjUw,2222
112
112
  ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=7RFM25tDj_b0FkpSv8RUWir8K8v9p2jMtwZmP4VAUhw,4474
113
113
  ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=nHzBe_YSPnUe1d5i09v4bePQomVifzJNeUjRfprmxC0,14878
114
114
  ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=mfLFrT8NPEPh9CqlJYHwh-I2y6ST7hH_vEmbZYartHQ,7764
@@ -186,7 +186,7 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=-v2Vj7Qdd3Gy
186
186
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
187
187
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
188
188
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
189
- ai_edge_torch/generative/utilities/converter.py,sha256=swtz69oyMOxSaCEYST_Gzd5sjGZ1qOBAfd_0xl207Nk,9766
189
+ ai_edge_torch/generative/utilities/converter.py,sha256=LtBHjnslhL-uf4sDRoC8JIbbUD73g0QW3FiWsHUdV1g,10631
190
190
  ai_edge_torch/generative/utilities/export_config.py,sha256=8-795nyd3M34LkGhgW7hwHlJyTc2Oz1iipHK8yBhdFs,1633
191
191
  ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
192
192
  ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
@@ -245,8 +245,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
245
245
  ai_edge_torch/testing/export.py,sha256=dguMa-aEi-WDPnmGBUs2IPdEmt2IVmHOELH19uiJ1uU,3014
246
246
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
247
247
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
248
- ai_edge_torch_nightly-0.5.0.dev20250417.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
249
- ai_edge_torch_nightly-0.5.0.dev20250417.dist-info/METADATA,sha256=ovMriaKRgveZtN2i-cTOM2_8BuNvgf-SYNITAte1wjs,2051
250
- ai_edge_torch_nightly-0.5.0.dev20250417.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
251
- ai_edge_torch_nightly-0.5.0.dev20250417.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
252
- ai_edge_torch_nightly-0.5.0.dev20250417.dist-info/RECORD,,
248
+ ai_edge_torch_nightly-0.5.0.dev20250418.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
249
+ ai_edge_torch_nightly-0.5.0.dev20250418.dist-info/METADATA,sha256=cmJPI-Zen5YAjRWQPoVwvcAwVZJGz3Jz0OuNAHs5498,2051
250
+ ai_edge_torch_nightly-0.5.0.dev20250418.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
251
+ ai_edge_torch_nightly-0.5.0.dev20250418.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
252
+ ai_edge_torch_nightly-0.5.0.dev20250418.dist-info/RECORD,,