ai-edge-torch-nightly 0.5.0.dev20250417__py3-none-any.whl → 0.5.0.dev20250419__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,14 +35,12 @@ def _run_convert_passes(
35
35
  )
36
36
 
37
37
  passes = [
38
+ fx_passes.CastInputsBf16ToF32Pass(),
38
39
  fx_passes.BuildInterpolateCompositePass(),
39
- fx_passes.CanonicalizePass(),
40
40
  fx_passes.OptimizeLayoutTransposesPass(),
41
41
  fx_passes.CanonicalizePass(),
42
42
  fx_passes.BuildAtenCompositePass(),
43
43
  fx_passes.RemoveNonUserOutputsPass(),
44
- fx_passes.CastInputsBf16ToF32Pass(),
45
- fx_passes.CanonicalizePass(),
46
44
  ]
47
45
 
48
46
  # Debuginfo is not injected automatically by odml_torch. Only inject
@@ -48,6 +48,7 @@ def main(_):
48
48
  pixel_values_size=torch.Size(
49
49
  [1, config.channels, config.image_size, config.image_size]
50
50
  ),
51
+ pixel_seq_len=(config.image_size // config.patch_size) ** 2,
51
52
  quantize=flags.FLAGS.quantize,
52
53
  config=pytorch_model.config.decoder_config,
53
54
  export_config=ExportConfig(),
@@ -43,6 +43,9 @@ def main(_):
43
43
  )
44
44
 
45
45
  grid_thw = pytorch_model.image_encoder.get_grid_thw()
