ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl → 0.3.0.dev20240926__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (169) hide show
  1. ai_edge_torch/__init__.py +5 -4
  2. ai_edge_torch/_convert/conversion.py +112 -0
  3. ai_edge_torch/_convert/conversion_utils.py +64 -0
  4. ai_edge_torch/{convert → _convert}/converter.py +94 -48
  5. ai_edge_torch/_convert/fx_passes/__init__.py +22 -0
  6. ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +107 -44
  7. ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +23 -20
  8. ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +5 -6
  9. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/__init__.py +1 -1
  10. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +39 -9
  11. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
  12. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
  13. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +17 -8
  14. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +9 -8
  15. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +31 -18
  16. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +2 -2
  17. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +34 -24
  18. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
  19. ai_edge_torch/_convert/signature.py +66 -0
  20. ai_edge_torch/_convert/test/test_convert.py +495 -0
  21. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  22. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  23. ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -5
  24. ai_edge_torch/{convert → _convert}/to_channel_last_io.py +10 -3
  25. ai_edge_torch/config.py +27 -0
  26. ai_edge_torch/conftest.py +20 -0
  27. ai_edge_torch/debug/culprit.py +72 -40
  28. ai_edge_torch/debug/test/test_culprit.py +7 -5
  29. ai_edge_torch/debug/test/test_search_model.py +8 -7
  30. ai_edge_torch/debug/utils.py +14 -3
  31. ai_edge_torch/fx_pass_base.py +101 -0
  32. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +68 -0
  33. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +68 -0
  34. ai_edge_torch/generative/examples/gemma/{gemma.py → gemma1.py} +69 -55
  35. ai_edge_torch/generative/examples/gemma/gemma2.py +267 -0
  36. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  37. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +57 -0
  38. ai_edge_torch/generative/examples/gemma/verify_util.py +143 -0
  39. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +68 -0
  40. ai_edge_torch/generative/examples/openelm/openelm.py +206 -0
  41. ai_edge_torch/generative/examples/openelm/verify.py +64 -0
  42. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  43. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
  44. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +68 -0
  45. ai_edge_torch/generative/examples/{phi2 → phi}/phi2.py +70 -51
  46. ai_edge_torch/generative/examples/phi/phi3.py +286 -0
  47. ai_edge_torch/generative/examples/phi/verify.py +65 -0
  48. ai_edge_torch/generative/examples/phi/verify_phi3.py +70 -0
  49. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  50. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +68 -0
  51. ai_edge_torch/generative/examples/smollm/smollm.py +101 -0
  52. ai_edge_torch/generative/examples/smollm/verify.py +62 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
  54. ai_edge_torch/generative/examples/stable_diffusion/clip.py +83 -13
  55. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +27 -14
  56. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +74 -9
  57. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +179 -37
  58. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
  59. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +83 -58
  60. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
  61. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
  62. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
  63. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
  64. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
  65. ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
  66. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +28 -25
  67. ai_edge_torch/generative/examples/t5/t5.py +208 -159
  68. ai_edge_torch/generative/examples/t5/t5_attention.py +45 -30
  69. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  70. ai_edge_torch/generative/examples/test_models/toy_model.py +69 -41
  71. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +50 -64
  72. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  73. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +41 -39
  74. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +67 -54
  75. ai_edge_torch/generative/examples/tiny_llama/verify.py +64 -0
  76. ai_edge_torch/generative/fx_passes/__init__.py +4 -5
  77. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +10 -7
  78. ai_edge_torch/generative/layers/attention.py +141 -102
  79. ai_edge_torch/generative/layers/attention_utils.py +53 -12
  80. ai_edge_torch/generative/layers/builder.py +37 -7
  81. ai_edge_torch/generative/layers/feed_forward.py +39 -14
  82. ai_edge_torch/generative/layers/kv_cache.py +162 -50
  83. ai_edge_torch/generative/layers/model_config.py +84 -30
  84. ai_edge_torch/generative/layers/normalization.py +185 -7
  85. ai_edge_torch/generative/layers/rotary_position_embedding.py +6 -4
  86. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +48 -21
  87. ai_edge_torch/generative/layers/unet/blocks_2d.py +136 -77
  88. ai_edge_torch/generative/layers/unet/builder.py +7 -4
  89. ai_edge_torch/generative/layers/unet/model_config.py +17 -15
  90. ai_edge_torch/generative/quantize/example.py +7 -8
  91. ai_edge_torch/generative/quantize/quant_recipe.py +10 -7
  92. ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -1
  93. ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
  94. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  95. ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +9 -7
  96. ai_edge_torch/generative/test/test_model_conversion.py +124 -188
  97. ai_edge_torch/generative/test/test_model_conversion_large.py +251 -0
  98. ai_edge_torch/generative/test/test_quantize.py +76 -60
  99. ai_edge_torch/generative/test/utils.py +54 -0
  100. ai_edge_torch/generative/utilities/converter.py +82 -0
  101. ai_edge_torch/generative/utilities/loader.py +120 -57
  102. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +165 -57
  103. ai_edge_torch/generative/utilities/t5_loader.py +110 -81
  104. ai_edge_torch/generative/utilities/verifier.py +247 -0
  105. ai_edge_torch/hlfb/__init__.py +1 -1
  106. ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -7
  107. ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
  108. ai_edge_torch/hlfb/mark_pattern/pattern.py +39 -30
  109. ai_edge_torch/hlfb/test/test_mark_pattern.py +46 -20
  110. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +24 -11
  111. ai_edge_torch/lowertools/__init__.py +18 -0
  112. ai_edge_torch/lowertools/_shim.py +80 -0
  113. ai_edge_torch/lowertools/common_utils.py +142 -0
  114. ai_edge_torch/lowertools/odml_torch_utils.py +255 -0
  115. ai_edge_torch/lowertools/test_utils.py +60 -0
  116. ai_edge_torch/lowertools/torch_xla_utils.py +284 -0
  117. ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +29 -14
  118. ai_edge_torch/model.py +53 -18
  119. ai_edge_torch/odml_torch/__init__.py +20 -0
  120. ai_edge_torch/odml_torch/_torch_future.py +61 -0
  121. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  122. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  123. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  124. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  125. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  126. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  127. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  128. ai_edge_torch/odml_torch/export.py +357 -0
  129. ai_edge_torch/odml_torch/export_utils.py +168 -0
  130. ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
  131. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +150 -0
  132. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  133. ai_edge_torch/odml_torch/lowerings/__init__.py +25 -0
  134. ai_edge_torch/odml_torch/lowerings/_basic.py +258 -0
  135. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  136. ai_edge_torch/odml_torch/lowerings/_convolution.py +241 -0
  137. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +252 -0
  138. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  139. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  140. ai_edge_torch/odml_torch/lowerings/registry.py +96 -0
  141. ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
  142. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  143. ai_edge_torch/odml_torch/tf_integration.py +194 -0
  144. ai_edge_torch/quantize/pt2e_quantizer.py +52 -24
  145. ai_edge_torch/quantize/pt2e_quantizer_utils.py +43 -23
  146. ai_edge_torch/quantize/quant_config.py +13 -9
  147. ai_edge_torch/testing/model_coverage/model_coverage.py +29 -16
  148. ai_edge_torch/version.py +16 -0
  149. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/METADATA +7 -3
  150. ai_edge_torch_nightly-0.3.0.dev20240926.dist-info/RECORD +177 -0
  151. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/WHEEL +1 -1
  152. ai_edge_torch/convert/conversion.py +0 -117
  153. ai_edge_torch/convert/conversion_utils.py +0 -400
  154. ai_edge_torch/convert/fx_passes/__init__.py +0 -59
  155. ai_edge_torch/convert/fx_passes/_pass_base.py +0 -49
  156. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +0 -37
  157. ai_edge_torch/convert/test/test_convert.py +0 -311
  158. ai_edge_torch/convert/test/test_convert_composites.py +0 -192
  159. ai_edge_torch/convert/test/test_convert_multisig.py +0 -139
  160. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +0 -66
  161. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -64
  162. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -161
  163. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  164. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +0 -121
  165. /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
  166. /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
  167. /ai_edge_torch/generative/examples/{phi2 → openelm}/__init__.py +0 -0
  168. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/LICENSE +0 -0
  169. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/top_level.txt +0 -0
