ai-edge-torch-nightly 0.2.0.dev20240806__py3-none-any.whl → 0.2.0.dev20240807__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (103) hide show
  1. ai_edge_torch/__init__.py +5 -5
  2. ai_edge_torch/{convert → _convert}/conversion.py +40 -50
  3. ai_edge_torch/_convert/conversion_utils.py +64 -0
  4. ai_edge_torch/{convert → _convert}/converter.py +83 -43
  5. ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
  6. ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
  7. ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
  8. ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
  9. ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
  10. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  11. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
  12. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
  13. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
  14. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
  15. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
  16. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
  17. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
  18. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
  19. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
  20. ai_edge_torch/_convert/signature.py +100 -0
  21. ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
  22. ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
  23. ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
  24. ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
  25. ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
  26. ai_edge_torch/config.py +24 -0
  27. ai_edge_torch/conftest.py +20 -0
  28. ai_edge_torch/debug/culprit.py +22 -22
  29. ai_edge_torch/debug/test/test_culprit.py +4 -3
  30. ai_edge_torch/debug/test/test_search_model.py +5 -5
  31. ai_edge_torch/debug/utils.py +11 -2
  32. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
  33. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
  34. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
  35. ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
  36. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
  37. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
  39. ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
  40. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
  41. ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
  42. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
  44. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
  45. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
  46. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
  47. ai_edge_torch/generative/examples/t5/t5.py +2 -2
  48. ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
  49. ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
  50. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
  51. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
  52. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
  55. ai_edge_torch/generative/fx_passes/__init__.py +2 -2
  56. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
  57. ai_edge_torch/generative/layers/attention.py +35 -26
  58. ai_edge_torch/generative/layers/attention_utils.py +23 -12
  59. ai_edge_torch/generative/layers/builder.py +0 -1
  60. ai_edge_torch/generative/layers/feed_forward.py +6 -10
  61. ai_edge_torch/generative/layers/kv_cache.py +0 -1
  62. ai_edge_torch/generative/layers/model_config.py +2 -5
  63. ai_edge_torch/generative/layers/normalization.py +5 -7
  64. ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
  65. ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
  66. ai_edge_torch/generative/layers/unet/model_config.py +14 -15
  67. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
  68. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
  69. ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
  70. ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
  71. ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
  72. ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
  73. ai_edge_torch/generative/test/test_model_conversion.py +24 -25
  74. ai_edge_torch/generative/test/test_quantize.py +10 -5
  75. ai_edge_torch/generative/utilities/loader.py +12 -12
  76. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
  77. ai_edge_torch/generative/utilities/t5_loader.py +12 -13
  78. ai_edge_torch/hlfb/__init__.py +1 -1
  79. ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
  80. ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
  81. ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
  82. ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
  83. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
  84. ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
  85. ai_edge_torch/lowertools/_shim.py +80 -0
  86. ai_edge_torch/lowertools/common_utils.py +89 -0
  87. ai_edge_torch/lowertools/odml_torch_utils.py +201 -0
  88. ai_edge_torch/{convert/conversion_utils.py → lowertools/torch_xla_utils.py} +35 -214
  89. ai_edge_torch/model.py +14 -9
  90. ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
  91. ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
  92. ai_edge_torch/quantize/quant_config.py +7 -7
  93. ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
  94. ai_edge_torch/version.py +1 -1
  95. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/METADATA +1 -1
  96. ai_edge_torch_nightly-0.2.0.dev20240807.dist-info/RECORD +141 -0
  97. ai_edge_torch_nightly-0.2.0.dev20240806.dist-info/RECORD +0 -133
  98. /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
  99. /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
  100. /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
  101. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/LICENSE +0 -0
  102. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/WHEEL +0 -0
  103. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/top_level.txt +0 -0