46
+ spatial_merge_size = (
47
+ pytorch_model.config.image_encoder_config.spatial_merge_size
48
+ )
46
49
  converter.convert_to_tflite(
47
50
  pytorch_model,
48
51
  output_path=flags.FLAGS.output_path,
@@ -51,6 +54,10 @@ def main(_):
51
54
  pixel_values_size=(
52
55
  pytorch_model.image_encoder.get_pixel_values_size(grid_thw)
53
56
  ),
57
+ pixel_seq_len=(
58
+ (grid_thw[0][1] // spatial_merge_size)
59
+ * (grid_thw[0][2] // spatial_merge_size)
60
+ ),
54
61
  quantize=flags.FLAGS.quantize,
55
62
  config=pytorch_model.config.decoder_config,
56
63
  export_config=ExportConfig(),
@@ -23,8 +23,5 @@ def run_generative_passes(
23
23
  ) -> torch.export.ExportedProgram:
24
24
  return fx_infra.run_passes(
25
25
  exported_program,
26
- [
27
- RemoveSDPACompositeZeroMaskPass(),
28
- CanonicalizePass(),
29
- ],
26
+ [RemoveSDPACompositeZeroMaskPass()],
30
27
  )
@@ -57,7 +57,7 @@ def define_conversion_flags(model_name: str):
57
57
  )
58
58
  flags.DEFINE_string(
59
59
  'output_name_prefix',
60
- f'{model_name}',
60
+ model_name,
61
61
  'The prefix of the output tflite model name.',
62
62
  )
63
63
  flags.DEFINE_multi_integer(
@@ -91,6 +91,7 @@ def convert_to_tflite(
91
91
  output_name_prefix: str,
92
92
  prefill_seq_len: Union[int, list[int]],
93
93
  pixel_values_size: torch.Size = None,
94
+ pixel_seq_len: int = 0,
94
95
  quantize: bool = True,
95
96
  config: cfg.ModelConfig = None,
96
97
  lora_ranks: Optional[list[int]] = None,
@@ -133,12 +134,18 @@ def convert_to_tflite(
133
134
  use. If a list, the model will have multiple prefill signatures.
134
135
  pixel_values_size (torch.Size, optional): The size of pixel values to pass
135
136
  to the model. If None, the model is not expected to take pixel values.
137
+ pixel_seq_len (int, optional): The length of pixel tokens, or pixel
138
+ embeddings generated by the image encoder with pixel values. The actual
139
+ length of prefill_seq_len will be added by pixel_seq_len when pixel
140
+ values are passed.
136
141
  quantize (bool, optional): Whether the model should be quanized. Defaults
137
142
  to True.
138
143
  config (cfg.ModelConfig, optional): The model config used to configure KV
139
144
  cache. If None, it uses the config of the pytorch_model.
140
145
  lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
141
146
  no LoRA signatures will be added.
147
+ export_config (ExportConfig, optional): The export configuration. If None,
148
+ it uses the default export configuration.
142
149
  """
143
150
  # pylint: disable=protected-access
144
151
  torch._dynamo.config.cache_size_limit = 64
@@ -173,6 +180,7 @@ def convert_to_tflite(
173
180
  output_file,
174
181
  prefill_seq_lens,
175
182
  pixel_values_size,
183
+ pixel_seq_len,
176
184
  quantize,
177
185
  config,
178
186
  loras,
@@ -185,6 +193,7 @@ def _export_helper(
185
193
  output_file: str,
186
194
  prefill_seq_lens: list[int],
187
195
  pixel_values_size: torch.Size,
196
+ pixel_seq_len: int,
188
197
  quantize: bool,
189
198
  config: cfg.ModelConfig,
190
199
  loras: list[None | lora_utils.LoRA],
@@ -197,11 +206,18 @@ def _export_helper(
197
206
  prefill_tokens_list.append(torch.full((1, seq_len), 0, dtype=torch.int))
198
207
  prefill_input_pos_list.append(torch.arange(0, seq_len, dtype=torch.int))
199
208
 
200
- prefill_pixel_values = (
201
- torch.full(pixel_values_size, 0, dtype=torch.float32)
202
- if pixel_values_size
203
- else None
204
- )
209
+ prefill_pixel_values = None
210
+ prefill_tokens_list_with_pixel = []
211
+ prefill_input_pos_list_with_pixel = []
212
+ if pixel_values_size is not None:
213
+ prefill_pixel_values = torch.full(pixel_values_size, 0, dtype=torch.float32)
214
+ for seq_len in prefill_seq_lens:
215
+ prefill_tokens_list_with_pixel.append(
216
+ torch.full((1, seq_len + pixel_seq_len), 0, dtype=torch.int)
217
+ )
218
+ prefill_input_pos_list_with_pixel.append(
219
+ torch.arange(0, seq_len + pixel_seq_len, dtype=torch.int)
220
+ )
205
221
 
206
222
  if export_config.prefill_mask is None:
207
223
  prefill_masks = None
@@ -238,13 +254,11 @@ def _export_helper(
238
254
  for lora in loras:
239
255
  for i in range(len(prefill_seq_lens)):
240
256
  prefill_seq_len = prefill_seq_lens[i]
241
- prefill_tokens = prefill_tokens_list[i]
242
- prefill_input_pos = prefill_input_pos_list[i]
243
257
  prefill_signature_name = f'prefill_{prefill_seq_len}'
244
258
 
245
259
  sample_kwargs = {
246
- 'tokens': prefill_tokens,
247
- 'input_pos': prefill_input_pos,
260
+ 'tokens': prefill_tokens_list[i],
261
+ 'input_pos': prefill_input_pos_list[i],
248
262
  'kv_cache': prefill_kv,
249
263
  }
250
264
  if prefill_masks is not None:
@@ -261,13 +275,13 @@ def _export_helper(
261
275
  )
262
276
 
263
277
  if prefill_pixel_values is not None:
278
+ sample_kwargs['tokens'] = prefill_tokens_list_with_pixel[i]
279
+ sample_kwargs['input_pos'] = prefill_input_pos_list_with_pixel[i]
280
+ sample_kwargs['pixel_values'] = prefill_pixel_values
264
281
  converter.add_signature(
265
282
  prefill_signature_name + '_pixel',
266
283
  mod,
267
- sample_kwargs={
268
- **sample_kwargs,
269
- 'pixel_values': prefill_pixel_values,
270
- },
284
+ sample_kwargs=sample_kwargs,
271
285
  )
272
286
 
273
287
  sample_kwargs = {
@@ -264,6 +264,8 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
264
264
  exported_program: The exported program to apply the pass.
265
265
  """
266
266
 
267
+ is_modified = False
268
+
267
269
  def in_i32(x: int):
268
270
  return -2147483648 <= x <= 2147483647
269
271
 
@@ -271,6 +273,7 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
271
273
  return torch.ops.aten._to_copy.default(x, dtype=torch.int32)
272
274
 
273
275
  def rewrite_arange(node: torch.fx.Node):
276
+ nonlocal is_modified
274
277
  tensor_meta = node.meta.get("tensor_meta", None)
275
278
  if not tensor_meta:
276
279
  return
@@ -282,12 +285,14 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
282
285
  return
283
286
  op = node.target
284
287
  node.target = lambda *args, **kwargs: to_int32(op(*args, **kwargs))
288
+ is_modified = True
285
289
 
286
290
  graph_module = exported_program.graph_module
287
291
  for node in graph_module.graph.nodes:
288
292
 
289
293
  if node.target == torch.ops.aten.arange.start_step:
290
294
  rewrite_arange(node)
295
+ return is_modified
291
296
 
292
297
 
293
298
  # TODO(b/331481564) Make this a ai_edge_torch FX pass.
@@ -351,9 +356,9 @@ def exported_program_to_mlir(
351
356
  exported_program,
352
357
  fx_infra.decomp.pre_lower_decomp(),
353
358
  )
354
- _convert_i64_to_i32(exported_program)
355
- # Run decompositions for retracing and cananicalization.
356
- exported_program = fx_infra.safe_run_decompositions(exported_program, {})
359
+ if _convert_i64_to_i32(exported_program):
360
+ # Run decompositions for retracing and cananicalization, if modified.
361
+ exported_program = fx_infra.safe_run_decompositions(exported_program, {})
357
362
 
358
363
  # Passes below mutate the exported program to a state not executable by torch.
359
364
  # Do not call run_decompositions after applying the passes.
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250417"
16
+ __version__ = "0.5.0.dev20250419"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250417
3
+ Version: 0.5.0.dev20250419
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,9 +2,9 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=kwb6M7GEr85K6sLrsbI9sNCggXojl_5TX9GeVCyP9OI,706
5
+ ai_edge_torch/version.py,sha256=SG1Sn0KkGdZyTIYeY_Rw8sRC6xtCmFSkF15xymi-Eho,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
- ai_edge_torch/_convert/conversion.py,sha256=GPDsXhfECjDzOut4vh_d9qWcyfpxobFMBTsC7MyJbM0,5557
7
+ ai_edge_torch/_convert/conversion.py,sha256=0gpwEjlTue5RttDerzM5SVOUnY8g16444yL2YIFBx-E,5485
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
9
9
  ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
10
10
  ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
@@ -84,7 +84,7 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=Hgp31zIQdJ
84
84
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
85
85
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=4W26ZtPF5Cb9mpHYuRM4b2QB_4W76zf4WV36KzexVjs,2446
86
86
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
87
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=2GLE4empjc8IANssR02ECFUqhdUNJV_OVHCf1UXKL8Y,1956
87
+ ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=7HHXkC-IIu7ieBvBI4RlXs_oITz7R8a6YVYQskAs_Uk,2023
88
88
  ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=G1dwtWp_v77AI3uyIY-8g6qRP2tRH3CIKjJTeYNqFPU,5511
89
89
  ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=Z-SKdb0dd8uWT1d-FRwFx5-tJEqpdrQwiIZnFRhOtVo,6060
90
90
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=SvuR97sjkBtfkerH7Hu1UXB8kCFLpEATNbPfCbNAyfo,5614
@@ -108,7 +108,7 @@ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=om3lXL1RnA87P
108
108
  ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
109
109
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
110
110
  ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
111
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=OcE2-8lqAukoK5hM1sqdgfXU37kxWQ84racweNAdjyk,1995
111
+ ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=yVebRatt2SLCsGvrYTBXOM-0S2REhkpikHTyy5MCjUw,2222
112
112
  ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=7RFM25tDj_b0FkpSv8RUWir8K8v9p2jMtwZmP4VAUhw,4474
113
113
  ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=nHzBe_YSPnUe1d5i09v4bePQomVifzJNeUjRfprmxC0,14878
114
114
  ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=mfLFrT8NPEPh9CqlJYHwh-I2y6ST7hH_vEmbZYartHQ,7764
@@ -147,7 +147,7 @@ ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6
147
147
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=LPxg7mAJ_aAUIx6eE5bxixPA8Ep9Vul0CWJoNcrD5oE,1565
148
148
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
149
149
  ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LRu6PSw7Lqu6HGbv1tO2i0nUCqe-VkRgboA10VZ7KNg,2431
150
- ai_edge_torch/generative/fx_passes/__init__.py,sha256=4rFrppMRKlTwwZeX1ON_cdp4yUqoTOES161IZQkJF6c,1143
150
+ ai_edge_torch/generative/fx_passes/__init__.py,sha256=PFSMsA1vfBfrV9ssBCkYJNl8Hx_bLdWjN01iyjPM5jE,1094
151
151
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
152
152
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
153
153
  ai_edge_torch/generative/layers/attention.py,sha256=wLZ1jgUlcODBWgK3hnnhclHuuQDqYuGOZdYAI9EooOM,13247
@@ -186,7 +186,7 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=-v2Vj7Qdd3Gy
186
186
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
187
187
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
188
188
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
189
- ai_edge_torch/generative/utilities/converter.py,sha256=swtz69oyMOxSaCEYST_Gzd5sjGZ1qOBAfd_0xl207Nk,9766
189
+ ai_edge_torch/generative/utilities/converter.py,sha256=LtBHjnslhL-uf4sDRoC8JIbbUD73g0QW3FiWsHUdV1g,10631
190
190
  ai_edge_torch/generative/utilities/export_config.py,sha256=8-795nyd3M34LkGhgW7hwHlJyTc2Oz1iipHK8yBhdFs,1633
191
191
  ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
192
192
  ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
@@ -212,7 +212,7 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1
212
212
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
213
213
  ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
214
214
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
215
- ai_edge_torch/odml_torch/export.py,sha256=rxsyVagQgb-DDIVtwZwSTSVFINqwIZleOOfmPkBoPKg,14817
215
+ ai_edge_torch/odml_torch/export.py,sha256=lbLpdGa8MDE8oWNA7aSV3tOCQ9P9I2Ox95dSPEssn-g,14930
216
216
  ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
217
217
  ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
218
218
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
@@ -245,8 +245,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
245
245
  ai_edge_torch/testing/export.py,sha256=dguMa-aEi-WDPnmGBUs2IPdEmt2IVmHOELH19uiJ1uU,3014
246
246
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
247
247
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
248
- ai_edge_torch_nightly-0.5.0.dev20250417.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
249
- ai_edge_torch_nightly-0.5.0.dev20250417.dist-info/METADATA,sha256=ovMriaKRgveZtN2i-cTOM2_8BuNvgf-SYNITAte1wjs,2051
250
- ai_edge_torch_nightly-0.5.0.dev20250417.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
251
- ai_edge_torch_nightly-0.5.0.dev20250417.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
252
- ai_edge_torch_nightly-0.5.0.dev20250417.dist-info/RECORD,,
248
+ ai_edge_torch_nightly-0.5.0.dev20250419.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
249
+ ai_edge_torch_nightly-0.5.0.dev20250419.dist-info/METADATA,sha256=FqixHlt1f3QPZdgxBHMWfF_GAD2GwOXkFvyVMP8IjpI,2051
250
+ ai_edge_torch_nightly-0.5.0.dev20250419.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
251
+ ai_edge_torch_nightly-0.5.0.dev20250419.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
252
+ ai_edge_torch_nightly-0.5.0.dev20250419.dist-info/RECORD,,