ai-edge-torch-nightly 0.2.0.dev20240721__py3-none-any.whl → 0.2.0.dev20240726__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -113,6 +113,10 @@ def is_4d(node: Node):
113
113
  val = node.meta.get("val")
114
114
  if val is None:
115
115
  return False
116
+
117
+ if isinstance(val, (list, tuple)) and val:
118
+ val = val[0]
119
+
116
120
  if not hasattr(val, "shape"):
117
121
  return False
118
122
 
@@ -168,7 +172,6 @@ def _all_layout_sensitive_inputs_are_4d_checker(node: Node):
168
172
 
169
173
 
170
174
  @nhwcable_node_checkers.register(aten._native_batch_norm_legit_no_training)
171
- @nhwcable_node_checkers.register(aten.native_group_norm)
172
175
  def _aten_norm_checker(node):
173
176
  val = node.meta.get("val")
174
177
  if not isinstance(val, (list, tuple)) or not val or not hasattr(val[0], "shape"):
@@ -176,6 +179,18 @@ def _aten_norm_checker(node):
176
179
  return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
177
180
 
178
181
 
182
+ @nhwcable_node_checkers.register(aten.native_group_norm)
183
+ def _aten_native_group_norm_checker(node):
184
+ val = node.meta.get("val")
185
+ if not isinstance(val, (list, tuple)) or not val or not hasattr(val[0], "shape"):
186
+ return NHWCable(can_be=False, must_be=False)
187
+ if len(node.args) >= 3 and (node.args[1] is not None or node.args[2] is not None):
188
+ # Disable NHWC rewriter due to precision issue with weight and bias.
189
+ # TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
190
+ return NHWCable(can_be=False, must_be=False)
191
+ return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
192
+
193
+
179
194
  # ==== Ops must be NCHW
180
195
 
181
196
 
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import operator
15
16
  import os
16
17
  from typing import Optional, Tuple, Union
17
18
 
@@ -274,6 +275,14 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
274
275
  graph_module = layout_partitioners.greedy.partition(graph_module)
275
276
 
276
277
  graph = graph_module.graph
278
+ for node in list(graph.nodes):
279
+ if node.target == operator.getitem:
280
+ # force the layout mark of a getitem node to follow its producer.
281
+ if layout_mark.is_nchw_node(node.args[0]):
282
+ layout_mark.mark_as_nchw_node(node)
283
+ else:
284
+ layout_mark.mark_as_nhwc_node(node)
285
+
277
286
  for node in list(graph.nodes):
278
287
  if layout_mark.is_nhwc_node(node):
279
288
  for input_node in layout_check.get_layout_sensitive_inputs(node):
@@ -354,7 +354,7 @@ def _search_model(
354
354
  max_granularity: Optional[int] = None,
355
355
  enable_fx_minifier_logging: bool = False,
356
356
  ) -> Generator[SearchResult, None, None]:
357
- """Finds subgraphs in the torch model that satify a certain predicate function provided by the users.
357
+ """Finds subgraphs in the torch model that satisfy a certain predicate function provided by the users.
358
358
 
359
359
  Args:
360
360
  predicate_f: a predicate function the users specify.
