ai-edge-torch-nightly 0.3.0.dev20250211__py3-none-any.whl → 0.3.0.dev20250212__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -56,11 +56,6 @@ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
56
56
  1280,
57
57
  'The maximum size of KV cache buffer, including both prefill and decode.',
58
58
  )
59
- _PIXEL_VALUES_SIZE = flags.DEFINE_multi_integer(
60
- 'pixel_values_size',
61
- [3, 224, 224],
62
- 'The size of prefill pixel values except the batch dimension.',
63
- )
64
59
  _QUANTIZE = flags.DEFINE_bool(
65
60
  'quantize',
66
61
  True,
@@ -75,12 +70,15 @@ def main(_):
75
70
  kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
76
71
  )
77
72
 
73
+ config = pytorch_model.image_encoder.config.image_embedding
78
74
  converter.convert_to_tflite(
79
75
  pytorch_model,
80
76
  output_path=_OUTPUT_PATH.value,
81
77
  output_name_prefix=f'{_OUTPUT_NAME_PREFIX.value}_{_VERSION.value}',
82
78
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
83
- pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
79
+ pixel_values_size=torch.Size(
80
+ [1, config.channels, config.image_size, config.image_size]
81
+ ),
84
82
  quantize=_QUANTIZE.value,
85
83
  config=pytorch_model.config.decoder_config,
86
84
  export_config=ExportConfig(),
@@ -136,9 +136,7 @@ def get_image_encoder_config() -> cfg.ModelConfig:
136
136
  image_embedding=image_embedding_config,
137
137
  block_configs=block_config,
138
138
  final_norm_config=norm_config,
139
- # TODO: b/377051577 - Once RemoveSDPACompositeZeroMaskPass is removed,
140
- # enable_hlfb can be set to True. See b/383865404#comment3 for details.
141
- # enable_hlfb=True,
139
+ enable_hlfb=True,
142
140
  )
143
141
  return config
144
142
 
