ai-edge-torch-nightly 0.4.0.dev20250221__py3-none-any.whl → 0.4.0.dev20250222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ import ai_edge_torch
16
17
  import ai_edge_torch.generative.layers.builder as layers_builder
17
18
  import ai_edge_torch.generative.layers.model_config as layers_cfg
18
19
  from ai_edge_torch.generative.layers.unet import blocks_2d
@@ -281,7 +282,10 @@ def get_model_config(device_type: str = "cpu") -> unet_cfg.AutoEncoderConfig:
281
282
 
282
283
  # For now, only turns on StableHLO composite ops on GPU backend for better
283
284
  # performance. CPU should also switch to it once the support is done.
284
- enable_hlfb = True if device_type == "gpu" else False
285
+ enable_hlfb = False
286
+ if device_type == "gpu":
287
+ ai_edge_torch.config.enable_group_norm_composite = True
288
+ enable_hlfb = True
285
289
 
286
290
  norm_config = layers_cfg.NormalizationConfig(
287
291
  layers_cfg.NormalizationType.GROUP_NORM,
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ import ai_edge_torch
16
17
  import ai_edge_torch.generative.layers.builder as layers_builder
17
18
  import ai_edge_torch.generative.layers.model_config as layers_cfg
18
19
  from ai_edge_torch.generative.layers.unet import blocks_2d
@@ -601,7 +602,10 @@ def get_model_config(
601
602
 
602
603
  # For now, only turns on StableHLO composite ops on GPU backend for better
603
604
  # performance. CPU should also switch to it once the support is done.
604
- enable_hlfb = True if device_type == "gpu" else False
605
+ enable_hlfb = False
606
+ if device_type == "gpu":
607
+ ai_edge_torch.config.enable_group_norm_composite = True
608
+ enable_hlfb = True
605
609
 
606
610
  # Residual configs.
607
611
  residual_norm_config = layers_cfg.NormalizationConfig(
@@ -12,8 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- import torch
16
15
  import re
16
+ import torch
17
17
 
18
18
 
19
19
  def _class_fullname(cls):
@@ -35,7 +35,7 @@ def _get_hierarchy(node: torch.fx.Node):
35
35
  return hierachy_str
36
36
 
37
37
 
38
- def _get_canonical_filename(filename):
38
+ def _get_canonical_filename(filename: str):
39
39
  """Remove unnecessary path prefix to make the filename more readable.
40
40
 
41
41
  This should be factored out so that pattern is a global option that a user
@@ -53,6 +53,29 @@ def _get_canonical_filename(filename):
53
53
  return filename
54
54
 
55
55
 
56
+ def _get_canoical_nodename(node: torch.fx.Node) -> str:
57
+ """Get the canonical node name from the node's history."""
58
+
59
+ history = node.meta.get("from_node", [])
60
+
61
+ if len(history) > 1: # Compatible with torch version under 2.6.0
62
+ return history[1][0]
63
+
64
+ if not hasattr(history[0], "name"):
65
+ return None
66
+ names = []
67
+ while history:
68
+ names.append(history[0].name)
69
+ history = history[0].from_node
70
+
71
+ # Based on the experiment, the third to last name in the history stack
72
+ # can be mapped to the original torch node name. The history stack is
73
+ # generated by tracing the node's transformation history during lowering.
74
+ if len(names) > 2:
75
+ return names[-3]
76
+ return None
77
+
78
+
56
79
  def build_mlir_file_debuginfo(node: torch.fx.Node):
57
80
  """Build the file and line info for the given node's lowerings in MLIR."""
58
81
 
@@ -66,16 +89,13 @@ def build_mlir_file_debuginfo(node: torch.fx.Node):
66
89
  return _get_canonical_filename(pt_trace.file), int(pt_trace.lineno)
67
90
 
68
91
 
69
- def build_nodename_debuginfo(node: torch.fx.Node):
92
+ def build_nodename_debuginfo(node: torch.fx.Node) -> str:
70
93
  """Build the fx node name for the given node's lowerings in MLIR."""
71
- history = node.meta.get("from_node", [])
72
- if not history:
94
+
95
+ if not hasattr(node, "meta") or "from_node" not in node.meta:
73
96
  return None
74
- if len(history) > 1:
75
- return history[1][0]
76
- if hasattr(history[0], "name"): # torch 2.6.0+
77
- return history[0].name
78
- return None
97
+
98
+ return _get_canoical_nodename(node)
79
99
 
80
100
 
81
101
  def build_mlir_debuginfo(node: torch.fx.Node):
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.4.0.dev20250221"
16
+ __version__ = "0.4.0.dev20250222"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.4.0.dev20250221
3
+ Version: 0.4.0.dev20250222
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=6xZmxpJoHLCndJcctziwhewp4ss--Bi2fe_YyGL2Vag,706
5
+ ai_edge_torch/version.py,sha256=3V323rXqU4LWg1Um9kHwmrGlIkEIet9NlHpxI3IJNyw,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -110,8 +110,8 @@ ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R
110
110
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
111
111
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=5M4auM33SgCTODt0VT8TO-EVILruqGDRiNILBPeB83Y,6072
112
112
  ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=_yk6wVoZm1_FRMFJF5URaPZNNdmMR89fwmKz81BEyao,5601
113
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=sQKQ-k6H9kG2brgwLsktjCMeN2h0POyfMP6iNsPNKWc,16271
114
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=6W58LxmHHkz2ctgpknQkyoDANZAnE9Byp_svfqLpQf0,34793
113
+ ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=afyHXc86h-ij5zTULmZnM1h313N9VWCyIVriH6pqeSo,16368
114
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=ylqXOZhYc6XFCaNBKQw0jAnYrCtRFFQKzQzEsFIntvo,34890
115
115
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
116
116
  ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=GnY3vPZ-obrWuJifuE5bUooKLqAI7v6q71oaTuLKeBE,8778
117
117
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
@@ -205,7 +205,7 @@ ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xK
205
205
  ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
206
206
  ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
207
207
  ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=3A_lMyj-B-DOhLJG6WmjKvZK5te2rXje8FrfqOhZsN0,959
208
- ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=JIMCn_fNh5-PgcV5qcklD7aFj0RhNKlvnZ-XQFCOszc,2706
208
+ ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=k9Kas790kpMS5OrVcLzIr48ejAzcc2smrroKAHHM7TQ,3311
209
209
  ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
210
210
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
211
211
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
@@ -230,8 +230,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
230
230
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
231
231
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
232
232
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
233
- ai_edge_torch_nightly-0.4.0.dev20250221.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
234
- ai_edge_torch_nightly-0.4.0.dev20250221.dist-info/METADATA,sha256=0V4eezenEyB4Ig9f76FMBy6p1MF3p9Ky--dPmZplYyM,1966
235
- ai_edge_torch_nightly-0.4.0.dev20250221.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
236
- ai_edge_torch_nightly-0.4.0.dev20250221.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
237
- ai_edge_torch_nightly-0.4.0.dev20250221.dist-info/RECORD,,
233
+ ai_edge_torch_nightly-0.4.0.dev20250222.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
234
+ ai_edge_torch_nightly-0.4.0.dev20250222.dist-info/METADATA,sha256=nMXtCeRH62h6fhvV2GuzHlG0TDmvPAP_W9KQF6lVsc0,1966
235
+ ai_edge_torch_nightly-0.4.0.dev20250222.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
236
+ ai_edge_torch_nightly-0.4.0.dev20250222.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
237
+ ai_edge_torch_nightly-0.4.0.dev20250222.dist-info/RECORD,,