@@ -382,26 +382,33 @@ def _search_model(
382
382
  fx_gm, fx_inputs = utils.exported_program_to_fx_graph_module_and_inputs(ep)
383
383
  fx_gm = _normalize_getitem_nodes(fx_gm)
384
384
 
385
- # HACK: temporarily disable XLA_HLO_DEBUG so that fx_minifier won't dump
386
- # intermediate stablehlo files to storage.
385
+ # HACK: temporarily disable XLA_HLO_DEBUG and create_minified_hlo_graph so that
386
+ # fx_minifier won't dump intermediate stablehlo files to storage.
387
387
  # https://github.com/pytorch/pytorch/blob/main/torch/_functorch/fx_minifier.py#L440
388
388
  @contextlib.contextmanager
389
- def disable_xla_hlo_debug():
389
+ def disable_minifier_xla_debug():
390
390
  xla_hlo_debug_value = None
391
391
  if "XLA_HLO_DEBUG" in os.environ:
392
392
  xla_hlo_debug_value = os.environ["XLA_HLO_DEBUG"]
393
393
  del os.environ["XLA_HLO_DEBUG"]
394
394
 
395
+ create_minified_hlo_graph = torch._functorch.fx_minifier.create_minified_hlo_graph
396
+ torch._functorch.fx_minifier.create_minified_hlo_graph = (
397
+ lambda *args, **kwargs: None
398
+ )
399
+
395
400
  try:
396
- yield None
401
+ yield
397
402
  finally:
398
403
  if xla_hlo_debug_value is not None:
399
404
  os.environ["XLA_HLO_DEBUG"] = xla_hlo_debug_value
400
405
 
406
+ torch._functorch.fx_minifier.create_minified_hlo_graph = create_minified_hlo_graph
407
+
401
408
  found_culprits_num = 0
402
409
  while True:
403
410
  try:
404
- with disable_xla_hlo_debug(), open(os.devnull, "w") as devnull:
411
+ with disable_minifier_xla_debug(), open(os.devnull, "w") as devnull:
405
412
  with contextlib.nullcontext() if enable_fx_minifier_logging else utils.redirect_stdio(
406
413
  stdout=devnull,
407
414
  stderr=devnull,
@@ -125,10 +125,9 @@ class AttentionBlock2D(nn.Module):
125
125
  """
126
126
  residual = input_tensor
127
127
  B, C, H, W = input_tensor.shape
128
- x = input_tensor
129
128
  if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
130
- x = self.norm(x)
131
- x = input_tensor.view(B, C, H * W)
129
+ x = self.norm(input_tensor)
130
+ x = x.view(B, C, H * W)
132
131
  x = x.transpose(-1, -2)
133
132
  else:
134
133
  x = input_tensor.view(B, C, H * W)
@@ -181,10 +180,9 @@ class CrossAttentionBlock2D(nn.Module):
181
180
  """
182
181
  residual = input_tensor
183
182
  B, C, H, W = input_tensor.shape
184
- x = input_tensor
185
183
  if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
186
- x = self.norm(x)
187
- x = input_tensor.view(B, C, H * W)
184
+ x = self.norm(input_tensor)
185
+ x = x.view(B, C, H * W)
188
186
  x = x.transpose(-1, -2)
189
187
  else:
190
188
  x = input_tensor.view(B, C, H * W)
@@ -222,10 +220,9 @@ class FeedForwardBlock2D(nn.Module):
222
220
  def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
223
221
  residual = input_tensor
224
222
  B, C, H, W = input_tensor.shape
225
- x = input_tensor
226
223
  if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
227
- x = self.norm(x)
228
- x = input_tensor.view(B, C, H * W)
224
+ x = self.norm(input_tensor)
225
+ x = x.view(B, C, H * W)
229
226
  x = x.transpose(-1, -2)
230
227
  else:
231
228
  x = input_tensor.view(B, C, H * W)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240721
3
+ Version: 0.2.0.dev20240726
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -27,7 +27,10 @@ Requires-Dist: numpy
27
27
  Requires-Dist: scipy
28
28
  Requires-Dist: safetensors
29
29
  Requires-Dist: tabulate
30
- Requires-Dist: torch >2.3
30
+ Requires-Dist: torch >=2.4.0
31
+ Requires-Dist: torch-xla >=2.4.0
32
+ Requires-Dist: tf-nightly >=2.18.0.dev20240722
33
+ Requires-Dist: ai-edge-quantizer-nightly ==0.0.1.dev20240718
31
34
 
32
35
  Library that supports converting PyTorch models into a .tflite format, which can
33
36
  then be run with TensorFlow Lite and MediaPipe. This enables applications for
@@ -12,11 +12,11 @@ ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py,sha256=6m_vc
12
12
  ai_edge_torch/convert/fx_passes/canonicalize_pass.py,sha256=UX6dJsxCqSkftXXvNBV-i7Bjk6H7qTyqzUnE640Itfg,1673
13
13
  ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=aRT8hTS3n9ie28lgu6mygtFO6Ypwu0qjNb0c81v9HLs,2448
14
14
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=VA9bekxPVhLk4MYlIRXnOzrSnbCtUmGj7OQ_fJcKQtc,795
15
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=pG-zLvO5vGs3gjNXa3RxGNwvC-_Azei2anxe2VdKsnY,6870
15
+ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=kV4Md7YBQIS-A8Dp4O-8SugNHAfDjVIBHvnldVPpHV0,7483
16
16
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=uXCHC23pWN-3JmDtAErWbSUnL8jjlQgUAy4gqtfDsQU,1560
17
17
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=RAgU31B98PQmXEIM3GOjgS0q9aRe2whJhGXpW2EjoqY,12438
18
18
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=tCx7J-WIFnxFCeRBtqJ159jWLgK9_9DCJrR4mkeBuYE,982
19
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=cfY6RTWQTGXNoQxKHaDcBYR9QdkVQXOWjKhuxvglocw,10383
19
+ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=7SuNFBfQWt42tEcUUxpQVnWhy-ByMuLHi2VtH8Kc1g4,10708
20
20
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=bItkXVaPA9THcFypAmqldpkLuD8WpOFmKlhVbBJJkPk,2076
21
21
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=B-zisphkH7aRCUOJNdwHnTA0fQXuDpN08q3Qjy5bL6E,715
22
22
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=8uHJbIwPMTgeSfYVba163pkXSQkHLxFwar_8A1AhgAM,2279
@@ -27,7 +27,7 @@ ai_edge_torch/convert/test/test_convert_composites.py,sha256=8UkdPtGkjgSVLCzB_rp
27
27
  ai_edge_torch/convert/test/test_convert_multisig.py,sha256=kMaGnHe9ylfyU68qCifYcaGwJqyejKz--QQt9jS2oUA,4537
28
28
  ai_edge_torch/convert/test/test_to_channel_last_io.py,sha256=I8c4ZG3v1vo0yxQYzLK_BTId4AOL9vadHGDtfCUZ4UI,2930
29
29
  ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyuKU,742
30
- ai_edge_torch/debug/culprit.py,sha256=urtCKPXORPvn6oyDxDSCSjgvngUnjjcsUMwAOeIl15E,14236
30
+ ai_edge_torch/debug/culprit.py,sha256=lN2N_J3EJTPcEH42xK6kvs4yOqZtyvvecSc4aidlog4,14556
31
31
  ai_edge_torch/debug/utils.py,sha256=hjVmQVVl1dKxEF0D6KB4a3ouQ3wBkTsebOX2YsUObZM,1430
32
32
  ai_edge_torch/debug/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
33
33
  ai_edge_torch/debug/test/test_culprit.py,sha256=9An_n9p_RWTAYdHYTCO-__EJlbnjclCDo8tDhOzMlwk,3731
@@ -90,7 +90,7 @@ ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0l
90
90
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
91
91
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
92
92
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
93
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=T70veX57CC9uNidwzoVGzOu-CwzcYMBr1Zk_0bq5UlM,26538
93
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=4et0TLBtAyVlYxMSJi3-oQoO5npFkOzcCYA927dvm_8,26475
94
94
  ai_edge_torch/generative/layers/unet/builder.py,sha256=NmJiZ2-e1wbv9jnvI3VCyUJlONV5ZAOz-RTc7ipAZ5U,1872
95
95
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=GU12QEJwO6ukveMR9JRsrhE0YIPKuhk1U81CylmOQTA,9097
96
96
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -125,8 +125,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
125
125
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
126
126
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
127
127
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=kzIulTldq8R9E-lAZsvfSTvLu3FYEX7b9DyYM3qisXM,4485
128
- ai_edge_torch_nightly-0.2.0.dev20240721.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
129
- ai_edge_torch_nightly-0.2.0.dev20240721.dist-info/METADATA,sha256=TJYFNAxXQkRwt9I_0OqpUOS3opWBU5i-ioMwsicD7cY,1745
130
- ai_edge_torch_nightly-0.2.0.dev20240721.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
131
- ai_edge_torch_nightly-0.2.0.dev20240721.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
132
- ai_edge_torch_nightly-0.2.0.dev20240721.dist-info/RECORD,,
128
+ ai_edge_torch_nightly-0.2.0.dev20240726.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
129
+ ai_edge_torch_nightly-0.2.0.dev20240726.dist-info/METADATA,sha256=QstgV90JdvfDL5-OAb_Z4dzWItNf6qT-gNfDtzdoTok,1889
130
+ ai_edge_torch_nightly-0.2.0.dev20240726.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
131
+ ai_edge_torch_nightly-0.2.0.dev20240726.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
132
+ ai_edge_torch_nightly-0.2.0.dev20240726.dist-info/RECORD,,