ai-edge-torch-nightly 0.3.0.dev20250211__py3-none-any.whl → 0.3.0.dev20250212__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -56,11 +56,6 @@ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
56
56
  1280,
57
57
  'The maximum size of KV cache buffer, including both prefill and decode.',
58
58
  )
59
- _PIXEL_VALUES_SIZE = flags.DEFINE_multi_integer(
60
- 'pixel_values_size',
61
- [3, 224, 224],
62
- 'The size of prefill pixel values except the batch dimension.',
63
- )
64
59
  _QUANTIZE = flags.DEFINE_bool(
65
60
  'quantize',
66
61
  True,
@@ -75,12 +70,15 @@ def main(_):
75
70
  kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
76
71
  )
77
72
 
73
+ config = pytorch_model.image_encoder.config.image_embedding
78
74
  converter.convert_to_tflite(
79
75
  pytorch_model,
80
76
  output_path=_OUTPUT_PATH.value,
81
77
  output_name_prefix=f'{_OUTPUT_NAME_PREFIX.value}_{_VERSION.value}',
82
78
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
83
- pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
79
+ pixel_values_size=torch.Size(
80
+ [1, config.channels, config.image_size, config.image_size]
81
+ ),
84
82
  quantize=_QUANTIZE.value,
85
83
  config=pytorch_model.config.decoder_config,
86
84
  export_config=ExportConfig(),
@@ -136,9 +136,7 @@ def get_image_encoder_config() -> cfg.ModelConfig:
136
136
  image_embedding=image_embedding_config,
137
137
  block_configs=block_config,
138
138
  final_norm_config=norm_config,
139
- # TODO: b/377051577 - Once RemoveSDPACompositeZeroMaskPass is removed,
140
- # enable_hlfb can be set to True. See b/383865404#comment3 for details.
141
- # enable_hlfb=True,
139
+ enable_hlfb=True,
142
140
  )
143
141
  return config
144
142
 
