ai-edge-torch-nightly 0.2.0.dev20240720__py3-none-any.whl → 0.2.0.dev20240725__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +16 -1
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +6 -9
- {ai_edge_torch_nightly-0.2.0.dev20240720.dist-info → ai_edge_torch_nightly-0.2.0.dev20240725.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240720.dist-info → ai_edge_torch_nightly-0.2.0.dev20240725.dist-info}/RECORD +8 -8
- {ai_edge_torch_nightly-0.2.0.dev20240720.dist-info → ai_edge_torch_nightly-0.2.0.dev20240725.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240720.dist-info → ai_edge_torch_nightly-0.2.0.dev20240725.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240720.dist-info → ai_edge_torch_nightly-0.2.0.dev20240725.dist-info}/top_level.txt +0 -0
|
@@ -113,6 +113,10 @@ def is_4d(node: Node):
|
|
|
113
113
|
val = node.meta.get("val")
|
|
114
114
|
if val is None:
|
|
115
115
|
return False
|
|
116
|
+
|
|
117
|
+
if isinstance(val, (list, tuple)) and val:
|
|
118
|
+
val = val[0]
|
|
119
|
+
|
|
116
120
|
if not hasattr(val, "shape"):
|
|
117
121
|
return False
|
|
118
122
|
|
|
@@ -168,7 +172,6 @@ def _all_layout_sensitive_inputs_are_4d_checker(node: Node):
|
|
|
168
172
|
|
|
169
173
|
|
|
170
174
|
@nhwcable_node_checkers.register(aten._native_batch_norm_legit_no_training)
|
|
171
|
-
@nhwcable_node_checkers.register(aten.native_group_norm)
|
|
172
175
|
def _aten_norm_checker(node):
|
|
173
176
|
val = node.meta.get("val")
|
|
174
177
|
if not isinstance(val, (list, tuple)) or not val or not hasattr(val[0], "shape"):
|
|
@@ -176,6 +179,18 @@ def _aten_norm_checker(node):
|
|
|
176
179
|
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
|
|
177
180
|
|
|
178
181
|
|
|
182
|
+
@nhwcable_node_checkers.register(aten.native_group_norm)
|
|
183
|
+
def _aten_native_group_norm_checker(node):
|
|
184
|
+
val = node.meta.get("val")
|
|
185
|
+
if not isinstance(val, (list, tuple)) or not val or not hasattr(val[0], "shape"):
|
|
186
|
+
return NHWCable(can_be=False, must_be=False)
|
|
187
|
+
if len(node.args) >= 3 and (node.args[1] is not None or node.args[2] is not None):
|
|
188
|
+
# Disable NHWC rewriter due to precision issue with weight and bias.
|
|
189
|
+
# TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
|
|
190
|
+
return NHWCable(can_be=False, must_be=False)
|
|
191
|
+
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
|
|
192
|
+
|
|
193
|
+
|
|
179
194
|
# ==== Ops must be NCHW
|
|
180
195
|
|
|
181
196
|
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
import operator
|
|
15
16
|
import os
|
|
16
17
|
from typing import Optional, Tuple, Union
|
|
17
18
|
|
|
@@ -274,6 +275,14 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
|
|
|
274
275
|
graph_module = layout_partitioners.greedy.partition(graph_module)
|
|
275
276
|
|
|
276
277
|
graph = graph_module.graph
|
|
278
|
+
for node in list(graph.nodes):
|
|
279
|
+
if node.target == operator.getitem:
|
|
280
|
+
# force the layout mark of a getitem node to follow its producer.
|
|
281
|
+
if layout_mark.is_nchw_node(node.args[0]):
|
|
282
|
+
layout_mark.mark_as_nchw_node(node)
|
|
283
|
+
else:
|
|
284
|
+
layout_mark.mark_as_nhwc_node(node)
|
|
285
|
+
|
|
277
286
|
for node in list(graph.nodes):
|
|
278
287
|
if layout_mark.is_nhwc_node(node):
|
|
279
288
|
for input_node in layout_check.get_layout_sensitive_inputs(node):
|
|
@@ -125,10 +125,9 @@ class AttentionBlock2D(nn.Module):
|
|
|
125
125
|
"""
|
|
126
126
|
residual = input_tensor
|
|
127
127
|
B, C, H, W = input_tensor.shape
|
|
128
|
-
x = input_tensor
|
|
129
128
|
if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
|
|
130
|
-
x = self.norm(
|
|
131
|
-
x =
|
|
129
|
+
x = self.norm(input_tensor)
|
|
130
|
+
x = x.view(B, C, H * W)
|
|
132
131
|
x = x.transpose(-1, -2)
|
|
133
132
|
else:
|
|
134
133
|
x = input_tensor.view(B, C, H * W)
|
|
@@ -181,10 +180,9 @@ class CrossAttentionBlock2D(nn.Module):
|
|
|
181
180
|
"""
|
|
182
181
|
residual = input_tensor
|
|
183
182
|
B, C, H, W = input_tensor.shape
|
|
184
|
-
x = input_tensor
|
|
185
183
|
if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
|
|
186
|
-
x = self.norm(
|
|
187
|
-
x =
|
|
184
|
+
x = self.norm(input_tensor)
|
|
185
|
+
x = x.view(B, C, H * W)
|
|
188
186
|
x = x.transpose(-1, -2)
|
|
189
187
|
else:
|
|
190
188
|
x = input_tensor.view(B, C, H * W)
|
|
@@ -222,10 +220,9 @@ class FeedForwardBlock2D(nn.Module):
|
|
|
222
220
|
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
223
221
|
residual = input_tensor
|
|
224
222
|
B, C, H, W = input_tensor.shape
|
|
225
|
-
x = input_tensor
|
|
226
223
|
if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
|
|
227
|
-
x = self.norm(
|
|
228
|
-
x =
|
|
224
|
+
x = self.norm(input_tensor)
|
|
225
|
+
x = x.view(B, C, H * W)
|
|
229
226
|
x = x.transpose(-1, -2)
|
|
230
227
|
else:
|
|
231
228
|
x = input_tensor.view(B, C, H * W)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.2.0.
|
|
3
|
+
Version: 0.2.0.dev20240725
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -12,11 +12,11 @@ ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py,sha256=6m_vc
|
|
|
12
12
|
ai_edge_torch/convert/fx_passes/canonicalize_pass.py,sha256=UX6dJsxCqSkftXXvNBV-i7Bjk6H7qTyqzUnE640Itfg,1673
|
|
13
13
|
ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=aRT8hTS3n9ie28lgu6mygtFO6Ypwu0qjNb0c81v9HLs,2448
|
|
14
14
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=VA9bekxPVhLk4MYlIRXnOzrSnbCtUmGj7OQ_fJcKQtc,795
|
|
15
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=
|
|
15
|
+
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=kV4Md7YBQIS-A8Dp4O-8SugNHAfDjVIBHvnldVPpHV0,7483
|
|
16
16
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=uXCHC23pWN-3JmDtAErWbSUnL8jjlQgUAy4gqtfDsQU,1560
|
|
17
17
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=RAgU31B98PQmXEIM3GOjgS0q9aRe2whJhGXpW2EjoqY,12438
|
|
18
18
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=tCx7J-WIFnxFCeRBtqJ159jWLgK9_9DCJrR4mkeBuYE,982
|
|
19
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=
|
|
19
|
+
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=7SuNFBfQWt42tEcUUxpQVnWhy-ByMuLHi2VtH8Kc1g4,10708
|
|
20
20
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=bItkXVaPA9THcFypAmqldpkLuD8WpOFmKlhVbBJJkPk,2076
|
|
21
21
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=B-zisphkH7aRCUOJNdwHnTA0fQXuDpN08q3Qjy5bL6E,715
|
|
22
22
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=8uHJbIwPMTgeSfYVba163pkXSQkHLxFwar_8A1AhgAM,2279
|
|
@@ -90,7 +90,7 @@ ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0l
|
|
|
90
90
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
|
|
91
91
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
|
|
92
92
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
93
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
|
93
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=4et0TLBtAyVlYxMSJi3-oQoO5npFkOzcCYA927dvm_8,26475
|
|
94
94
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=NmJiZ2-e1wbv9jnvI3VCyUJlONV5ZAOz-RTc7ipAZ5U,1872
|
|
95
95
|
ai_edge_torch/generative/layers/unet/model_config.py,sha256=GU12QEJwO6ukveMR9JRsrhE0YIPKuhk1U81CylmOQTA,9097
|
|
96
96
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
@@ -125,8 +125,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
|
|
|
125
125
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
126
126
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
127
127
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=kzIulTldq8R9E-lAZsvfSTvLu3FYEX7b9DyYM3qisXM,4485
|
|
128
|
-
ai_edge_torch_nightly-0.2.0.
|
|
129
|
-
ai_edge_torch_nightly-0.2.0.
|
|
130
|
-
ai_edge_torch_nightly-0.2.0.
|
|
131
|
-
ai_edge_torch_nightly-0.2.0.
|
|
132
|
-
ai_edge_torch_nightly-0.2.0.
|
|
128
|
+
ai_edge_torch_nightly-0.2.0.dev20240725.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
129
|
+
ai_edge_torch_nightly-0.2.0.dev20240725.dist-info/METADATA,sha256=fwnI0u8Q4MJNeahUsCYe34o4YaeoRBJrUieJUk5ogY8,1745
|
|
130
|
+
ai_edge_torch_nightly-0.2.0.dev20240725.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
131
|
+
ai_edge_torch_nightly-0.2.0.dev20240725.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
132
|
+
ai_edge_torch_nightly-0.2.0.dev20240725.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|