ai-edge-torch-nightly 0.2.0.dev20240721__py3-none-any.whl → 0.2.0.dev20240726__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +16 -1
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -0
- ai_edge_torch/debug/culprit.py +13 -6
- ai_edge_torch/generative/layers/unet/blocks_2d.py +6 -9
- {ai_edge_torch_nightly-0.2.0.dev20240721.dist-info → ai_edge_torch_nightly-0.2.0.dev20240726.dist-info}/METADATA +5 -2
- {ai_edge_torch_nightly-0.2.0.dev20240721.dist-info → ai_edge_torch_nightly-0.2.0.dev20240726.dist-info}/RECORD +9 -9
- {ai_edge_torch_nightly-0.2.0.dev20240721.dist-info → ai_edge_torch_nightly-0.2.0.dev20240726.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240721.dist-info → ai_edge_torch_nightly-0.2.0.dev20240726.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240721.dist-info → ai_edge_torch_nightly-0.2.0.dev20240726.dist-info}/top_level.txt +0 -0
|
@@ -113,6 +113,10 @@ def is_4d(node: Node):
|
|
|
113
113
|
val = node.meta.get("val")
|
|
114
114
|
if val is None:
|
|
115
115
|
return False
|
|
116
|
+
|
|
117
|
+
if isinstance(val, (list, tuple)) and val:
|
|
118
|
+
val = val[0]
|
|
119
|
+
|
|
116
120
|
if not hasattr(val, "shape"):
|
|
117
121
|
return False
|
|
118
122
|
|
|
@@ -168,7 +172,6 @@ def _all_layout_sensitive_inputs_are_4d_checker(node: Node):
|
|
|
168
172
|
|
|
169
173
|
|
|
170
174
|
@nhwcable_node_checkers.register(aten._native_batch_norm_legit_no_training)
|
|
171
|
-
@nhwcable_node_checkers.register(aten.native_group_norm)
|
|
172
175
|
def _aten_norm_checker(node):
|
|
173
176
|
val = node.meta.get("val")
|
|
174
177
|
if not isinstance(val, (list, tuple)) or not val or not hasattr(val[0], "shape"):
|
|
@@ -176,6 +179,18 @@ def _aten_norm_checker(node):
|
|
|
176
179
|
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
|
|
177
180
|
|
|
178
181
|
|
|
182
|
+
@nhwcable_node_checkers.register(aten.native_group_norm)
|
|
183
|
+
def _aten_native_group_norm_checker(node):
|
|
184
|
+
val = node.meta.get("val")
|
|
185
|
+
if not isinstance(val, (list, tuple)) or not val or not hasattr(val[0], "shape"):
|
|
186
|
+
return NHWCable(can_be=False, must_be=False)
|
|
187
|
+
if len(node.args) >= 3 and (node.args[1] is not None or node.args[2] is not None):
|
|
188
|
+
# Disable NHWC rewriter due to precision issue with weight and bias.
|
|
189
|
+
# TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
|
|
190
|
+
return NHWCable(can_be=False, must_be=False)
|
|
191
|
+
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
|
|
192
|
+
|
|
193
|
+
|
|
179
194
|
# ==== Ops must be NCHW
|
|
180
195
|
|
|
181
196
|
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
import operator
|
|
15
16
|
import os
|
|
16
17
|
from typing import Optional, Tuple, Union
|
|
17
18
|
|
|
@@ -274,6 +275,14 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
|
|
|
274
275
|
graph_module = layout_partitioners.greedy.partition(graph_module)
|
|
275
276
|
|
|
276
277
|
graph = graph_module.graph
|
|
278
|
+
for node in list(graph.nodes):
|
|
279
|
+
if node.target == operator.getitem:
|
|
280
|
+
# force the layout mark of a getitem node to follow its producer.
|
|
281
|
+
if layout_mark.is_nchw_node(node.args[0]):
|
|
282
|
+
layout_mark.mark_as_nchw_node(node)
|
|
283
|
+
else:
|
|
284
|
+
layout_mark.mark_as_nhwc_node(node)
|
|
285
|
+
|
|
277
286
|
for node in list(graph.nodes):
|
|
278
287
|
if layout_mark.is_nhwc_node(node):
|
|
279
288
|
for input_node in layout_check.get_layout_sensitive_inputs(node):
|
ai_edge_torch/debug/culprit.py
CHANGED
|
@@ -354,7 +354,7 @@ def _search_model(
|
|
|
354
354
|
max_granularity: Optional[int] = None,
|
|
355
355
|
enable_fx_minifier_logging: bool = False,
|
|
356
356
|
) -> Generator[SearchResult, None, None]:
|
|
357
|
-
"""Finds subgraphs in the torch model that
|
|
357
|
+
"""Finds subgraphs in the torch model that satisfy a certain predicate function provided by the users.
|
|
358
358
|
|
|
359
359
|
Args:
|
|
360
360
|
predicate_f: a predicate function the users specify.
|
|
@@ -382,26 +382,33 @@ def _search_model(
|
|
|
382
382
|
fx_gm, fx_inputs = utils.exported_program_to_fx_graph_module_and_inputs(ep)
|
|
383
383
|
fx_gm = _normalize_getitem_nodes(fx_gm)
|
|
384
384
|
|
|
385
|
-
# HACK: temporarily disable XLA_HLO_DEBUG so that
|
|
386
|
-
# intermediate stablehlo files to storage.
|
|
385
|
+
# HACK: temporarily disable XLA_HLO_DEBUG and create_minified_hlo_graph so that
|
|
386
|
+
# fx_minifier won't dump intermediate stablehlo files to storage.
|
|
387
387
|
# https://github.com/pytorch/pytorch/blob/main/torch/_functorch/fx_minifier.py#L440
|
|
388
388
|
@contextlib.contextmanager
|
|
389
|
-
def
|
|
389
|
+
def disable_minifier_xla_debug():
|
|
390
390
|
xla_hlo_debug_value = None
|
|
391
391
|
if "XLA_HLO_DEBUG" in os.environ:
|
|
392
392
|
xla_hlo_debug_value = os.environ["XLA_HLO_DEBUG"]
|
|
393
393
|
del os.environ["XLA_HLO_DEBUG"]
|
|
394
394
|
|
|
395
|
+
create_minified_hlo_graph = torch._functorch.fx_minifier.create_minified_hlo_graph
|
|
396
|
+
torch._functorch.fx_minifier.create_minified_hlo_graph = (
|
|
397
|
+
lambda *args, **kwargs: None
|
|
398
|
+
)
|
|
399
|
+
|
|
395
400
|
try:
|
|
396
|
-
yield
|
|
401
|
+
yield
|
|
397
402
|
finally:
|
|
398
403
|
if xla_hlo_debug_value is not None:
|
|
399
404
|
os.environ["XLA_HLO_DEBUG"] = xla_hlo_debug_value
|
|
400
405
|
|
|
406
|
+
torch._functorch.fx_minifier.create_minified_hlo_graph = create_minified_hlo_graph
|
|
407
|
+
|
|
401
408
|
found_culprits_num = 0
|
|
402
409
|
while True:
|
|
403
410
|
try:
|
|
404
|
-
with
|
|
411
|
+
with disable_minifier_xla_debug(), open(os.devnull, "w") as devnull:
|
|
405
412
|
with contextlib.nullcontext() if enable_fx_minifier_logging else utils.redirect_stdio(
|
|
406
413
|
stdout=devnull,
|
|
407
414
|
stderr=devnull,
|
|
@@ -125,10 +125,9 @@ class AttentionBlock2D(nn.Module):
|
|
|
125
125
|
"""
|
|
126
126
|
residual = input_tensor
|
|
127
127
|
B, C, H, W = input_tensor.shape
|
|
128
|
-
x = input_tensor
|
|
129
128
|
if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
|
|
130
|
-
x = self.norm(
|
|
131
|
-
x =
|
|
129
|
+
x = self.norm(input_tensor)
|
|
130
|
+
x = x.view(B, C, H * W)
|
|
132
131
|
x = x.transpose(-1, -2)
|
|
133
132
|
else:
|
|
134
133
|
x = input_tensor.view(B, C, H * W)
|
|
@@ -181,10 +180,9 @@ class CrossAttentionBlock2D(nn.Module):
|
|
|
181
180
|
"""
|
|
182
181
|
residual = input_tensor
|
|
183
182
|
B, C, H, W = input_tensor.shape
|
|
184
|
-
x = input_tensor
|
|
185
183
|
if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
|
|
186
|
-
x = self.norm(
|
|
187
|
-
x =
|
|
184
|
+
x = self.norm(input_tensor)
|
|
185
|
+
x = x.view(B, C, H * W)
|
|
188
186
|
x = x.transpose(-1, -2)
|
|
189
187
|
else:
|
|
190
188
|
x = input_tensor.view(B, C, H * W)
|
|
@@ -222,10 +220,9 @@ class FeedForwardBlock2D(nn.Module):
|
|
|
222
220
|
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
223
221
|
residual = input_tensor
|
|
224
222
|
B, C, H, W = input_tensor.shape
|
|
225
|
-
x = input_tensor
|
|
226
223
|
if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
|
|
227
|
-
x = self.norm(
|
|
228
|
-
x =
|
|
224
|
+
x = self.norm(input_tensor)
|
|
225
|
+
x = x.view(B, C, H * W)
|
|
229
226
|
x = x.transpose(-1, -2)
|
|
230
227
|
else:
|
|
231
228
|
x = input_tensor.view(B, C, H * W)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.2.0.
|
|
3
|
+
Version: 0.2.0.dev20240726
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -27,7 +27,10 @@ Requires-Dist: numpy
|
|
|
27
27
|
Requires-Dist: scipy
|
|
28
28
|
Requires-Dist: safetensors
|
|
29
29
|
Requires-Dist: tabulate
|
|
30
|
-
Requires-Dist: torch
|
|
30
|
+
Requires-Dist: torch >=2.4.0
|
|
31
|
+
Requires-Dist: torch-xla >=2.4.0
|
|
32
|
+
Requires-Dist: tf-nightly >=2.18.0.dev20240722
|
|
33
|
+
Requires-Dist: ai-edge-quantizer-nightly ==0.0.1.dev20240718
|
|
31
34
|
|
|
32
35
|
Library that supports converting PyTorch models into a .tflite format, which can
|
|
33
36
|
then be run with TensorFlow Lite and MediaPipe. This enables applications for
|
|
@@ -12,11 +12,11 @@ ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py,sha256=6m_vc
|
|
|
12
12
|
ai_edge_torch/convert/fx_passes/canonicalize_pass.py,sha256=UX6dJsxCqSkftXXvNBV-i7Bjk6H7qTyqzUnE640Itfg,1673
|
|
13
13
|
ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=aRT8hTS3n9ie28lgu6mygtFO6Ypwu0qjNb0c81v9HLs,2448
|
|
14
14
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=VA9bekxPVhLk4MYlIRXnOzrSnbCtUmGj7OQ_fJcKQtc,795
|
|
15
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=
|
|
15
|
+
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=kV4Md7YBQIS-A8Dp4O-8SugNHAfDjVIBHvnldVPpHV0,7483
|
|
16
16
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=uXCHC23pWN-3JmDtAErWbSUnL8jjlQgUAy4gqtfDsQU,1560
|
|
17
17
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=RAgU31B98PQmXEIM3GOjgS0q9aRe2whJhGXpW2EjoqY,12438
|
|
18
18
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=tCx7J-WIFnxFCeRBtqJ159jWLgK9_9DCJrR4mkeBuYE,982
|
|
19
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=
|
|
19
|
+
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=7SuNFBfQWt42tEcUUxpQVnWhy-ByMuLHi2VtH8Kc1g4,10708
|
|
20
20
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=bItkXVaPA9THcFypAmqldpkLuD8WpOFmKlhVbBJJkPk,2076
|
|
21
21
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=B-zisphkH7aRCUOJNdwHnTA0fQXuDpN08q3Qjy5bL6E,715
|
|
22
22
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=8uHJbIwPMTgeSfYVba163pkXSQkHLxFwar_8A1AhgAM,2279
|
|
@@ -27,7 +27,7 @@ ai_edge_torch/convert/test/test_convert_composites.py,sha256=8UkdPtGkjgSVLCzB_rp
|
|
|
27
27
|
ai_edge_torch/convert/test/test_convert_multisig.py,sha256=kMaGnHe9ylfyU68qCifYcaGwJqyejKz--QQt9jS2oUA,4537
|
|
28
28
|
ai_edge_torch/convert/test/test_to_channel_last_io.py,sha256=I8c4ZG3v1vo0yxQYzLK_BTId4AOL9vadHGDtfCUZ4UI,2930
|
|
29
29
|
ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyuKU,742
|
|
30
|
-
ai_edge_torch/debug/culprit.py,sha256=
|
|
30
|
+
ai_edge_torch/debug/culprit.py,sha256=lN2N_J3EJTPcEH42xK6kvs4yOqZtyvvecSc4aidlog4,14556
|
|
31
31
|
ai_edge_torch/debug/utils.py,sha256=hjVmQVVl1dKxEF0D6KB4a3ouQ3wBkTsebOX2YsUObZM,1430
|
|
32
32
|
ai_edge_torch/debug/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
33
33
|
ai_edge_torch/debug/test/test_culprit.py,sha256=9An_n9p_RWTAYdHYTCO-__EJlbnjclCDo8tDhOzMlwk,3731
|
|
@@ -90,7 +90,7 @@ ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0l
|
|
|
90
90
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
|
|
91
91
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
|
|
92
92
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
93
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
|
93
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=4et0TLBtAyVlYxMSJi3-oQoO5npFkOzcCYA927dvm_8,26475
|
|
94
94
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=NmJiZ2-e1wbv9jnvI3VCyUJlONV5ZAOz-RTc7ipAZ5U,1872
|
|
95
95
|
ai_edge_torch/generative/layers/unet/model_config.py,sha256=GU12QEJwO6ukveMR9JRsrhE0YIPKuhk1U81CylmOQTA,9097
|
|
96
96
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
@@ -125,8 +125,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
|
|
|
125
125
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
126
126
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
127
127
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=kzIulTldq8R9E-lAZsvfSTvLu3FYEX7b9DyYM3qisXM,4485
|
|
128
|
-
ai_edge_torch_nightly-0.2.0.
|
|
129
|
-
ai_edge_torch_nightly-0.2.0.
|
|
130
|
-
ai_edge_torch_nightly-0.2.0.
|
|
131
|
-
ai_edge_torch_nightly-0.2.0.
|
|
132
|
-
ai_edge_torch_nightly-0.2.0.
|
|
128
|
+
ai_edge_torch_nightly-0.2.0.dev20240726.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
129
|
+
ai_edge_torch_nightly-0.2.0.dev20240726.dist-info/METADATA,sha256=QstgV90JdvfDL5-OAb_Z4dzWItNf6qT-gNfDtzdoTok,1889
|
|
130
|
+
ai_edge_torch_nightly-0.2.0.dev20240726.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
131
|
+
ai_edge_torch_nightly-0.2.0.dev20240726.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
132
|
+
ai_edge_torch_nightly-0.2.0.dev20240726.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|