ai-edge-torch-nightly 0.2.0.dev20240723__py3-none-any.whl → 0.2.0.dev20240724__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +16 -1
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -0
- {ai_edge_torch_nightly-0.2.0.dev20240723.dist-info → ai_edge_torch_nightly-0.2.0.dev20240724.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240723.dist-info → ai_edge_torch_nightly-0.2.0.dev20240724.dist-info}/RECORD +7 -7
- {ai_edge_torch_nightly-0.2.0.dev20240723.dist-info → ai_edge_torch_nightly-0.2.0.dev20240724.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240723.dist-info → ai_edge_torch_nightly-0.2.0.dev20240724.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240723.dist-info → ai_edge_torch_nightly-0.2.0.dev20240724.dist-info}/top_level.txt +0 -0
|
@@ -113,6 +113,10 @@ def is_4d(node: Node):
|
|
|
113
113
|
val = node.meta.get("val")
|
|
114
114
|
if val is None:
|
|
115
115
|
return False
|
|
116
|
+
|
|
117
|
+
if isinstance(val, (list, tuple)) and val:
|
|
118
|
+
val = val[0]
|
|
119
|
+
|
|
116
120
|
if not hasattr(val, "shape"):
|
|
117
121
|
return False
|
|
118
122
|
|
|
@@ -168,7 +172,6 @@ def _all_layout_sensitive_inputs_are_4d_checker(node: Node):
|
|
|
168
172
|
|
|
169
173
|
|
|
170
174
|
@nhwcable_node_checkers.register(aten._native_batch_norm_legit_no_training)
|
|
171
|
-
@nhwcable_node_checkers.register(aten.native_group_norm)
|
|
172
175
|
def _aten_norm_checker(node):
|
|
173
176
|
val = node.meta.get("val")
|
|
174
177
|
if not isinstance(val, (list, tuple)) or not val or not hasattr(val[0], "shape"):
|
|
@@ -176,6 +179,18 @@ def _aten_norm_checker(node):
|
|
|
176
179
|
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
|
|
177
180
|
|
|
178
181
|
|
|
182
|
+
@nhwcable_node_checkers.register(aten.native_group_norm)
|
|
183
|
+
def _aten_native_group_norm_checker(node):
|
|
184
|
+
val = node.meta.get("val")
|
|
185
|
+
if not isinstance(val, (list, tuple)) or not val or not hasattr(val[0], "shape"):
|
|
186
|
+
return NHWCable(can_be=False, must_be=False)
|
|
187
|
+
if len(node.args) >= 3 and (node.args[1] is not None or node.args[2] is not None):
|
|
188
|
+
# Disable NHWC rewriter due to precision issue with weight and bias.
|
|
189
|
+
# TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
|
|
190
|
+
return NHWCable(can_be=False, must_be=False)
|
|
191
|
+
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
|
|
192
|
+
|
|
193
|
+
|
|
179
194
|
# ==== Ops must be NCHW
|
|
180
195
|
|
|
181
196
|
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
import operator
|
|
15
16
|
import os
|
|
16
17
|
from typing import Optional, Tuple, Union
|
|
17
18
|
|
|
@@ -274,6 +275,14 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
|
|
|
274
275
|
graph_module = layout_partitioners.greedy.partition(graph_module)
|
|
275
276
|
|
|
276
277
|
graph = graph_module.graph
|
|
278
|
+
for node in list(graph.nodes):
|
|
279
|
+
if node.target == operator.getitem:
|
|
280
|
+
# force the layout mark of a getitem node to follow its producer.
|
|
281
|
+
if layout_mark.is_nchw_node(node.args[0]):
|
|
282
|
+
layout_mark.mark_as_nchw_node(node)
|
|
283
|
+
else:
|
|
284
|
+
layout_mark.mark_as_nhwc_node(node)
|
|
285
|
+
|
|
277
286
|
for node in list(graph.nodes):
|
|
278
287
|
if layout_mark.is_nhwc_node(node):
|
|
279
288
|
for input_node in layout_check.get_layout_sensitive_inputs(node):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.2.0.
|
|
3
|
+
Version: 0.2.0.dev20240724
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -12,11 +12,11 @@ ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py,sha256=6m_vc
|
|
|
12
12
|
ai_edge_torch/convert/fx_passes/canonicalize_pass.py,sha256=UX6dJsxCqSkftXXvNBV-i7Bjk6H7qTyqzUnE640Itfg,1673
|
|
13
13
|
ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=aRT8hTS3n9ie28lgu6mygtFO6Ypwu0qjNb0c81v9HLs,2448
|
|
14
14
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=VA9bekxPVhLk4MYlIRXnOzrSnbCtUmGj7OQ_fJcKQtc,795
|
|
15
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=
|
|
15
|
+
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=kV4Md7YBQIS-A8Dp4O-8SugNHAfDjVIBHvnldVPpHV0,7483
|
|
16
16
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=uXCHC23pWN-3JmDtAErWbSUnL8jjlQgUAy4gqtfDsQU,1560
|
|
17
17
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=RAgU31B98PQmXEIM3GOjgS0q9aRe2whJhGXpW2EjoqY,12438
|
|
18
18
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=tCx7J-WIFnxFCeRBtqJ159jWLgK9_9DCJrR4mkeBuYE,982
|
|
19
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=
|
|
19
|
+
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=7SuNFBfQWt42tEcUUxpQVnWhy-ByMuLHi2VtH8Kc1g4,10708
|
|
20
20
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=bItkXVaPA9THcFypAmqldpkLuD8WpOFmKlhVbBJJkPk,2076
|
|
21
21
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=B-zisphkH7aRCUOJNdwHnTA0fQXuDpN08q3Qjy5bL6E,715
|
|
22
22
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=8uHJbIwPMTgeSfYVba163pkXSQkHLxFwar_8A1AhgAM,2279
|
|
@@ -125,8 +125,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
|
|
|
125
125
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
126
126
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
127
127
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=kzIulTldq8R9E-lAZsvfSTvLu3FYEX7b9DyYM3qisXM,4485
|
|
128
|
-
ai_edge_torch_nightly-0.2.0.
|
|
129
|
-
ai_edge_torch_nightly-0.2.0.
|
|
130
|
-
ai_edge_torch_nightly-0.2.0.
|
|
131
|
-
ai_edge_torch_nightly-0.2.0.
|
|
132
|
-
ai_edge_torch_nightly-0.2.0.
|
|
128
|
+
ai_edge_torch_nightly-0.2.0.dev20240724.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
129
|
+
ai_edge_torch_nightly-0.2.0.dev20240724.dist-info/METADATA,sha256=pdYXzmVJ02-GQnxWNMHtYvAq4LaTOpfFREjbRcpNSmI,1745
|
|
130
|
+
ai_edge_torch_nightly-0.2.0.dev20240724.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
131
|
+
ai_edge_torch_nightly-0.2.0.dev20240724.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
132
|
+
ai_edge_torch_nightly-0.2.0.dev20240724.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|