ai-edge-torch-nightly 0.2.0.dev20240720__py3-none-any.whl → 0.2.0.dev20240725__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -113,6 +113,10 @@ def is_4d(node: Node):
113
113
  val = node.meta.get("val")
114
114
  if val is None:
115
115
  return False
116
+
117
+ if isinstance(val, (list, tuple)) and val:
118
+ val = val[0]
119
+
116
120
  if not hasattr(val, "shape"):
117
121
  return False
118
122
 
@@ -168,7 +172,6 @@ def _all_layout_sensitive_inputs_are_4d_checker(node: Node):
168
172
 
169
173
 
170
174
  @nhwcable_node_checkers.register(aten._native_batch_norm_legit_no_training)
171
- @nhwcable_node_checkers.register(aten.native_group_norm)
172
175
  def _aten_norm_checker(node):
173
176
  val = node.meta.get("val")
174
177
  if not isinstance(val, (list, tuple)) or not val or not hasattr(val[0], "shape"):
@@ -176,6 +179,18 @@ def _aten_norm_checker(node):
176
179
  return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
177
180
 
178
181
 
182
+ @nhwcable_node_checkers.register(aten.native_group_norm)
183
+ def _aten_native_group_norm_checker(node):
184
+ val = node.meta.get("val")
185
+ if not isinstance(val, (list, tuple)) or not val or not hasattr(val[0], "shape"):
186
+ return NHWCable(can_be=False, must_be=False)
187
+ if len(node.args) >= 3 and (node.args[1] is not None or node.args[2] is not None):
188
+ # Disable NHWC rewriter due to precision issue with weight and bias.
189
+ # TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
190
+ return NHWCable(can_be=False, must_be=False)
191
+ return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
192
+
193
+
179
194
  # ==== Ops must be NCHW
180
195
 
181
196
 
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import operator
15
16
  import os
16
17
  from typing import Optional, Tuple, Union
17
18
 
@@ -274,6 +275,14 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
274
275
  graph_module = layout_partitioners.greedy.partition(graph_module)
275
276
 
276
277
  graph = graph_module.graph
278
+ for node in list(graph.nodes):
279
+ if node.target == operator.getitem:
280
+ # force the layout mark of a getitem node to follow its producer.
281
+ if layout_mark.is_nchw_node(node.args[0]):
282
+ layout_mark.mark_as_nchw_node(node)
283
+ else:
284
+ layout_mark.mark_as_nhwc_node(node)
285
+
277
286
  for node in list(graph.nodes):
278
287
  if layout_mark.is_nhwc_node(node):
279
288
  for input_node in layout_check.get_layout_sensitive_inputs(node):
@@ -125,10 +125,9 @@ class AttentionBlock2D(nn.Module):
125
125
  """
126
126
  residual = input_tensor
127
127
  B, C, H, W = input_tensor.shape
128
- x = input_tensor
129
128
  if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
130
- x = self.norm(x)
131
- x = input_tensor.view(B, C, H * W)
129
+ x = self.norm(input_tensor)
130
+ x = x.view(B, C, H * W)
132
131
  x = x.transpose(-1, -2)
133
132
  else:
134
133
  x = input_tensor.view(B, C, H * W)
@@ -181,10 +180,9 @@ class CrossAttentionBlock2D(nn.Module):
181
180
  """
182
181
  residual = input_tensor
183
182
  B, C, H, W = input_tensor.shape
184
- x = input_tensor
185
183
  if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
186
- x = self.norm(x)
187
- x = input_tensor.view(B, C, H * W)
184
+ x = self.norm(input_tensor)
185
+ x = x.view(B, C, H * W)
188
186
  x = x.transpose(-1, -2)
189
187
  else:
190
188
  x = input_tensor.view(B, C, H * W)
@@ -222,10 +220,9 @@ class FeedForwardBlock2D(nn.Module):
222
220
  def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
223
221
  residual = input_tensor
224
222
  B, C, H, W = input_tensor.shape
225
- x = input_tensor
226
223
  if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