@@ -12,19 +12,22 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Build interpolate composite pass."""
15
16
 
16
17
  import functools
17
18
 
18
- from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
19
- from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
19
+ from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
20
+ from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
20
21
  from ai_edge_torch.hlfb import mark_pattern
22
+ from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
21
23
  import torch
22
24
 
23
25
  # For torch nightly released after mid June 2024,
24
26
  # torch.nn.functional.interpolate no longer gets exported into decomposed graph
25
- # but single aten op torch.ops.aten.upsample_nearest2d.vec/torch.ops.aten.upsample_bilinear2d.vec.
26
- # This behavior would our pattern matching based composite builder.
27
- # It requires the pattern and model graph to get decomposed first for backward compatibility.
27
+ # but a single aten op:
28
+ # torch.ops.aten.upsample_nearest2d.vec/torch.ops.aten.upsample_bilinear2d.vec.
29
+ # This would interefere with our pattern matching based composite builder.
30
+ # Here we register the now missing decompositions first.
28
31
  _INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions([
29
32
  torch.ops.aten.upsample_bilinear2d.vec,
30
33
  torch.ops.aten.upsample_nearest2d.vec,
@@ -33,7 +36,7 @@ _INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions([
33
36
 
34
37
  @functools.cache
35
38
  def _get_upsample_bilinear2d_pattern():
36
- pattern = mark_pattern.Pattern(
39
+ pattern = pattern_module.Pattern(
37
40
  "odml.upsample_bilinear2d",
38
41
  lambda x: torch.nn.functional.interpolate(
39
42
  x, scale_factor=2, mode="bilinear", align_corners=False
@@ -56,7 +59,7 @@ def _get_upsample_bilinear2d_pattern():
56
59
 
57
60
  @functools.cache
58
61
  def _get_upsample_bilinear2d_align_corners_pattern():
59
- pattern = mark_pattern.Pattern(
62
+ pattern = pattern_module.Pattern(
60
63
  "odml.upsample_bilinear2d",
61
64
  lambda x: torch.nn.functional.interpolate(
62
65
  x, scale_factor=2, mode="bilinear", align_corners=True
@@ -79,7 +82,7 @@ def _get_upsample_bilinear2d_align_corners_pattern():
79
82
 
80
83
  @functools.cache
81
84
  def _get_interpolate_nearest2d_pattern():
82
- pattern = mark_pattern.Pattern(
85
+ pattern = pattern_module.Pattern(
83
86
  "tfl.resize_nearest_neighbor",
84
87
  lambda x: torch.nn.functional.interpolate(
85
88
  x, scale_factor=2, mode="nearest"
@@ -13,8 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
17
- from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
16
+ from ai_edge_torch._convert.fx_passes import _pass_base
18
17
  import torch
19
18
  from torch.export import ExportedProgram
20
19
 
@@ -29,8 +28,8 @@ _dummy_decomp_table = {
29
28
  }
30
29
 
31
30
 
32
- class CanonicalizePass(ExportedProgramPassBase):
31
+ class CanonicalizePass(_pass_base.ExportedProgramPassBase):
33
32
 
34
33
  def call(self, exported_program: ExportedProgram):
35
34
  exported_program = exported_program.run_decompositions(_dummy_decomp_table)
36
- return ExportedProgramPassResult(exported_program, True)
35
+ return _pass_base.ExportedProgramPassResult(exported_program, True)
@@ -13,11 +13,11 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from ai_edge_torch import lowertools
16
17
  import torch
17
18
  from torch.fx.passes.infra.pass_base import PassBase
18
19
  from torch.fx.passes.infra.pass_base import PassResult
19
20
  import torch.utils._pytree as pytree
20
- import torch_xla.experimental.xla_mlir_debuginfo # Import required to register torch.ops.xla.write_mlir_debuginfo
21
21
 
22
22
 
23
23
  def _get_mlir_debuginfo(node: torch.fx.Node):
@@ -54,7 +54,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
54
54
  outputs = target(*args, **kwargs)
55
55
  outputs = pytree.tree_map_only(
56
56
  torch.Tensor,
57
- lambda x: torch.ops.xla.write_mlir_debuginfo(x, debuginfo),
57
+ lambda x: lowertools.write_mlir_debuginfo_op(x, debuginfo),
58
58
  outputs,
59
59
  )
60
60
  return outputs
@@ -0,0 +1,16 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass.pass_body import OptimizeLayoutTransposesPass # NOQA
@@ -12,13 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Layout check for the optimized layout transposes pass."""
16
+
15
17
  import dataclasses
16
18
  import operator
17
19
 
18
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
19
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite # NOQA
20
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
21
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass.op_func_registry import OpFuncRegistry # NOQA
20
+ from ai_edge_torch import lowertools
21
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite
22
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
23
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass.op_func_registry import OpFuncRegistry
22
24
  import torch
23
25
  from torch.fx import Node
24
26
 
@@ -205,7 +207,7 @@ def _aten_native_group_norm_checker(node):
205
207
  # ==== Ops must be NCHW
206
208
 
207
209
 