@@ -145,7 +145,7 @@ def _export_helper(
145
145
  prefill_input_pos_list.append(torch.arange(0, seq_len, dtype=torch.int))
146
146
 
147
147
  prefill_pixel_values = (
148
- torch.full((1,) + pixel_values_size, 0, dtype=torch.float32)
148
+ torch.full(pixel_values_size, 0, dtype=torch.float32)
149
149
  if pixel_values_size
150
150
  else None
151
151
  )
@@ -12,5 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from ._build import build_mlir_debuginfo, build_mlir_file_debuginfo
15
+ """Debug info generation for ODML Torch."""
16
+
17
+ from . import _build
16
18
  from ._op_polyfill import write_mlir_debuginfo_op
19
+
20
+ build_nodename_debuginfo = _build.build_nodename_debuginfo
21
+ build_mlir_file_debuginfo = _build.build_mlir_file_debuginfo
22
+ build_mlir_debuginfo = _build.build_mlir_debuginfo
@@ -40,10 +40,16 @@ def _get_canonical_filename(filename):
40
40
 
41
41
  This should be factored out so that pattern is a global option that a user
42
42
  can override.
43
+
44
+ Args:
45
+ filename: The filename to canonicalize.
46
+
47
+ Returns:
48
+ The canonicalized filename.
43
49
  """
44
50
 
45
- # TODO: We should add a config option to provide a regex to strip from the
46
- # debug info. Currently absolute path is used.
51
+ # TODO(yijieyang): We should add a config option to provide a regex to strip
52
+ # from the debug info. Currently absolute path is used.
47
53
  return filename
48
54
 
49
55
 
@@ -55,9 +61,23 @@ def build_mlir_file_debuginfo(node: torch.fx.Node):
55
61
 
56
62
  # Note: This uses internal APIs and may break in the future.
57
63
  pt_trace = torch.fx.graph._parse_stack_trace(node.stack_trace)
64
+ if pt_trace is None:
65
+ return None, None
58
66
  return _get_canonical_filename(pt_trace.file), int(pt_trace.lineno)
59
67
 
60
68
 
69
+ def build_nodename_debuginfo(node: torch.fx.Node):
70
+ """Build the fx node name for the given node's lowerings in MLIR."""
71
+ history = node.meta.get("from_node", [])
72
+ if not history:
73
+ return None
74
+ if len(history) > 1:
75
+ return history[1][0]
76
+ if hasattr(history[0], "name"): # torch 2.6.0+
77
+ return history[0].name
78
+ return None
79
+
80
+
61
81
  def build_mlir_debuginfo(node: torch.fx.Node):
62
82
  """Build the debuginfo string for the given node's lowerings in MLIR."""
63
83
 
@@ -88,6 +88,21 @@ class LoweringInterpreter(torch.fx.Interpreter):
88
88
  self.outputs = None
89
89
 
90
90
  def _build_loc(self, node: torch.fx.Node):
91
+ """Build MLIR location for the given node.
92
+
93
+ The location contains:
94
+ - layer info
95
+ - fx node name
96
+ - file and line info
97
+
98
+ Currently it's still under development and format is subject to change.
99
+
100
+ Args:
101
+ node: The torch.fx.Node to build the location for.
102
+
103
+ Returns:
104
+ The MLIR location for the given node.
105
+ """
91
106
 
92
107
  info = debuginfo.build_mlir_debuginfo(node)
93
108
  if info is None:
@@ -98,7 +113,12 @@ class LoweringInterpreter(torch.fx.Interpreter):
98
113
  if file is not None:
99
114
  fileinfo = ir.Location.file(filename=file, line=line, col=0)
100
115
 
101
- return ir.Location.name(name=info, childLoc=fileinfo)
116
+ node_name = debuginfo.build_nodename_debuginfo(node)
117
+ nodeinfo = None
118
+ if node_name is not None:
119
+ nodeinfo = ir.Location.name(name=node_name, childLoc=fileinfo)
120
+
121
+ return ir.Location.name(name=info, childLoc=nodeinfo)
102
122
 
103
123
  def run_node(self, node: torch.fx.Node):
104
124
  loc = self._build_loc(node)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250211"
16
+ __version__ = "0.3.0.dev20250212"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250211
3
+ Version: 0.3.0.dev20250212
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=SqHTGIWZweve6ni4shLmLl85BwAAptNTCG646yd6Hk8,706
5
+ ai_edge_torch/version.py,sha256=gYCO29a18zxFgm6z4IHCArhNE29IvlL_sxhEL6O4ECw,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -73,10 +73,10 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=pyxRGgMxrn
73
73
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
74
74
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
75
75
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
76
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=a6ISb96xhEJc1TtaFGCUiA4msKedPTAeMvkWrfIklx4,2792
76
+ ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=APQymtr3n2k6-e8wvn3kVrli0qiElduYIkHeahcoSA0,2743
77
77
  ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=z658dW_D0Iqvo6xnh4vG7_o17-Fufndyis8Rq5yafJY,5439
78
78
  ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=GZa0Ou_DvOijB2nTL_jRvGbn0_dvJPosQAPf47yqicw,5988
79
- ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=7K1xl64UvoHaYmqWjIbahwXHfppwTQ8sN7JrpGKX1XQ,5771
79
+ ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=SvuR97sjkBtfkerH7Hu1UXB8kCFLpEATNbPfCbNAyfo,5614
80
80
  ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=x1mgRtVLxkCTvlkPow3y7ADoGTjUh5uc5pF46mxatLw,6099
81
81
  ai_edge_torch/generative/examples/paligemma/verify.py,sha256=HLcu1fWMtFFFONAqVW94rOBqq4XvFHtatX3JFGOsfZw,5345
82
82
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
@@ -172,7 +172,7 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728Fc
172
172
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
173
173
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
174
174
  ai_edge_torch/generative/utilities/bmm_4d.py,sha256=2BMOYiFVUsl-bjxmLkrX4N7kpO0CnhB7eDYxm_iBCr8,2533
175
- ai_edge_torch/generative/utilities/converter.py,sha256=6siSpCvH_cLV-eP40lkF_AqjBpYv68xeMRQ722fKgE0,8065
175
+ ai_edge_torch/generative/utilities/converter.py,sha256=K9taR0KY59dvfU_jO1yBe_p7w8lDns1Q3U6oJTTKZzM,8058
176
176
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
177
177
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
178
178
  ai_edge_torch/generative/utilities/model_builder.py,sha256=5WqcxpeTdt51nVoUwt9g5kKB5wQKj2eYbiaz7k6Ofxg,6815
@@ -197,14 +197,14 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1
197
197
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
198
198
  ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
199
199
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
200
- ai_edge_torch/odml_torch/export.py,sha256=LDyZUehM1lmT3y2bGeA94rMGRUTLxzIUm4DTlCA8tQc,13640
200
+ ai_edge_torch/odml_torch/export.py,sha256=7l8R0DEq_vfns8iWpruMlIyaIKZAFzoAy369-7iRrl0,14164
201
201
  ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
202
202
  ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
203
203
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
204
204
  ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
205
205
  ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
206
- ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=DoE3HgAtV_GNKGBDGzH2Lb7JUHvyH7TUqWbDZIObr34,789
207
- ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=sjpYeqgdbDmD7lhp80yc8jfWq-HxX3xuQ58ND8ZeU-I,2213
206
+ ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=3A_lMyj-B-DOhLJG6WmjKvZK5te2rXje8FrfqOhZsN0,959
207
+ ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=JIMCn_fNh5-PgcV5qcklD7aFj0RhNKlvnZ-XQFCOszc,2706
208
208
  ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
209
209
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
210
210
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
@@ -229,8 +229,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
229
229
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
230
230
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
231
231
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
232
- ai_edge_torch_nightly-0.3.0.dev20250211.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
233
- ai_edge_torch_nightly-0.3.0.dev20250211.dist-info/METADATA,sha256=wT5v4PcnaE4IJmA-3d1W-0U91ETCBsWZyaT8VWiVI-c,1966
234
- ai_edge_torch_nightly-0.3.0.dev20250211.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
235
- ai_edge_torch_nightly-0.3.0.dev20250211.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
236
- ai_edge_torch_nightly-0.3.0.dev20250211.dist-info/RECORD,,
232
+ ai_edge_torch_nightly-0.3.0.dev20250212.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
233
+ ai_edge_torch_nightly-0.3.0.dev20250212.dist-info/METADATA,sha256=uCdhC0TvTD1j2cNaIsI1l_uB_FVHMml_jDr2vUzzKHw,1966
234
+ ai_edge_torch_nightly-0.3.0.dev20250212.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
235
+ ai_edge_torch_nightly-0.3.0.dev20250212.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
236
+ ai_edge_torch_nightly-0.3.0.dev20250212.dist-info/RECORD,,