ai-edge-torch-nightly 0.2.0.dev20240806__py3-none-any.whl → 0.2.0.dev20240807__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +5 -5
- ai_edge_torch/{convert → _convert}/conversion.py +40 -50
- ai_edge_torch/_convert/conversion_utils.py +64 -0
- ai_edge_torch/{convert → _convert}/converter.py +83 -43
- ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
- ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
- ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
- ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
- ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
- ai_edge_torch/_convert/signature.py +100 -0
- ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
- ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
- ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
- ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
- ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
- ai_edge_torch/config.py +24 -0
- ai_edge_torch/conftest.py +20 -0
- ai_edge_torch/debug/culprit.py +22 -22
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +5 -5
- ai_edge_torch/debug/utils.py +11 -2
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
- ai_edge_torch/generative/examples/t5/t5.py +2 -2
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/fx_passes/__init__.py +2 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
- ai_edge_torch/generative/layers/attention.py +35 -26
- ai_edge_torch/generative/layers/attention_utils.py +23 -12
- ai_edge_torch/generative/layers/builder.py +0 -1
- ai_edge_torch/generative/layers/feed_forward.py +6 -10
- ai_edge_torch/generative/layers/kv_cache.py +0 -1
- ai_edge_torch/generative/layers/model_config.py +2 -5
- ai_edge_torch/generative/layers/normalization.py +5 -7
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
- ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
- ai_edge_torch/generative/layers/unet/model_config.py +14 -15
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
- ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
- ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
- ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
- ai_edge_torch/generative/test/test_model_conversion.py +24 -25
- ai_edge_torch/generative/test/test_quantize.py +10 -5
- ai_edge_torch/generative/utilities/loader.py +12 -12
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
- ai_edge_torch/generative/utilities/t5_loader.py +12 -13
- ai_edge_torch/hlfb/__init__.py +1 -1
- ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
- ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
- ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
- ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
- ai_edge_torch/lowertools/_shim.py +80 -0
- ai_edge_torch/lowertools/common_utils.py +89 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +201 -0
- ai_edge_torch/{convert/conversion_utils.py → lowertools/torch_xla_utils.py} +35 -214
- ai_edge_torch/model.py +14 -9
- ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
- ai_edge_torch/quantize/quant_config.py +7 -7
- ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/METADATA +1 -1
- ai_edge_torch_nightly-0.2.0.dev20240807.dist-info/RECORD +141 -0
- ai_edge_torch_nightly-0.2.0.dev20240806.dist-info/RECORD +0 -133
- /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
- /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
- /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/top_level.txt +0 -0
|
@@ -12,19 +12,22 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
"""Build interpolate composite pass."""
|
|
15
16
|
|
|
16
17
|
import functools
|
|
17
18
|
|
|
18
|
-
from ai_edge_torch.
|
|
19
|
-
from ai_edge_torch.
|
|
19
|
+
from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
|
|
20
|
+
from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
|
|
20
21
|
from ai_edge_torch.hlfb import mark_pattern
|
|
22
|
+
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
|
|
21
23
|
import torch
|
|
22
24
|
|
|
23
25
|
# For torch nightly released after mid June 2024,
|
|
24
26
|
# torch.nn.functional.interpolate no longer gets exported into decomposed graph
|
|
25
|
-
# but single aten op
|
|
26
|
-
#
|
|
27
|
-
#
|
|
27
|
+
# but a single aten op:
|
|
28
|
+
# torch.ops.aten.upsample_nearest2d.vec/torch.ops.aten.upsample_bilinear2d.vec.
|
|
29
|
+
# This would interefere with our pattern matching based composite builder.
|
|
30
|
+
# Here we register the now missing decompositions first.
|
|
28
31
|
_INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions([
|
|
29
32
|
torch.ops.aten.upsample_bilinear2d.vec,
|
|
30
33
|
torch.ops.aten.upsample_nearest2d.vec,
|
|
@@ -33,7 +36,7 @@ _INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions([
|
|
|
33
36
|
|
|
34
37
|
@functools.cache
|
|
35
38
|
def _get_upsample_bilinear2d_pattern():
|
|
36
|
-
pattern =
|
|
39
|
+
pattern = pattern_module.Pattern(
|
|
37
40
|
"odml.upsample_bilinear2d",
|
|
38
41
|
lambda x: torch.nn.functional.interpolate(
|
|
39
42
|
x, scale_factor=2, mode="bilinear", align_corners=False
|
|
@@ -56,7 +59,7 @@ def _get_upsample_bilinear2d_pattern():
|
|
|
56
59
|
|
|
57
60
|
@functools.cache
|
|
58
61
|
def _get_upsample_bilinear2d_align_corners_pattern():
|
|
59
|
-
pattern =
|
|
62
|
+
pattern = pattern_module.Pattern(
|
|
60
63
|
"odml.upsample_bilinear2d",
|
|
61
64
|
lambda x: torch.nn.functional.interpolate(
|
|
62
65
|
x, scale_factor=2, mode="bilinear", align_corners=True
|
|
@@ -79,7 +82,7 @@ def _get_upsample_bilinear2d_align_corners_pattern():
|
|
|
79
82
|
|
|
80
83
|
@functools.cache
|
|
81
84
|
def _get_interpolate_nearest2d_pattern():
|
|
82
|
-
pattern =
|
|
85
|
+
pattern = pattern_module.Pattern(
|
|
83
86
|
"tfl.resize_nearest_neighbor",
|
|
84
87
|
lambda x: torch.nn.functional.interpolate(
|
|
85
88
|
x, scale_factor=2, mode="nearest"
|
|
@@ -13,8 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from ai_edge_torch.
|
|
17
|
-
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
|
|
16
|
+
from ai_edge_torch._convert.fx_passes import _pass_base
|
|
18
17
|
import torch
|
|
19
18
|
from torch.export import ExportedProgram
|
|
20
19
|
|
|
@@ -29,8 +28,8 @@ _dummy_decomp_table = {
|
|
|
29
28
|
}
|
|
30
29
|
|
|
31
30
|
|
|
32
|
-
class CanonicalizePass(ExportedProgramPassBase):
|
|
31
|
+
class CanonicalizePass(_pass_base.ExportedProgramPassBase):
|
|
33
32
|
|
|
34
33
|
def call(self, exported_program: ExportedProgram):
|
|
35
34
|
exported_program = exported_program.run_decompositions(_dummy_decomp_table)
|
|
36
|
-
return ExportedProgramPassResult(exported_program, True)
|
|
35
|
+
return _pass_base.ExportedProgramPassResult(exported_program, True)
|
|
@@ -13,11 +13,11 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
+
from ai_edge_torch import lowertools
|
|
16
17
|
import torch
|
|
17
18
|
from torch.fx.passes.infra.pass_base import PassBase
|
|
18
19
|
from torch.fx.passes.infra.pass_base import PassResult
|
|
19
20
|
import torch.utils._pytree as pytree
|
|
20
|
-
import torch_xla.experimental.xla_mlir_debuginfo # Import required to register torch.ops.xla.write_mlir_debuginfo
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
def _get_mlir_debuginfo(node: torch.fx.Node):
|
|
@@ -54,7 +54,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
|
|
|
54
54
|
outputs = target(*args, **kwargs)
|
|
55
55
|
outputs = pytree.tree_map_only(
|
|
56
56
|
torch.Tensor,
|
|
57
|
-
lambda x:
|
|
57
|
+
lambda x: lowertools.write_mlir_debuginfo_op(x, debuginfo),
|
|
58
58
|
outputs,
|
|
59
59
|
)
|
|
60
60
|
return outputs
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass.pass_body import OptimizeLayoutTransposesPass # NOQA
|
ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py
RENAMED
|
@@ -12,13 +12,15 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
"""Layout check for the optimized layout transposes pass."""
|
|
16
|
+
|
|
15
17
|
import dataclasses
|
|
16
18
|
import operator
|
|
17
19
|
|
|
18
|
-
from ai_edge_torch
|
|
19
|
-
from ai_edge_torch.
|
|
20
|
-
from ai_edge_torch.
|
|
21
|
-
from ai_edge_torch.
|
|
20
|
+
from ai_edge_torch import lowertools
|
|
21
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite
|
|
22
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
|
|
23
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass.op_func_registry import OpFuncRegistry
|
|
22
24
|
import torch
|
|
23
25
|
from torch.fx import Node
|
|
24
26
|
|
|
@@ -205,7 +207,7 @@ def _aten_native_group_norm_checker(node):
|
|
|
205
207
|
# ==== Ops must be NCHW
|
|
206
208
|
|
|
207
209
|
|
|
208
|
-
@nhwcable_node_checkers.register(
|
|
210
|
+
@nhwcable_node_checkers.register(lowertools.mark_tensor_op)
|
|
209
211
|
@nhwcable_node_checkers.register(utils.tensor_to_nchw)
|
|
210
212
|
@nhwcable_node_checkers.register(utils.tensor_to_nhwc)
|
|
211
213
|
@nhwcable_node_checkers.register("output")
|
|
@@ -12,6 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
"""Layout mark for the optimized layout transposes pass."""
|
|
16
|
+
|
|
15
17
|
import torch
|
|
16
18
|
|
|
17
19
|
# Tag which is added to a node's meta to indicate that is is part of the NHWC
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
"""Layout partitioners."""
|
|
15
16
|
|
|
16
17
|
from . import greedy
|
|
17
18
|
from . import min_cut
|
|
@@ -12,23 +12,31 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
"""Greedy partitioning algorithm."""
|
|
15
16
|
|
|
16
|
-
from ai_edge_torch.
|
|
17
|
-
from ai_edge_torch.
|
|
17
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check
|
|
18
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
|
|
18
19
|
import torch
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
def partition(graph_module: torch.fx.GraphModule):
|
|
22
|
-
"""Partition the graph module into NHWC and non-NHWC subgraphs
|
|
23
|
-
|
|
23
|
+
"""Partition the graph module into NHWC and non-NHWC subgraphs.
|
|
24
|
+
|
|
25
|
+
Partition the graph module into NHWC and non-NHWC subgraphs and mark nodes in
|
|
26
|
+
the NHWC partitions.
|
|
24
27
|
|
|
25
28
|
Implements O(|V|) greedy partitioning algorithm.
|
|
26
|
-
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
graph_module: The graph module to be partitioned.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
The partitioned graph module.
|
|
27
35
|
"""
|
|
28
36
|
graph = graph_module.graph
|
|
29
37
|
|
|
30
38
|
for node in list(graph.nodes):
|
|
31
|
-
if
|
|
39
|
+
if not node.all_input_nodes:
|
|
32
40
|
# This node has no inputs so we don't need to change anything
|
|
33
41
|
continue
|
|
34
42
|
|
|
@@ -12,13 +12,13 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
"""Min cut solver for partitioning the graph module into NHWC and non-NHWC subgraphs."""
|
|
15
16
|
|
|
16
17
|
import collections
|
|
17
18
|
import dataclasses
|
|
18
|
-
import itertools
|
|
19
19
|
|
|
20
|
-
from ai_edge_torch.
|
|
21
|
-
from ai_edge_torch.
|
|
20
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
|
|
21
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
|
|
22
22
|
import numpy as np
|
|
23
23
|
import scipy
|
|
24
24
|
import torch
|
|
@@ -26,13 +26,12 @@ import torch
|
|
|
26
26
|
|
|
27
27
|
def can_partition(graph_module: torch.fx.GraphModule):
|
|
28
28
|
"""Returns true if the input graph_module can be partitioned by min cut solver
|
|
29
|
+
|
|
29
30
|
in a reasonable time.
|
|
30
31
|
|
|
31
32
|
The min cut solver implements O(|V|^2|E|) Dinic's algorithm, which may
|
|
32
33
|
take a long time to complete for large graph module. This function determines
|
|
33
34
|
whether the graph module can be partitioned by the graph module size.
|
|
34
|
-
|
|
35
|
-
See go/pytorch-layout-transpose-optimization for more details.
|
|
36
35
|
"""
|
|
37
36
|
graph = graph_module.graph
|
|
38
37
|
n_nodes = len(graph.nodes)
|
|
@@ -137,10 +136,10 @@ class MultiUsersDummyNode:
|
|
|
137
136
|
|
|
138
137
|
def partition(graph_module: torch.fx.GraphModule):
|
|
139
138
|
"""Partition the graph module into NHWC and non-NHWC subgraphs, and mark
|
|
139
|
+
|
|
140
140
|
nodes in the NHWC partitions.
|
|
141
141
|
|
|
142
142
|
Implements O(|V|^2|E|) min-cut (optimal) partitioning algorithm.
|
|
143
|
-
See go/pytorch-layout-transpose-optimization for more details.
|
|
144
143
|
"""
|
|
145
144
|
graph = graph_module.graph
|
|
146
145
|
|
ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py
RENAMED
|
@@ -12,13 +12,14 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
"""Layout rewrite for the optimized layout transposes pass."""
|
|
16
|
+
|
|
15
17
|
import operator
|
|
16
18
|
|
|
17
|
-
from ai_edge_torch.
|
|
18
|
-
from ai_edge_torch.
|
|
19
|
-
from ai_edge_torch.
|
|
19
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
|
|
20
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import op_func_registry
|
|
21
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
|
|
20
22
|
import torch
|
|
21
|
-
from torch.fx import Node
|
|
22
23
|
import torch.utils._pytree as pytree
|
|
23
24
|
|
|
24
25
|
aten = torch.ops.aten
|
|
@@ -26,7 +27,7 @@ aten = torch.ops.aten
|
|
|
26
27
|
__all__ = ["rewrite_nhwc_node", "has_nhwc_rewriter"]
|
|
27
28
|
|
|
28
29
|
|
|
29
|
-
class NHWCNodeRewritersRegistry(OpFuncRegistry):
|
|
30
|
+
class NHWCNodeRewritersRegistry(op_func_registry.OpFuncRegistry):
|
|
30
31
|
|
|
31
32
|
def __missing__(self, op):
|
|
32
33
|
def _rewriter(node):
|
|
@@ -38,14 +39,14 @@ class NHWCNodeRewritersRegistry(OpFuncRegistry):
|
|
|
38
39
|
rewriters = NHWCNodeRewritersRegistry()
|
|
39
40
|
|
|
40
41
|
|
|
41
|
-
def rewrite_nhwc_node(node: Node):
|
|
42
|
+
def rewrite_nhwc_node(node: torch.fx.Node):
|
|
42
43
|
if not layout_mark.is_nhwc_node(node):
|
|
43
44
|
return
|
|
44
45
|
|
|
45
46
|
rewriters[node.target](node)
|
|
46
47
|
|
|
47
48
|
|
|
48
|
-
def has_nhwc_rewriter(node: Node):
|
|
49
|
+
def has_nhwc_rewriter(node: torch.fx.Node):
|
|
49
50
|
return node.target in rewriters
|
|
50
51
|
|
|
51
52
|
|
|
@@ -54,13 +55,13 @@ def has_nhwc_rewriter(node: Node):
|
|
|
54
55
|
|
|
55
56
|
@rewriters.register(torch.ops.quantized_decomposed.dequantize_per_tensor)
|
|
56
57
|
@rewriters.register(torch.ops.quantized_decomposed.quantize_per_tensor)
|
|
57
|
-
def noop(node: Node):
|
|
58
|
+
def noop(node: torch.fx.Node):
|
|
58
59
|
pass
|
|
59
60
|
|
|
60
61
|
|
|
61
62
|
@rewriters.register(torch.ops.quantized_decomposed.dequantize_per_channel)
|
|
62
63
|
@rewriters.register(torch.ops.quantized_decomposed.quantize_per_channel)
|
|
63
|
-
def _qdq_per_channel_rewriter(node: Node):
|
|
64
|
+
def _qdq_per_channel_rewriter(node: torch.fx.Node):
|
|
64
65
|
new_args = []
|
|
65
66
|
new_kwargs = {}
|
|
66
67
|
|
|
@@ -199,7 +200,7 @@ def _qdq_per_channel_rewriter(node: Node):
|
|
|
199
200
|
@rewriters.register(aten._prelu_kernel)
|
|
200
201
|
@rewriters.register(aten.softplus)
|
|
201
202
|
@rewriters.register(aten.silu)
|
|
202
|
-
def noop(node: Node):
|
|
203
|
+
def noop(node: torch.fx.Node):
|
|
203
204
|
pass
|
|
204
205
|
|
|
205
206
|
|
|
@@ -212,14 +213,16 @@ def noop(node: Node):
|
|
|
212
213
|
@rewriters.register(aten.max_pool2d_with_indices)
|
|
213
214
|
@rewriters.register(aten.avg_pool2d)
|
|
214
215
|
@rewriters.register(aten._adaptive_avg_pool2d.default)
|
|
215
|
-
def transpose_first_arg_rewriter(node: Node):
|
|
216
|
+
def transpose_first_arg_rewriter(node: torch.fx.Node):
|
|
216
217
|
op = node.target
|
|
217
218
|
|
|
218
219
|
def nhwc_op(x, *args, **kwargs):
|
|
219
220
|
nonlocal op
|
|
220
221
|
x = utils.tensor_to_nchw(x)
|
|
221
222
|
res = pytree.tree_map_only(
|
|
222
|
-
torch.Tensor,
|
|
223
|
+
torch.Tensor,
|
|
224
|
+
utils.tensor_to_nhwc,
|
|
225
|
+
op(x, *args, **kwargs),
|
|
223
226
|
)
|
|
224
227
|
return res
|
|
225
228
|
|
|
@@ -227,7 +230,7 @@ def transpose_first_arg_rewriter(node: Node):
|
|
|
227
230
|
|
|
228
231
|
|
|
229
232
|
@rewriters.register(aten.convolution)
|
|
230
|
-
def _aten_convolution_rewriter(node: Node):
|
|
233
|
+
def _aten_convolution_rewriter(node: torch.fx.Node):
|
|
231
234
|
op = node.target
|
|
232
235
|
|
|
233
236
|
def conv_nhwc(input, weight, bias, *args, **kwargs):
|
|
@@ -286,7 +289,7 @@ def _aten_convolution_rewriter(node: Node):
|
|
|
286
289
|
@rewriters.register(aten.sort.default)
|
|
287
290
|
@rewriters.register(aten.topk.default)
|
|
288
291
|
@rewriters.register(aten.cat.default)
|
|
289
|
-
def dim_attr_rewriter(node: Node):
|
|
292
|
+
def dim_attr_rewriter(node: torch.fx.Node):
|
|
290
293
|
op = node.target
|
|
291
294
|
|
|
292
295
|
new_args = []
|
ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py
RENAMED
|
@@ -12,8 +12,9 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
|
|
15
|
+
"""Op function registry for the optimized layout transposes pass."""
|
|
16
|
+
|
|
17
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class OpFuncRegistry(dict):
|
|
@@ -12,23 +12,21 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
"""Optimize layout transposes pass."""
|
|
16
|
+
|
|
15
17
|
import operator
|
|
16
18
|
import os
|
|
17
|
-
from typing import
|
|
18
|
-
|
|
19
|
-
from ai_edge_torch.
|
|
20
|
-
from ai_edge_torch.
|
|
21
|
-
from ai_edge_torch.
|
|
22
|
-
from ai_edge_torch.
|
|
23
|
-
from ai_edge_torch.
|
|
24
|
-
from ai_edge_torch.
|
|
25
|
-
from ai_edge_torch.
|
|
19
|
+
from typing import Union
|
|
20
|
+
|
|
21
|
+
from ai_edge_torch._convert.fx_passes import ExportedProgramPassBase
|
|
22
|
+
from ai_edge_torch._convert.fx_passes import ExportedProgramPassResult
|
|
23
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
|
|
24
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
|
|
25
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
|
|
26
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite # NOQA
|
|
27
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
|
|
26
28
|
import torch
|
|
27
29
|
import torch.ao.quantization.quantize_pt2e
|
|
28
|
-
from torch.export import ExportedProgram
|
|
29
|
-
from torch.fx import GraphModule
|
|
30
|
-
from torch.fx import Node
|
|
31
|
-
import torch.utils._pytree as pytree
|
|
32
30
|
|
|
33
31
|
TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
|
|
34
32
|
|
|
@@ -51,8 +49,8 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
|
|
|
51
49
|
transpose_func: TransposeFunc,
|
|
52
50
|
transpose_node_meta: dict,
|
|
53
51
|
) -> list[torch.fx.Node]:
|
|
54
|
-
"""
|
|
55
|
-
|
|
52
|
+
"""original:
|
|
53
|
+
|
|
56
54
|
input_dq -> target
|
|
57
55
|
insert the node as:
|
|
58
56
|
input_dq -> (T q dq) -> target
|
|
@@ -86,8 +84,8 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
|
|
|
86
84
|
transpose_func: TransposeFunc,
|
|
87
85
|
transpose_node_meta: dict,
|
|
88
86
|
) -> list[torch.fx.Node]:
|
|
89
|
-
"""
|
|
90
|
-
|
|
87
|
+
"""original:
|
|
88
|
+
|
|
91
89
|
input_q -> target
|
|
92
90
|
insert the node as:
|
|
93
91
|
input_q -> (dq T q) -> target
|
|
@@ -12,6 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
"""Utils for the optimized layout transposes pass."""
|
|
16
|
+
|
|
15
17
|
from typing import Callable
|
|
16
18
|
|
|
17
19
|
import torch
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import dataclasses
|
|
17
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
import torch.utils._pytree as pytree
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclasses.dataclass
|
|
24
|
+
class Signature:
|
|
25
|
+
name: str
|
|
26
|
+
module: torch.nn.Module
|
|
27
|
+
sample_args: tuple[torch.Tensor]
|
|
28
|
+
sample_kwargs: dict[str, torch.Tensor]
|
|
29
|
+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def _normalized_sample_args_kwargs(self):
|
|
33
|
+
args, kwargs = self.sample_args, self.sample_kwargs
|
|
34
|
+
if args is not None:
|
|
35
|
+
if not isinstance(args, tuple):
|
|
36
|
+
# TODO(b/352584188): Check value types
|
|
37
|
+
raise ValueError("sample_args must be a tuple of torch tensors.")
|
|
38
|
+
if kwargs is not None:
|
|
39
|
+
if not isinstance(kwargs, dict) or not all(
|
|
40
|
+
isinstance(key, str) for key in kwargs.keys()
|
|
41
|
+
):
|
|
42
|
+
# TODO(b/352584188): Check value types
|
|
43
|
+
raise ValueError("sample_kwargs must be a dict of string to tensor.")
|
|
44
|
+
args = args if args is not None else tuple()
|
|
45
|
+
kwargs = kwargs if kwargs is not None else {}
|
|
46
|
+
return args, kwargs
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def flat_arg_names(self) -> list[str]:
|
|
50
|
+
spec = pytree.tree_flatten(self._normalized_sample_args_kwargs)[1]
|
|
51
|
+
args_spec, kwargs_spec = spec.children_specs
|
|
52
|
+
names = []
|
|
53
|
+
for i in range(args_spec.num_leaves):
|
|
54
|
+
names.append(f"args_{i}")
|
|
55
|
+
|
|
56
|
+
kwargs_names = self._flat_kwarg_names(
|
|
57
|
+
kwargs_spec.children_specs, kwargs_spec.context
|
|
58
|
+
)
|
|
59
|
+
names.extend(kwargs_names)
|
|
60
|
+
return names
|
|
61
|
+
|
|
62
|
+
def _flat_kwarg_names(self, specs, context) -> List[str]:
|
|
63
|
+
flat_names = []
|
|
64
|
+
if context is None:
|
|
65
|
+
for i, spec in enumerate(specs):
|
|
66
|
+
if spec.children_specs:
|
|
67
|
+
flat_names.extend([
|
|
68
|
+
f"{i}_{name}"
|
|
69
|
+
for name in self._flat_kwarg_names(
|
|
70
|
+
spec.children_specs, spec.context
|
|
71
|
+
)
|
|
72
|
+
])
|
|
73
|
+
else:
|
|
74
|
+
flat_names.append(f"{i}")
|
|
75
|
+
else:
|
|
76
|
+
flat_ctx = self._flatten_list(context)
|
|
77
|
+
for prefix, spec in zip(flat_ctx, specs):
|
|
78
|
+
leaf_flat_names = self._flat_kwarg_names(
|
|
79
|
+
spec.children_specs, spec.context
|
|
80
|
+
)
|
|
81
|
+
if leaf_flat_names:
|
|
82
|
+
flat_names.extend([f"{prefix}_{name}" for name in leaf_flat_names])
|
|
83
|
+
else:
|
|
84
|
+
flat_names.append(prefix)
|
|
85
|
+
|
|
86
|
+
return flat_names
|
|
87
|
+
|
|
88
|
+
def _flatten_list(self, l: List) -> List:
|
|
89
|
+
flattened = []
|
|
90
|
+
for item in l:
|
|
91
|
+
if isinstance(item, list):
|
|
92
|
+
flattened.extend(self._flatten_list(item))
|
|
93
|
+
else:
|
|
94
|
+
flattened.append(item)
|
|
95
|
+
return flattened
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def flat_args(self) -> tuple[Any]:
|
|
99
|
+
args, kwargs = self._normalized_sample_args_kwargs
|
|
100
|
+
return tuple([*args, *kwargs.values()])
|