@@ -13,27 +13,22 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import copy
17
- import functools
18
16
  from typing import Any, Callable
19
-
17
+ from ai_edge_torch import fx_pass_base
18
+ from ai_edge_torch import lowertools
20
19
  import torch
21
- from torch.fx import GraphModule
22
- from torch.fx import Node
23
- from torch.fx.passes.infra.pass_base import PassBase
24
- from torch.fx.passes.infra.pass_base import PassResult
25
20
  import torch.utils._pytree as pytree
26
21
 
27
- from ai_edge_torch.hlfb import StableHLOCompositeBuilder
28
-
29
- _composite_builders: dict[Callable, Callable[[GraphModule, Node], None]] = {}
22
+ _composite_builders: dict[
23
+ Callable, Callable[[torch.fx.GraphModule, torch.fx.Node], None]
24
+ ] = {}
30
25
 
31
26
 
32
27
  def _register_composite_builder(op):
33
28
  def inner(func):
34
29
  if isinstance(op, torch._ops.OpOverloadPacket):
35
- for overload in v.overloads():
36
- _composite_builders[getattr(v, overload)] = func
30
+ for overload in op.overloads():
31
+ _composite_builders[getattr(op, overload)] = func
37
32
  else:
38
33
  _composite_builders[op] = func
39
34
  return func