@@ -145,7 +145,7 @@ def _export_helper(
145
145
  prefill_input_pos_list.append(torch.arange(0, seq_len, dtype=torch.int))
146
146
 
147
147
  prefill_pixel_values = (
148
- torch.full((1,) + pixel_values_size, 0, dtype=torch.float32)
148
+ torch.full(pixel_values_size, 0, dtype=torch.float32)
149
149
  if pixel_values_size
150
150
  else None
151
151
  )
@@ -12,5 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from ._build import build_mlir_debuginfo, build_mlir_file_debuginfo
15
+ """Debug info generation for ODML Torch."""
16
+
17
+ from . import _build
16
18
  from ._op_polyfill import write_mlir_debuginfo_op
19
+
20
+ build_nodename_debuginfo = _build.build_nodename_debuginfo
21
+ build_mlir_file_debuginfo = _build.build_mlir_file_debuginfo
22
+ build_mlir_debuginfo = _build.build_mlir_debuginfo
@@ -40,10 +40,16 @@ def _get_canonical_filename(filename):
40
40
 
41
41
  This should be factored out so that pattern is a global option that a user
42
42
  can override.
43
+
44
+ Args:
45
+ filename: The filename to canonicalize.
46
+
47
+ Returns:
48
+ The canonicalized filename.
43
49
  """
44
50
 
45
- # TODO: We should add a config option to provide a regex to strip from the
46
- # debug info. Currently absolute path is used.
51
+ # TODO(yijieyang): We should add a config option to provide a regex to strip
52
+ # from the debug info. Currently absolute path is used.
47
53
  return filename
48
54
 
49
55
 
@@ -55,9 +61,23 @@ def build_mlir_file_debuginfo(node: torch.fx.Node):
55
61
 
56
62
  # Note: This uses internal APIs and may break in the future.
57
63
  pt_trace = torch.fx.graph._parse_stack_trace(node.stack_trace)
64
+ if pt_trace is None:
65
+ return None, None
58
66
  return _get_canonical_filename(pt_trace.file), int(pt_trace.lineno)
59
67
 
60
68
 
69
+ def build_nodename_debuginfo(node: torch.fx.Node):
70
+ """Build the fx node name for the given node's lowerings in MLIR."""
71
+ history = node.meta.get("from_node", [])
72
+ if not history:
73
+ return None
74
+ if len(history) > 1:
75
+ return history[1][0]
76
+ if hasattr(history[0], "name"): # torch 2.6.0+
77
+ return history[0].name
78
+ return None
79
+
80
+
61
81
  def build_mlir_debuginfo(node: torch.fx.Node):
62
82
  """Build the debuginfo string for the given node's lowerings in MLIR."""
63
83
 
@@ -88,6 +88,21 @@ class LoweringInterpreter(torch.fx.Interpreter):
88
88
  self.outputs = None
89
89
 
90
90
  def _build_loc(self, node: torch.fx.Node):
91
+ """Build MLIR location for the given node.
92
+
93
+ The location contains:
94
+ - layer info
95
+ - fx node name
96
+ - file and line info
97
+
98
+ Currently it's still under development and format is subject to change.
99
+
100
+ Args:
101
+ node: The torch.fx.Node to build the location for.
102
+
103
+ Returns:
104
+ The MLIR location for the given node.
105
+ """
91
106
 
92
107
  info = debuginfo.build_mlir_debuginfo(node)
93
108
  if info is None:
@@ -98,7 +113,12 @@ class LoweringInterpreter(torch.fx.Interpreter):
98
113
  if file is not None:
99
114
  fileinfo = ir.Location.file(filename=file, line=line, col=0)
100
115
 
101
- return ir.Location.name(name=info, childLoc=fileinfo)
116
+ node_name = debuginfo.build_nodename_debuginfo(node)
117
+ nodeinfo = None
118
+ if node_name is not None:
119
+ nodeinfo = ir.Location.name(name=node_name, childLoc=fileinfo)
120
+
121
+ return ir.Location.name(name=info, childLoc=nodeinfo)
102
122
 
103
123
  def run_node(self, node: torch.fx.Node):
104
124
  loc = self._build_loc(node)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250211"
16
+ __version__ = "0.3.0.dev20250212"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250211
3
+ Version: 0.3.0.dev20250212
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=SqHTGIWZweve6ni4shLmLl85BwAAptNTCG646yd6Hk8,706
5
+ ai_edge_torch/version.py,sha256=gYCO29a18zxFgm6z4IHCArhNE29IvlL_sxhEL6O4ECw,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -73,10 +73,10 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=pyxRGgMxrn
73
73
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
74
74
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
75
75
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
76
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=a6ISb96xhEJc1TtaFGCUiA4msKedPTAeMvkWrfIklx4,2792
76
+ ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=APQymtr3n2k6-e8wvn3kVrli0qiElduYIkHeahcoSA0,2743
77
77
  ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=z658dW_D0Iqvo6xnh4vG7_o17-Fufndyis8Rq5yafJY,5439
78
78
  ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=GZa0Ou_DvOijB2nTL_jRvGbn0_dvJPosQAPf47yqicw,5988
79
- ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=7K1xl64UvoHaYmqWjIbahwXHfppwTQ8sN7JrpGKX1XQ,5771
79
+ ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=SvuR97sjkBtfkerH7Hu1UXB8kCFLpEATNbPfCbNAyfo,5614
80
80
  ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=x1mgRtVLxkCTvlkPow3y7ADoGTjUh5uc5pF46mxatLw,6099
81
81
  ai_edge_torch/generative/examples/paligemma/verify.py,sha256=HLcu1fWMtFFFONAqVW94rOBqq4XvFHtatX3JFGOsfZw,5345
82
82
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
@@ -172,7 +172,7 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728Fc
172
172
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
173
173
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
174
174
  ai_edge_torch/generative/utilities/bmm_4d.py,sha256=2BMOYiFVUsl-bjxmLkrX4N7kpO0CnhB7eDYxm_iBCr8,2533
175
- ai_edge_torch/generative/utilities/converter.py,sha256=6siSpCvH_cLV-eP40lkF_AqjBpYv68xeMRQ722fKgE0,8065
175
+ ai_edge_torch/generative/utilities/converter.py,sha256=K9taR0KY59dvfU_jO1yBe_p7w8lDns1Q3U6oJTTKZzM,8058
176
176
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
177
177
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
178
178
  ai_edge_torch/generative/utilities/model_builder.py,sha256=5WqcxpeTdt51nVoUwt9g5kKB5wQKj2eYbiaz7k6Ofxg,6815
@@ -197,14 +197,14 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1
197
197
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
198
198
  ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
199
199
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
200
- ai_edge_torch/odml_torch/export.py,sha256=LDyZUehM1lmT3y2bGeA94rMGRUTLxzIUm4DTlCA8tQc,13640
200
+ ai_edge_torch/odml_torch/export.py,sha256=7l8R0DEq_vfns8iWpruMlIyaIKZAFzoAy369-7iRrl0,14164
201
201
  ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
202
202
  ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
203
203
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
204
204
  ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
205
205
  ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
206
- ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=DoE3HgAtV_GNKGBDGzH2Lb7JUHvyH7TUqWbDZIObr34,789
207
- ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=sjpYeqgdbDmD7lhp80yc8jfWq-HxX3xuQ58ND8ZeU-I,2213
206
+ ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=3A_lMyj-B-DOhLJG6WmjKvZK5te2rXje8FrfqOhZsN0,959
207
+ ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=JIMCn_fNh5-PgcV5qcklD7aFj0RhNKlvnZ-XQFCOszc,2706
208
208
  ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
209
209
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
210
210
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
@@ -229,8 +229,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
229
229
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
230
230
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
231
231
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
232
- ai_edge_torch_nightly-0.3.0.dev20250211.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
233
- ai_edge_torch_nightly-0.3.0.dev20250211.dist-info/METADATA,sha256=wT5v4PcnaE4IJmA-3d1W-0U91ETCBsWZyaT8VWiVI-c,1966
234
- ai_edge_torch_nightly-0.3.0.dev20250211.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
235
- ai_edge_torch_nightly-0.3.0.dev20250211.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
236
- ai_edge_torch_nightly-0.3.0.dev20250211.dist-info/RECORD,,
232
+ ai_edge_torch_nightly-0.3.0.dev20250212.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
233
+ ai_edge_torch_nightly-0.3.0.dev20250212.dist-info/METADATA,sha256=uCdhC0TvTD1j2cNaIsI1l_uB_FVHMml_jDr2vUzzKHw,1966
234
+ ai_edge_torch_nightly-0.3.0.dev20250212.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
235
+ ai_edge_torch_nightly-0.3.0.dev20250212.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
236
+ ai_edge_torch_nightly-0.3.0.dev20250212.dist-info/RECORD,,