208
- @nhwcable_node_checkers.register(torch.ops.xla.mark_tensor)
210
+ @nhwcable_node_checkers.register(lowertools.mark_tensor_op)
209
211
  @nhwcable_node_checkers.register(utils.tensor_to_nchw)
210
212
  @nhwcable_node_checkers.register(utils.tensor_to_nhwc)
211
213
  @nhwcable_node_checkers.register("output")
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Layout mark for the optimized layout transposes pass."""
16
+
15
17
  import torch
16
18
 
17
19
  # Tag which is added to a node's meta to indicate that is is part of the NHWC
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Layout partitioners."""
15
16
 
16
17
  from . import greedy
17
18
  from . import min_cut
@@ -12,23 +12,31 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Greedy partitioning algorithm."""
15
16
 
16
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
17
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
17
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check
18
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
18
19
  import torch
19
20
 
20
21
 
21
22
  def partition(graph_module: torch.fx.GraphModule):
22
- """Partition the graph module into NHWC and non-NHWC subgraphs, and mark
23
- nodes in the NHWC partitions.
23
+ """Partition the graph module into NHWC and non-NHWC subgraphs.
24
+
25
+ Partition the graph module into NHWC and non-NHWC subgraphs and mark nodes in
26
+ the NHWC partitions.
24
27
 
25
28
  Implements O(|V|) greedy partitioning algorithm.
26
- See go/pytorch-layout-transpose-optimization for more details.
29
+
30
+ Args:
31
+ graph_module: The graph module to be partitioned.
32
+
33
+ Returns:
34
+ The partitioned graph module.
27
35
  """
28
36
  graph = graph_module.graph
29
37
 
30
38
  for node in list(graph.nodes):
31
- if len(node.all_input_nodes) == 0:
39
+ if not node.all_input_nodes:
32
40
  # This node has no inputs so we don't need to change anything
33
41
  continue
34
42
 
@@ -12,13 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Min cut solver for partitioning the graph module into NHWC and non-NHWC subgraphs."""
15
16
 
16
17
  import collections
17
18
  import dataclasses
18
- import itertools
19
19
 
20
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
21
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
20
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
21
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
22
22
  import numpy as np
23
23
  import scipy
24
24
  import torch
@@ -26,13 +26,12 @@ import torch
26
26
 
27
27
  def can_partition(graph_module: torch.fx.GraphModule):
28
28
  """Returns true if the input graph_module can be partitioned by min cut solver
29
+
29
30
  in a reasonable time.
30
31
 
31
32
  The min cut solver implements O(|V|^2|E|) Dinic's algorithm, which may
32
33
  take a long time to complete for large graph module. This function determines
33
34
  whether the graph module can be partitioned by the graph module size.
34
-
35
- See go/pytorch-layout-transpose-optimization for more details.
36
35
  """
37
36
  graph = graph_module.graph
38
37
  n_nodes = len(graph.nodes)
@@ -137,10 +136,10 @@ class MultiUsersDummyNode:
137
136
 
138
137
  def partition(graph_module: torch.fx.GraphModule):
139
138
  """Partition the graph module into NHWC and non-NHWC subgraphs, and mark
139
+
140
140
  nodes in the NHWC partitions.
141
141
 
142
142
  Implements O(|V|^2|E|) min-cut (optimal) partitioning algorithm.
143
- See go/pytorch-layout-transpose-optimization for more details.
144
143
  """
145
144
  graph = graph_module.graph
146
145
 
@@ -12,13 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Layout rewrite for the optimized layout transposes pass."""
16
+
15
17
  import operator
16
18
 
17
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
18
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
19
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass.op_func_registry import OpFuncRegistry # NOQA
19
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
20
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import op_func_registry
21
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
20
22
  import torch
21
- from torch.fx import Node
22
23
  import torch.utils._pytree as pytree
23
24
 
24
25
  aten = torch.ops.aten
@@ -26,7 +27,7 @@ aten = torch.ops.aten
26
27
  __all__ = ["rewrite_nhwc_node", "has_nhwc_rewriter"]
27
28
 
28
29
 
29
- class NHWCNodeRewritersRegistry(OpFuncRegistry):
30
+ class NHWCNodeRewritersRegistry(op_func_registry.OpFuncRegistry):
30
31
 
31
32
  def __missing__(self, op):
32
33
  def _rewriter(node):
@@ -38,14 +39,14 @@ class NHWCNodeRewritersRegistry(OpFuncRegistry):
38
39
  rewriters = NHWCNodeRewritersRegistry()