@@ -41,7 +36,22 @@ def _register_composite_builder(op):
41
36
  return inner
42
37
 
43
38
 
44
- def _tree_map_to_composite_attr_values(values, *, stringify_incompatible_values=True):
39
+ def _tree_map_to_composite_attr_values(
40
+ values, *, stringify_incompatible_values=True
41
+ ):
42
+ """Convert a tree of values to a tree of composite attribute values.
43
+
44
+ This is used for pre-processing op attributes before passing them to
45
+ the composite op as attributes.
46
+
47
+ Args:
48
+ values: A tree of values.
49
+ stringify_incompatible_values: If True, stringify values that are not
50
+ compatible with composite attributes.
51
+
52
+ Returns:
53
+ A tree of composite attribute values.
54
+ """
45
55
 
46
56
  def convert(value):
47
57
  nonlocal stringify_incompatible_values
@@ -58,6 +68,11 @@ def _tree_map_to_composite_attr_values(values, *, stringify_incompatible_values=
58
68
 
59
69
 
60
70
  class TorchOpArgumentsMapper:
71
+ """A helper class to map op arguments to kwargs.
72
+
73
+ This is mainly used to extract the default values for op arguments and present
74
+ all arguments as kwargs.
75
+ """
61
76
 
62
77
  def __init__(self, op):
63
78
  if isinstance(op, torch._ops.OpOverloadPacket):
@@ -65,16 +80,26 @@ class TorchOpArgumentsMapper:
65
80
 
66
81
  assert hasattr(op, "_schema")
67
82
  self.op = op
68
- self.arg_specs = [(spec.name, spec.default_value) for spec in op._schema.arguments]
83
+ self.arg_specs = [
84
+ (spec.name, spec.default_value) for spec in op._schema.arguments
85
+ ]
69
86
 
70
87
  def get_full_kwargs(self, args, kwargs=None) -> dict[str, Any]:
71
- """Inspect the op's schema and extract all its args and kwargs
72
- into one single kwargs dict, with default values for those
73
- unspecified args and kwargs.
88
+ """Extracts all arguments of the op as kwargs.
89
+
90
+ Inspect the op's schema and extract all its args and kwargs into one single
91
+ kwargs dict, with default values for those unspecified args and kwargs.
92
+
93
+ Args:
94
+ args: The op's arguments.
95
+ kwargs: The op's kwargs.
96
+
97
+ Returns:
98
+ A kwargs dict with all args and kwargs.
74
99
  """
75
100
  full_kwargs = {**(kwargs or {})}
76
101
 
77
- for arg, (name, default_value) in zip(args, self.arg_specs):
102
+ for arg, (name, _) in zip(args, self.arg_specs):
78
103
  full_kwargs[name] = arg
79
104
 
80
105
  for name, default_value in self.arg_specs[len(args) :]:
@@ -85,12 +110,13 @@ class TorchOpArgumentsMapper:
85
110
 
86
111
 
87
112
  @_register_composite_builder(torch.ops.aten.hardswish.default)
88
- def _aten_hardswish(gm: GraphModule, node: Node):
113
+ def _aten_hardswish(_: torch.fx.GraphModule, node: torch.fx.Node):
114
+ """Build a composite for aten.hardswish.default."""
89
115
  op = node.target
90
116
 
91
117
  def hardswish(self: torch.Tensor):
92
118
  nonlocal op
93
- builder = StableHLOCompositeBuilder("aten.hardswish.default")
119
+ builder = lowertools.StableHLOCompositeBuilder("aten.hardswish.default")
94
120
  self = builder.mark_inputs(self)
95
121
  output = op(self)
96
122
  output = builder.mark_outputs(output)
@@ -100,7 +126,8 @@ def _aten_hardswish(gm: GraphModule, node: Node):
100
126
 
101
127
 
102
128
  @_register_composite_builder(torch.ops.aten.gelu.default)
103
- def _aten_gelu(gm: GraphModule, node: Node):
129
+ def _aten_gelu(_: torch.fx.GraphModule, node: torch.fx.Node):
130
+ """Build a composite for aten.gelu.default."""
104
131
  op = node.target
105
132
  args_mapper = TorchOpArgumentsMapper(op)
106
133
 
@@ -110,16 +137,17 @@ def _aten_gelu(gm: GraphModule, node: Node):
110
137
  full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
111
138
 
112
139
  # TFLite supports exact and tanh approximate.
113
- if full_kwargs["approximate"] != "none" and full_kwargs["approximate"] != "tanh":
140
+ if (
141
+ full_kwargs["approximate"] != "none"
142
+ and full_kwargs["approximate"] != "tanh"
143
+ ):
114
144
  return op(*args, **kwargs)
115
145
 
116
- builder = StableHLOCompositeBuilder(
146
+ builder = lowertools.StableHLOCompositeBuilder(
117
147
  "aten.gelu.default",
118
- attr=_tree_map_to_composite_attr_values(
119
- {
120
- "approximate": full_kwargs["approximate"],
121
- }
122
- ),
148
+ attr=_tree_map_to_composite_attr_values({
149
+ "approximate": full_kwargs["approximate"],
150
+ }),
123
151
  )
124
152
  full_kwargs["self"] = builder.mark_inputs(full_kwargs["self"])
125
153
  output = op(full_kwargs["self"])
@@ -130,7 +158,8 @@ def _aten_gelu(gm: GraphModule, node: Node):
130
158
 
131
159
 
132
160
  @_register_composite_builder(torch.ops.aten.avg_pool2d.default)
133
- def _aten_avg_pool2d(gm: GraphModule, node: Node):
161
+ def _aten_avg_pool2d(_: torch.fx.GraphModule, node: torch.fx.Node):
162
+ """Build a composite for aten.avg_pool2d.default."""
134
163
  op = node.target
135
164
  args_mapper = TorchOpArgumentsMapper(op)
136
165
 
@@ -150,7 +179,10 @@ def _aten_avg_pool2d(gm: GraphModule, node: Node):
150
179
  ):
151
180
  dim_output_size = int((dim_input_size + dim_stride - 1) / dim_stride)
152
181
  padding_needed = max(
153
- 0, (dim_output_size - 1) * dim_stride + dim_kernel_size - dim_input_size
182
+ 0,
183
+ (dim_output_size - 1) * dim_stride
184
+ + dim_kernel_size
185
+ - dim_input_size,
154
186
  )
155
187
  if padding_needed % 2 != 0:
156
188
  return False
@@ -191,18 +223,16 @@ def _aten_avg_pool2d(gm: GraphModule, node: Node):
191
223
  ):
192
224
  return op(*args, **kwargs)
193
225
 
194
- builder = StableHLOCompositeBuilder(
226
+ builder = lowertools.StableHLOCompositeBuilder(
195
227
  "aten.avg_pool2d.default",
196
- attr=_tree_map_to_composite_attr_values(
197
- {
198
- "kernel_size": full_kwargs["kernel_size"],
199
- "stride": full_kwargs["stride"],
200
- "padding": full_kwargs["padding"],
201
- "ceil_mode": full_kwargs["ceil_mode"],
202
- "count_include_pad": full_kwargs["count_include_pad"],
203
- "divisor_override": full_kwargs["divisor_override"],
204
- }
205
- ),
228
+ attr=_tree_map_to_composite_attr_values({
229
+ "kernel_size": full_kwargs["kernel_size"],
230
+ "stride": full_kwargs["stride"],
231
+ "padding": full_kwargs["padding"],
232
+ "ceil_mode": full_kwargs["ceil_mode"],
233
+ "count_include_pad": full_kwargs["count_include_pad"],
234
+ "divisor_override": full_kwargs["divisor_override"],
235
+ }),
206
236
  )
207
237
 
208
238
  full_kwargs["self"] = builder.mark_inputs(full_kwargs["self"])
@@ -213,13 +243,46 @@ def _aten_avg_pool2d(gm: GraphModule, node: Node):
213
243
  node.target = avg_pool2d
214
244
 
215
245
 
216
- class BuildAtenCompositePass(PassBase):
246
+ @_register_composite_builder(torch.ops.aten.embedding.default)
247
+ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
248
+ op = node.target
249
+ args_mapper = TorchOpArgumentsMapper(op)
250
+
251
+ def embedding(*args, **kwargs):
252
+ nonlocal op, args_mapper
253
+ full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
254
+ _, embedding_dim = full_kwargs["weight"].size()
255
+ idx = full_kwargs["indices"]
256
+
257
+ # Explicitly cast to INT32. This places the CastOp outside of the HLFB.
258
+ idx = idx.type(torch.int)
259
+ original_idx_shape = idx.size()
260
+
261
+ # Explicitly reshape to 1D. This places the ReshapeOp outside of the HLFB.
262
+ idx = torch.reshape(idx, (idx.numel(),))
263
+
264
+ builder = lowertools.StableHLOCompositeBuilder("odml.embedding_lookup")
265
+ full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs(
266
+ idx,
267
+ full_kwargs["weight"],
268
+ )
269
+ output = op(**full_kwargs)
270
+ output = builder.mark_outputs(output)
271
+
272
+ # Explicitly reshape back to the original shape. This places the ReshapeOp outside of the HLFB.
273
+ output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
274
+ return output
275
+
276
+ node.target = embedding
277
+
278
+
279
+ class BuildAtenCompositePass(fx_pass_base.PassBase):
217
280
 
218
- def call(self, graph_module: GraphModule):
281
+ def call(self, graph_module: torch.fx.GraphModule):
219
282
  for node in graph_module.graph.nodes:
220
283
  if node.target in _composite_builders:
221
284
  _composite_builders[node.target](graph_module, node)
222
285
 
223
286
  graph_module.graph.lint()
224
287
  graph_module.recompile()
225
- return PassResult(graph_module, True)
288
+ return fx_pass_base.PassResult(graph_module, True)
@@ -12,31 +12,30 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Build interpolate composite pass."""
15
16
 
16
17
  import functools
17
18
 
18
- import torch
19
-
20
- from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
21
- from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
19
+ from ai_edge_torch import fx_pass_base
22
20
  from ai_edge_torch.hlfb import mark_pattern
21
+ from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
22
+ import torch
23
23
 
24
24
  # For torch nightly released after mid June 2024,
25
25
  # torch.nn.functional.interpolate no longer gets exported into decomposed graph
26
- # but single aten op torch.ops.aten.upsample_nearest2d.vec/torch.ops.aten.upsample_bilinear2d.vec.
27
- # This behavior would our pattern matching based composite builder.
28
- # It requires the pattern and model graph to get decomposed first for backward compatibility.
29
- _INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions(
30
- [
31
- torch.ops.aten.upsample_bilinear2d.vec,
32
- torch.ops.aten.upsample_nearest2d.vec,
33
- ]
34
- )
26
+ # but a single aten op:
27
+ # torch.ops.aten.upsample_nearest2d.vec/torch.ops.aten.upsample_bilinear2d.vec.
28
+ # This would interefere with our pattern matching based composite builder.
29
+ # Here we register the now missing decompositions first.
30
+ _INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions([
31
+ torch.ops.aten.upsample_bilinear2d.vec,
32
+ torch.ops.aten.upsample_nearest2d.vec,
33
+ ])
35
34
 
36
35
 
37
36
  @functools.cache
38
37
  def _get_upsample_bilinear2d_pattern():
39
- pattern = mark_pattern.Pattern(
38
+ pattern = pattern_module.Pattern(
40
39
  "odml.upsample_bilinear2d",
41
40
  lambda x: torch.nn.functional.interpolate(
42
41
  x, scale_factor=2, mode="bilinear", align_corners=False
@@ -59,7 +58,7 @@ def _get_upsample_bilinear2d_pattern():
59
58
 
60
59
  @functools.cache
61
60
  def _get_upsample_bilinear2d_align_corners_pattern():
62
- pattern = mark_pattern.Pattern(
61
+ pattern = pattern_module.Pattern(
63
62
  "odml.upsample_bilinear2d",
64
63
  lambda x: torch.nn.functional.interpolate(
65
64
  x, scale_factor=2, mode="bilinear", align_corners=True
@@ -82,9 +81,11 @@ def _get_upsample_bilinear2d_align_corners_pattern():
82
81
 
83
82
  @functools.cache
84
83
  def _get_interpolate_nearest2d_pattern():
85
- pattern = mark_pattern.Pattern(
84
+ pattern = pattern_module.Pattern(
86
85
  "tfl.resize_nearest_neighbor",
87
- lambda x: torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest"),
86
+ lambda x: torch.nn.functional.interpolate(
87
+ x, scale_factor=2, mode="nearest"
88
+ ),
88
89
  export_args=(torch.rand(1, 3, 100, 100),),
89
90
  decomp_table=_INTERPOLATE_DECOMPOSITIONS,
90
91
  )
@@ -101,7 +102,7 @@ def _get_interpolate_nearest2d_pattern():
101
102
  return pattern
102
103
 
103
104
 
104
- class BuildInterpolateCompositePass(ExportedProgramPassBase):
105
+ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
105
106
 
106
107
  def __init__(self):
107
108
  super().__init__()
@@ -112,7 +113,9 @@ class BuildInterpolateCompositePass(ExportedProgramPassBase):
112
113
  ]
113
114
 
114
115
  def call(self, exported_program: torch.export.ExportedProgram):
115
- exported_program = exported_program.run_decompositions(_INTERPOLATE_DECOMPOSITIONS)
116
+ exported_program = exported_program.run_decompositions(
117
+ _INTERPOLATE_DECOMPOSITIONS
118
+ )
116
119
 
117
120
  graph_module = exported_program.graph_module
118
121
  for pattern in self._patterns:
@@ -120,4 +123,4 @@ class BuildInterpolateCompositePass(ExportedProgramPassBase):
120
123
 
121
124
  graph_module.graph.lint()
122
125
  graph_module.recompile()
123
- return ExportedProgramPassResult(exported_program, True)
126
+ return fx_pass_base.ExportedProgramPassResult(exported_program, True)
@@ -13,11 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from ai_edge_torch import fx_pass_base
17
+ from ai_edge_torch import lowertools
16
18
  import torch
17
- from torch.fx.passes.infra.pass_base import PassBase
18
- from torch.fx.passes.infra.pass_base import PassResult
19
19
  import torch.utils._pytree as pytree
20
- import torch_xla.experimental.xla_mlir_debuginfo # Import required to register torch.ops.xla.write_mlir_debuginfo
21
20
 
22
21
 
23
22
  def _get_mlir_debuginfo(node: torch.fx.Node):
@@ -54,7 +53,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
54
53
  outputs = target(*args, **kwargs)
55
54
  outputs = pytree.tree_map_only(
56
55
  torch.Tensor,
57
- lambda x: torch.ops.xla.write_mlir_debuginfo(x, debuginfo),
56
+ lambda x: lowertools.write_mlir_debuginfo_op(x, debuginfo),
58
57
  outputs,
59
58
  )
60
59
  return outputs
@@ -62,7 +61,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
62
61
  node.target = debuginfo_writer
63
62
 
64
63
 
65
- class InjectMlirDebuginfoPass(PassBase):
64
+ class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
66
65
 
67
66
  def call(self, graph_module: torch.fx.GraphModule):
68
67
  for node in graph_module.graph.nodes:
@@ -70,4 +69,4 @@ class InjectMlirDebuginfoPass(PassBase):
70
69
 
71
70
  graph_module.graph.lint()
72
71
  graph_module.recompile()
73
- return PassResult(graph_module, True)
72
+ return fx_pass_base.PassResult(graph_module, True)
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass.pass_body import OptimizeLayoutTransposesPass # NOQA
16
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass.pass_body import OptimizeLayoutTransposesPass # NOQA
@@ -12,17 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Layout check for the optimized layout transposes pass."""
16
+
15
17
  import dataclasses
16
18
  import operator
17
19
 
20
+ from ai_edge_torch import lowertools
21
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite
22
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
23
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass.op_func_registry import OpFuncRegistry
18
24
  import torch
19
25
  from torch.fx import Node
20
26
 
21
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
22
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite # NOQA
23
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
24
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass.op_func_registry import OpFuncRegistry # NOQA
25
-
26
27
  aten = torch.ops.aten
27
28
 
28
29
  __all__ = [
@@ -113,6 +114,10 @@ def is_4d(node: Node):
113
114
  val = node.meta.get("val")
114
115
  if val is None:
115
116
  return False
117
+
118
+ if isinstance(val, (list, tuple)) and val:
119
+ val = val[0]
120
+
116
121
  if not hasattr(val, "shape"):
117
122
  return False
118
123
 
@@ -145,8 +150,11 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
145
150
  # ==== Ops must be NHWC if possible
146
151
 
147
152
 
153
+ @layout_sensitive_inputs_getters.register(aten.conv2d)
148
154
  @layout_sensitive_inputs_getters.register(aten.convolution)
149
- @layout_sensitive_inputs_getters.register(aten._native_batch_norm_legit_no_training)
155
+ @layout_sensitive_inputs_getters.register(
156
+ aten._native_batch_norm_legit_no_training
157
+ )
150
158
  @layout_sensitive_inputs_getters.register(aten.native_group_norm)
151
159
  def _first_arg_getter(node):
152
160
  return [node.args[0]]
@@ -161,6 +169,7 @@ def _first_arg_getter(node):
161
169
  @nhwcable_node_checkers.register(aten.upsample_bilinear2d)
162
170
  @nhwcable_node_checkers.register(aten.upsample_nearest2d)
163
171
  @nhwcable_node_checkers.register(aten._adaptive_avg_pool2d)
172
+ @nhwcable_node_checkers.register(aten.conv2d)
164
173
  @nhwcable_node_checkers.register(aten.convolution)
165
174
  def _all_layout_sensitive_inputs_are_4d_checker(node: Node):
166
175
  can_be = all_layout_sensitive_inputs_are_4d(node)
@@ -168,10 +177,31 @@ def _all_layout_sensitive_inputs_are_4d_checker(node: Node):
168
177
 
169
178
 
170
179
  @nhwcable_node_checkers.register(aten._native_batch_norm_legit_no_training)
171
- @nhwcable_node_checkers.register(aten.native_group_norm)
172
180
  def _aten_norm_checker(node):
173
181
  val = node.meta.get("val")
174
- if not isinstance(val, (list, tuple)) or not val or not hasattr(val[0], "shape"):
182
+ if (
183
+ not isinstance(val, (list, tuple))
184
+ or not val
185
+ or not hasattr(val[0], "shape")
186
+ ):
187
+ return NHWCable(can_be=False, must_be=False)
188
+ return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
189
+
190
+
191
+ @nhwcable_node_checkers.register(aten.native_group_norm)
192
+ def _aten_native_group_norm_checker(node):
193
+ val = node.meta.get("val")
194
+ if (
195
+ not isinstance(val, (list, tuple))
196
+ or not val
197
+ or not hasattr(val[0], "shape")
198
+ ):
199
+ return NHWCable(can_be=False, must_be=False)
200
+ if len(node.args) >= 3 and (
201
+ node.args[1] is not None or node.args[2] is not None
202
+ ):
203
+ # Disable NHWC rewriter due to precision issue with weight and bias.
204
+ # TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
175
205
  return NHWCable(can_be=False, must_be=False)
176
206
  return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
177
207
 
@@ -179,7 +209,7 @@ def _aten_norm_checker(node):
179
209
  # ==== Ops must be NCHW
180
210
 
181
211
 
182
- @nhwcable_node_checkers.register(torch.ops.xla.mark_tensor)
212
+ @nhwcable_node_checkers.register(lowertools.mark_tensor_op)
183
213
  @nhwcable_node_checkers.register(utils.tensor_to_nchw)
184
214
  @nhwcable_node_checkers.register(utils.tensor_to_nhwc)
185
215
  @nhwcable_node_checkers.register("output")
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Layout mark for the optimized layout transposes pass."""
16
+
15
17
  import torch
16
18
 
17
19
  # Tag which is added to a node's meta to indicate that is is part of the NHWC
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Layout partitioners."""
15
16
 
16
17
  from . import greedy
17
18
  from . import min_cut
@@ -12,24 +12,31 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Greedy partitioning algorithm."""
15
16
 
17
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check
18
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
16
19
  import torch
17
20
 
18
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
19
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
20
-
21
21
 
22
22
  def partition(graph_module: torch.fx.GraphModule):
23
- """Partition the graph module into NHWC and non-NHWC subgraphs, and mark
24
- nodes in the NHWC partitions.
23
+ """Partition the graph module into NHWC and non-NHWC subgraphs.
24
+
25
+ Partition the graph module into NHWC and non-NHWC subgraphs and mark nodes in
26
+ the NHWC partitions.
25
27
 
26
28
  Implements O(|V|) greedy partitioning algorithm.
27
- See go/pytorch-layout-transpose-optimization for more details.
29
+
30
+ Args:
31
+ graph_module: The graph module to be partitioned.
32
+
33
+ Returns:
34
+ The partitioned graph module.
28
35
  """
29
36
  graph = graph_module.graph
30
37
 
31
38
  for node in list(graph.nodes):
32
- if len(node.all_input_nodes) == 0:
39
+ if not node.all_input_nodes:
33
40
  # This node has no inputs so we don't need to change anything
34
41
  continue
35
42
 
@@ -45,7 +52,9 @@ def partition(graph_module: torch.fx.GraphModule):
45
52
 
46
53
  layout_sensitive_inputs = layout_check.get_layout_sensitive_inputs(node)
47
54
 
48
- should_be_nhwc = any(map(layout_mark.is_nhwc_node, layout_sensitive_inputs))
55
+ should_be_nhwc = any(
56
+ map(layout_mark.is_nhwc_node, layout_sensitive_inputs)
57
+ )
49
58
  for input_node in layout_sensitive_inputs:
50
59
  if not layout_mark.is_nhwc_node(input_node) and not layout_check.is_4d(
51
60
  input_node
@@ -12,28 +12,26 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Min cut solver for partitioning the graph module into NHWC and non-NHWC subgraphs."""
15
16
 
16
17
  import collections
17
18
  import dataclasses
18
- import itertools
19
19
 
20
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
21
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
20
22
  import numpy as np
21
23
  import scipy
22
24
  import torch
23
25
 
24
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
25
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
26
-
27
26
 
28
27
  def can_partition(graph_module: torch.fx.GraphModule):
29
28
  """Returns true if the input graph_module can be partitioned by min cut solver
29
+
30
30
  in a reasonable time.
31
31
 
32
32
  The min cut solver implements O(|V|^2|E|) Dinic's algorithm, which may
33
33
  take a long time to complete for large graph module. This function determines
34
34
  whether the graph module can be partitioned by the graph module size.
35
-
36
- See go/pytorch-layout-transpose-optimization for more details.
37
35
  """
38
36
  graph = graph_module.graph
39
37
  n_nodes = len(graph.nodes)
@@ -83,7 +81,10 @@ class MinCutSolver:
83
81
  def graph(self):
84
82
  edges = np.array(self.edges)
85
83
  return scipy.sparse.csr_matrix(
86
- (np.minimum(edges[:, 2], MinCutSolver.INF_COST), (edges[:, 0], edges[:, 1])),
84
+ (
85
+ np.minimum(edges[:, 2], MinCutSolver.INF_COST),
86
+ (edges[:, 0], edges[:, 1]),
87
+ ),
87
88
  shape=(self._nodes_cnt, self._nodes_cnt),
88
89
  dtype=np.int32,
89
90
  )
@@ -135,10 +136,10 @@ class MultiUsersDummyNode:
135
136
 
136
137
  def partition(graph_module: torch.fx.GraphModule):
137
138
  """Partition the graph module into NHWC and non-NHWC subgraphs, and mark
139
+
138
140
  nodes in the NHWC partitions.
139
141
 
140
142
  Implements O(|V|^2|E|) min-cut (optimal) partitioning algorithm.
141
- See go/pytorch-layout-transpose-optimization for more details.
142
143
  """
143
144
  graph = graph_module.graph
144
145