ai-edge-torch-nightly 0.4.0.dev20250221__py3-none-any.whl → 0.4.0.dev20250222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +5 -1
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +5 -1
- ai_edge_torch/odml_torch/debuginfo/_build.py +30 -10
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250221.dist-info → ai_edge_torch_nightly-0.4.0.dev20250222.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250221.dist-info → ai_edge_torch_nightly-0.4.0.dev20250222.dist-info}/RECORD +9 -9
- {ai_edge_torch_nightly-0.4.0.dev20250221.dist-info → ai_edge_torch_nightly-0.4.0.dev20250222.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250221.dist-info → ai_edge_torch_nightly-0.4.0.dev20250222.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250221.dist-info → ai_edge_torch_nightly-0.4.0.dev20250222.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
import ai_edge_torch
|
16
17
|
import ai_edge_torch.generative.layers.builder as layers_builder
|
17
18
|
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
18
19
|
from ai_edge_torch.generative.layers.unet import blocks_2d
|
@@ -281,7 +282,10 @@ def get_model_config(device_type: str = "cpu") -> unet_cfg.AutoEncoderConfig:
|
|
281
282
|
|
282
283
|
# For now, only turns on StableHLO composite ops on GPU backend for better
|
283
284
|
# performance. CPU should also switch to it once the support is done.
|
284
|
-
enable_hlfb =
|
285
|
+
enable_hlfb = False
|
286
|
+
if device_type == "gpu":
|
287
|
+
ai_edge_torch.config.enable_group_norm_composite = True
|
288
|
+
enable_hlfb = True
|
285
289
|
|
286
290
|
norm_config = layers_cfg.NormalizationConfig(
|
287
291
|
layers_cfg.NormalizationType.GROUP_NORM,
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
import ai_edge_torch
|
16
17
|
import ai_edge_torch.generative.layers.builder as layers_builder
|
17
18
|
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
18
19
|
from ai_edge_torch.generative.layers.unet import blocks_2d
|
@@ -601,7 +602,10 @@ def get_model_config(
|
|
601
602
|
|
602
603
|
# For now, only turns on StableHLO composite ops on GPU backend for better
|
603
604
|
# performance. CPU should also switch to it once the support is done.
|
604
|
-
enable_hlfb =
|
605
|
+
enable_hlfb = False
|
606
|
+
if device_type == "gpu":
|
607
|
+
ai_edge_torch.config.enable_group_norm_composite = True
|
608
|
+
enable_hlfb = True
|
605
609
|
|
606
610
|
# Residual configs.
|
607
611
|
residual_norm_config = layers_cfg.NormalizationConfig(
|
@@ -12,8 +12,8 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
import torch
|
16
15
|
import re
|
16
|
+
import torch
|
17
17
|
|
18
18
|
|
19
19
|
def _class_fullname(cls):
|
@@ -35,7 +35,7 @@ def _get_hierarchy(node: torch.fx.Node):
|
|
35
35
|
return hierachy_str
|
36
36
|
|
37
37
|
|
38
|
-
def _get_canonical_filename(filename):
|
38
|
+
def _get_canonical_filename(filename: str):
|
39
39
|
"""Remove unnecessary path prefix to make the filename more readable.
|
40
40
|
|
41
41
|
This should be factored out so that pattern is a global option that a user
|
@@ -53,6 +53,29 @@ def _get_canonical_filename(filename):
|
|
53
53
|
return filename
|
54
54
|
|
55
55
|
|
56
|
+
def _get_canoical_nodename(node: torch.fx.Node) -> str:
|
57
|
+
"""Get the canonical node name from the node's history."""
|
58
|
+
|
59
|
+
history = node.meta.get("from_node", [])
|
60
|
+
|
61
|
+
if len(history) > 1: # Compatible with torch version under 2.6.0
|
62
|
+
return history[1][0]
|
63
|
+
|
64
|
+
if not hasattr(history[0], "name"):
|
65
|
+
return None
|
66
|
+
names = []
|
67
|
+
while history:
|
68
|
+
names.append(history[0].name)
|
69
|
+
history = history[0].from_node
|
70
|
+
|
71
|
+
# Based on the experiment, the third to last name in the history stack
|
72
|
+
# can be mapped to the original torch node name. The history stack is
|
73
|
+
# generated by tracing the node's transformation history during lowering.
|
74
|
+
if len(names) > 2:
|
75
|
+
return names[-3]
|
76
|
+
return None
|
77
|
+
|
78
|
+
|
56
79
|
def build_mlir_file_debuginfo(node: torch.fx.Node):
|
57
80
|
"""Build the file and line info for the given node's lowerings in MLIR."""
|
58
81
|
|
@@ -66,16 +89,13 @@ def build_mlir_file_debuginfo(node: torch.fx.Node):
|
|
66
89
|
return _get_canonical_filename(pt_trace.file), int(pt_trace.lineno)
|
67
90
|
|
68
91
|
|
69
|
-
def build_nodename_debuginfo(node: torch.fx.Node):
|
92
|
+
def build_nodename_debuginfo(node: torch.fx.Node) -> str:
|
70
93
|
"""Build the fx node name for the given node's lowerings in MLIR."""
|
71
|
-
|
72
|
-
if not
|
94
|
+
|
95
|
+
if not hasattr(node, "meta") or "from_node" not in node.meta:
|
73
96
|
return None
|
74
|
-
|
75
|
-
|
76
|
-
if hasattr(history[0], "name"): # torch 2.6.0+
|
77
|
-
return history[0].name
|
78
|
-
return None
|
97
|
+
|
98
|
+
return _get_canoical_nodename(node)
|
79
99
|
|
80
100
|
|
81
101
|
def build_mlir_debuginfo(node: torch.fx.Node):
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.4.0.
|
3
|
+
Version: 0.4.0.dev20250222
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=3V323rXqU4LWg1Um9kHwmrGlIkEIet9NlHpxI3IJNyw,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -110,8 +110,8 @@ ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R
|
|
110
110
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
111
111
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=5M4auM33SgCTODt0VT8TO-EVILruqGDRiNILBPeB83Y,6072
|
112
112
|
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=_yk6wVoZm1_FRMFJF5URaPZNNdmMR89fwmKz81BEyao,5601
|
113
|
-
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=
|
114
|
-
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=
|
113
|
+
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=afyHXc86h-ij5zTULmZnM1h313N9VWCyIVriH6pqeSo,16368
|
114
|
+
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=ylqXOZhYc6XFCaNBKQw0jAnYrCtRFFQKzQzEsFIntvo,34890
|
115
115
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
|
116
116
|
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=GnY3vPZ-obrWuJifuE5bUooKLqAI7v6q71oaTuLKeBE,8778
|
117
117
|
ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
|
@@ -205,7 +205,7 @@ ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xK
|
|
205
205
|
ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
|
206
206
|
ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
|
207
207
|
ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=3A_lMyj-B-DOhLJG6WmjKvZK5te2rXje8FrfqOhZsN0,959
|
208
|
-
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=
|
208
|
+
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=k9Kas790kpMS5OrVcLzIr48ejAzcc2smrroKAHHM7TQ,3311
|
209
209
|
ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
|
210
210
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
|
211
211
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
@@ -230,8 +230,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
230
230
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
231
231
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
232
232
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
233
|
-
ai_edge_torch_nightly-0.4.0.
|
234
|
-
ai_edge_torch_nightly-0.4.0.
|
235
|
-
ai_edge_torch_nightly-0.4.0.
|
236
|
-
ai_edge_torch_nightly-0.4.0.
|
237
|
-
ai_edge_torch_nightly-0.4.0.
|
233
|
+
ai_edge_torch_nightly-0.4.0.dev20250222.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
234
|
+
ai_edge_torch_nightly-0.4.0.dev20250222.dist-info/METADATA,sha256=nMXtCeRH62h6fhvV2GuzHlG0TDmvPAP_W9KQF6lVsc0,1966
|
235
|
+
ai_edge_torch_nightly-0.4.0.dev20250222.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
236
|
+
ai_edge_torch_nightly-0.4.0.dev20250222.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
237
|
+
ai_edge_torch_nightly-0.4.0.dev20250222.dist-info/RECORD,,
|
File without changes
|
File without changes
|