39
40
 
40
41
 
41
- def rewrite_nhwc_node(node: Node):
42
+ def rewrite_nhwc_node(node: torch.fx.Node):
42
43
  if not layout_mark.is_nhwc_node(node):
43
44
  return
44
45
 
45
46
  rewriters[node.target](node)
46
47
 
47
48
 
48
- def has_nhwc_rewriter(node: Node):
49
+ def has_nhwc_rewriter(node: torch.fx.Node):
49
50
  return node.target in rewriters
50
51
 
51
52
 
@@ -54,13 +55,13 @@ def has_nhwc_rewriter(node: Node):
54
55
 
55
56
  @rewriters.register(torch.ops.quantized_decomposed.dequantize_per_tensor)
56
57
  @rewriters.register(torch.ops.quantized_decomposed.quantize_per_tensor)
57
- def noop(node: Node):
58
+ def noop(node: torch.fx.Node):
58
59
  pass
59
60
 
60
61
 
61
62
  @rewriters.register(torch.ops.quantized_decomposed.dequantize_per_channel)
62
63
  @rewriters.register(torch.ops.quantized_decomposed.quantize_per_channel)
63
- def _qdq_per_channel_rewriter(node: Node):
64
+ def _qdq_per_channel_rewriter(node: torch.fx.Node):
64
65
  new_args = []
65
66
  new_kwargs = {}
66
67
 
@@ -199,7 +200,7 @@ def _qdq_per_channel_rewriter(node: Node):
199
200
  @rewriters.register(aten._prelu_kernel)
200
201
  @rewriters.register(aten.softplus)
201
202
  @rewriters.register(aten.silu)
202
- def noop(node: Node):
203
+ def noop(node: torch.fx.Node):
203
204
  pass
204
205
 
205
206
 
@@ -212,14 +213,16 @@ def noop(node: Node):
212
213
  @rewriters.register(aten.max_pool2d_with_indices)
213
214
  @rewriters.register(aten.avg_pool2d)
214
215
  @rewriters.register(aten._adaptive_avg_pool2d.default)
215
- def transpose_first_arg_rewriter(node: Node):
216
+ def transpose_first_arg_rewriter(node: torch.fx.Node):
216
217
  op = node.target
217
218
 
218
219
  def nhwc_op(x, *args, **kwargs):
219
220
  nonlocal op
220
221
  x = utils.tensor_to_nchw(x)
221
222
  res = pytree.tree_map_only(
222
- torch.Tensor, utils.tensor_to_nhwc, op(x, *args, **kwargs)
223
+ torch.Tensor,
224
+ utils.tensor_to_nhwc,
225
+ op(x, *args, **kwargs),
223
226
  )
224
227
  return res
225
228
 
@@ -227,7 +230,7 @@ def transpose_first_arg_rewriter(node: Node):
227
230
 
228
231
 
229
232
  @rewriters.register(aten.convolution)
230
- def _aten_convolution_rewriter(node: Node):
233
+ def _aten_convolution_rewriter(node: torch.fx.Node):
231
234
  op = node.target
232
235
 
233
236
  def conv_nhwc(input, weight, bias, *args, **kwargs):
@@ -286,7 +289,7 @@ def _aten_convolution_rewriter(node: Node):
286
289
  @rewriters.register(aten.sort.default)
287
290
  @rewriters.register(aten.topk.default)
288
291
  @rewriters.register(aten.cat.default)
289
- def dim_attr_rewriter(node: Node):
292
+ def dim_attr_rewriter(node: torch.fx.Node):
290
293
  op = node.target
291
294
 
292
295
  new_args = []
@@ -12,8 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
16
- import torch
15
+ """Op function registry for the optimized layout transposes pass."""
16
+
17
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
17
18
 
18
19
 
19
20
  class OpFuncRegistry(dict):