227
- x = self.norm(x)
228
- x = input_tensor.view(B, C, H * W)
224
+ x = self.norm(input_tensor)
225
+ x = x.view(B, C, H * W)
229
226
  x = x.transpose(-1, -2)
230
227
  else:
231
228
  x = input_tensor.view(B, C, H * W)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240720
3
+ Version: 0.2.0.dev20240725
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -12,11 +12,11 @@ ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py,sha256=6m_vc
12
12
  ai_edge_torch/convert/fx_passes/canonicalize_pass.py,sha256=UX6dJsxCqSkftXXvNBV-i7Bjk6H7qTyqzUnE640Itfg,1673
13
13
  ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=aRT8hTS3n9ie28lgu6mygtFO6Ypwu0qjNb0c81v9HLs,2448
14
14
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=VA9bekxPVhLk4MYlIRXnOzrSnbCtUmGj7OQ_fJcKQtc,795
15
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=pG-zLvO5vGs3gjNXa3RxGNwvC-_Azei2anxe2VdKsnY,6870
15
+ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=kV4Md7YBQIS-A8Dp4O-8SugNHAfDjVIBHvnldVPpHV0,7483
16
16
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=uXCHC23pWN-3JmDtAErWbSUnL8jjlQgUAy4gqtfDsQU,1560
17
17
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=RAgU31B98PQmXEIM3GOjgS0q9aRe2whJhGXpW2EjoqY,12438
18
18
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=tCx7J-WIFnxFCeRBtqJ159jWLgK9_9DCJrR4mkeBuYE,982
19
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=cfY6RTWQTGXNoQxKHaDcBYR9QdkVQXOWjKhuxvglocw,10383
19
+ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=7SuNFBfQWt42tEcUUxpQVnWhy-ByMuLHi2VtH8Kc1g4,10708
20
20
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=bItkXVaPA9THcFypAmqldpkLuD8WpOFmKlhVbBJJkPk,2076
21
21
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=B-zisphkH7aRCUOJNdwHnTA0fQXuDpN08q3Qjy5bL6E,715
22
22
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=8uHJbIwPMTgeSfYVba163pkXSQkHLxFwar_8A1AhgAM,2279
@@ -90,7 +90,7 @@ ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0l
90
90
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
91
91
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
92
92
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
93
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=T70veX57CC9uNidwzoVGzOu-CwzcYMBr1Zk_0bq5UlM,26538
93
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=4et0TLBtAyVlYxMSJi3-oQoO5npFkOzcCYA927dvm_8,26475
94
94
  ai_edge_torch/generative/layers/unet/builder.py,sha256=NmJiZ2-e1wbv9jnvI3VCyUJlONV5ZAOz-RTc7ipAZ5U,1872
95
95
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=GU12QEJwO6ukveMR9JRsrhE0YIPKuhk1U81CylmOQTA,9097
96
96
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -125,8 +125,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
125
125
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
126
126
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
127
127
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=kzIulTldq8R9E-lAZsvfSTvLu3FYEX7b9DyYM3qisXM,4485
128
- ai_edge_torch_nightly-0.2.0.dev20240720.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
129
- ai_edge_torch_nightly-0.2.0.dev20240720.dist-info/METADATA,sha256=xkXzcnmvTzJRRNOJ2c8JnWS1ZCofdlZiKsW5sa5sDyM,1745
130
- ai_edge_torch_nightly-0.2.0.dev20240720.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
131
- ai_edge_torch_nightly-0.2.0.dev20240720.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
132
- ai_edge_torch_nightly-0.2.0.dev20240720.dist-info/RECORD,,
128
+ ai_edge_torch_nightly-0.2.0.dev20240725.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
129
+ ai_edge_torch_nightly-0.2.0.dev20240725.dist-info/METADATA,sha256=fwnI0u8Q4MJNeahUsCYe34o4YaeoRBJrUieJUk5ogY8,1745
130
+ ai_edge_torch_nightly-0.2.0.dev20240725.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
131
+ ai_edge_torch_nightly-0.2.0.dev20240725.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
132
+ ai_edge_torch_nightly-0.2.0.dev20240725.dist-info/RECORD,,