@@ -12,23 +12,21 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Optimize layout transposes pass."""
16
+
15
17
  import operator
16
18
  import os
17
- from typing import Optional, Tuple, Union
18
-
19
- from ai_edge_torch.convert.fx_passes import ExportedProgramPassBase
20
- from ai_edge_torch.convert.fx_passes import ExportedProgramPassResult
21
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
22
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
23
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
24
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite # NOQA
25
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
19
+ from typing import Union
20
+
21
+ from ai_edge_torch._convert.fx_passes import ExportedProgramPassBase
22
+ from ai_edge_torch._convert.fx_passes import ExportedProgramPassResult
23
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
24
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
25
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
26
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite # NOQA
27
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
26
28
  import torch
27
29
  import torch.ao.quantization.quantize_pt2e
28
- from torch.export import ExportedProgram
29
- from torch.fx import GraphModule
30
- from torch.fx import Node
31
- import torch.utils._pytree as pytree
32
30
 
33
31
  TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
34
32
 
@@ -51,8 +49,8 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
51
49
  transpose_func: TransposeFunc,
52
50
  transpose_node_meta: dict,
53
51
  ) -> list[torch.fx.Node]:
54
- """
55
- original:
52
+ """original:
53
+
56
54
  input_dq -> target
57
55
  insert the node as:
58
56
  input_dq -> (T q dq) -> target
@@ -86,8 +84,8 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
86
84
  transpose_func: TransposeFunc,
87
85
  transpose_node_meta: dict,
88
86
  ) -> list[torch.fx.Node]:
89
- """
90
- original:
87
+ """original:
88
+
91
89
  input_q -> target
92
90
  insert the node as:
93
91
  input_q -> (dq T q) -> target
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Utils for the optimized layout transposes pass."""
16
+
15
17
  from typing import Callable
16
18
 
17
19
  import torch
@@ -0,0 +1,100 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import dataclasses
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.utils._pytree as pytree
21
+
22
+
23
+ @dataclasses.dataclass
24
+ class Signature:
25
+ name: str
26
+ module: torch.nn.Module
27
+ sample_args: tuple[torch.Tensor]
28
+ sample_kwargs: dict[str, torch.Tensor]
29
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None
30
+
31
+ @property
32
+ def _normalized_sample_args_kwargs(self):
33
+ args, kwargs = self.sample_args, self.sample_kwargs
34
+ if args is not None:
35
+ if not isinstance(args, tuple):
36
+ # TODO(b/352584188): Check value types
37
+ raise ValueError("sample_args must be a tuple of torch tensors.")
38
+ if kwargs is not None:
39
+ if not isinstance(kwargs, dict) or not all(
40
+ isinstance(key, str) for key in kwargs.keys()
41
+ ):
42
+ # TODO(b/352584188): Check value types
43
+ raise ValueError("sample_kwargs must be a dict of string to tensor.")
44
+ args = args if args is not None else tuple()
45
+ kwargs = kwargs if kwargs is not None else {}
46
+ return args, kwargs
47
+
48
+ @property
49
+ def flat_arg_names(self) -> list[str]:
50
+ spec = pytree.tree_flatten(self._normalized_sample_args_kwargs)[1]
51
+ args_spec, kwargs_spec = spec.children_specs
52
+ names = []
53
+ for i in range(args_spec.num_leaves):
54
+ names.append(f"args_{i}")
55
+
56
+ kwargs_names = self._flat_kwarg_names(
57
+ kwargs_spec.children_specs, kwargs_spec.context
58
+ )
59
+ names.extend(kwargs_names)
60
+ return names
61
+
62
+ def _flat_kwarg_names(self, specs, context) -> List[str]:
63
+ flat_names = []
64
+ if context is None:
65
+ for i, spec in enumerate(specs):
66
+ if spec.children_specs:
67
+ flat_names.extend([
68
+ f"{i}_{name}"
69
+ for name in self._flat_kwarg_names(
70
+ spec.children_specs, spec.context
71
+ )
72
+ ])
73
+ else:
74
+ flat_names.append(f"{i}")
75
+ else:
76
+ flat_ctx = self._flatten_list(context)
77
+ for prefix, spec in zip(flat_ctx, specs):
78
+ leaf_flat_names = self._flat_kwarg_names(
79
+ spec.children_specs, spec.context
80
+ )
81
+ if leaf_flat_names:
82
+ flat_names.extend([f"{prefix}_{name}" for name in leaf_flat_names])
83
+ else:
84
+ flat_names.append(prefix)
85
+
86
+ return flat_names
87
+
88
+ def _flatten_list(self, l: List) -> List:
89
+ flattened = []
90
+ for item in l:
91
+ if isinstance(item, list):
92
+ flattened.extend(self._flatten_list(item))
93
+ else:
94
+ flattened.append(item)
95
+ return flattened
96
+
97
+ @property
98
+ def flat_args(self) -> tuple[Any]:
99
+ args, kwargs = self._normalized_sample_args_kwargs
100
+ return tuple([*args, *kwargs.values()])