ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (91) hide show
  1. ai_edge_torch/__init__.py +30 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +330 -0
  5. ai_edge_torch/convert/converter.py +171 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
  9. ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +273 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +171 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/debug/__init__.py +16 -0
  27. ai_edge_torch/debug/culprit.py +423 -0
  28. ai_edge_torch/debug/test/__init__.py +14 -0
  29. ai_edge_torch/debug/test/test_culprit.py +133 -0
  30. ai_edge_torch/debug/utils.py +48 -0
  31. ai_edge_torch/experimental/__init__.py +14 -0
  32. ai_edge_torch/generative/__init__.py +14 -0
  33. ai_edge_torch/generative/examples/__init__.py +14 -0
  34. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  35. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  36. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  37. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  39. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  40. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  42. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  43. ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
  44. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
  46. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  47. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  48. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  49. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  50. ai_edge_torch/generative/layers/__init__.py +14 -0
  51. ai_edge_torch/generative/layers/attention.py +288 -0
  52. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  53. ai_edge_torch/generative/layers/builder.py +103 -0
  54. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  55. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  56. ai_edge_torch/generative/layers/model_config.py +135 -0
  57. ai_edge_torch/generative/layers/normalization.py +62 -0
  58. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  59. ai_edge_torch/generative/quantize/__init__.py +14 -0
  60. ai_edge_torch/generative/quantize/example.py +45 -0
  61. ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
  62. ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
  63. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  64. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  65. ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
  66. ai_edge_torch/generative/test/__init__.py +14 -0
  67. ai_edge_torch/generative/test/test_model_conversion.py +201 -0
  68. ai_edge_torch/generative/test/test_quantize.py +109 -0
  69. ai_edge_torch/generative/utilities/__init__.py +15 -0
  70. ai_edge_torch/generative/utilities/loader.py +290 -0
  71. ai_edge_torch/generative/utilities/t5_loader.py +467 -0
  72. ai_edge_torch/hlfb/__init__.py +16 -0
  73. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  74. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  75. ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
  76. ai_edge_torch/hlfb/test/__init__.py +14 -0
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  79. ai_edge_torch/model.py +134 -0
  80. ai_edge_torch/quantize/__init__.py +16 -0
  81. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  82. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  83. ai_edge_torch/quantize/quant_config.py +85 -0
  84. ai_edge_torch/testing/__init__.py +14 -0
  85. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  86. ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
  87. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
  88. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
  89. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
  90. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
  91. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
@@ -0,0 +1,400 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import operator
16
+
17
+ import torch
18
+ from torch.fx import Node
19
+ import torch.utils._pytree as pytree
20
+
21
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
22
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
23
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass.op_func_registry import OpFuncRegistry # NOQA
24
+
25
+ aten = torch.ops.aten
26
+
27
+ __all__ = ["rewrite_nhwc_node", "has_nhwc_rewriter"]
28
+
29
+
30
+ class NHWCNodeRewritersRegistry(OpFuncRegistry):
31
+
32
+ def __missing__(self, op):
33
+ def _rewriter(node):
34
+ raise RuntimeError(f"NHWC node rewriter not found: {str(node)}")
35
+
36
+ return _rewriter
37
+
38
+
39
+ rewriters = NHWCNodeRewritersRegistry()
40
+
41
+
42
+ def rewrite_nhwc_node(node: Node):
43
+ if not layout_mark.is_nhwc_node(node):
44
+ return
45
+
46
+ rewriters[node.target](node)
47
+
48
+
49
+ def has_nhwc_rewriter(node: Node):
50
+ return node.target in rewriters
51
+
52
+
53
+ # ======= Quantize ops
54
+
55
+
56
+ @rewriters.register(torch.ops.quantized_decomposed.dequantize_per_tensor)
57
+ @rewriters.register(torch.ops.quantized_decomposed.quantize_per_tensor)
58
+ def noop(node: Node):
59
+ pass
60
+
61
+
62
+ @rewriters.register(torch.ops.quantized_decomposed.dequantize_per_channel)
63
+ @rewriters.register(torch.ops.quantized_decomposed.quantize_per_channel)
64
+ def _qdq_per_channel_rewriter(node: Node):
65
+ new_args = []
66
+ new_kwargs = {}
67
+
68
+ def axis_nchw_to_nhwc(axis: int):
69
+ axis = axis if axis >= 0 else 4 + axis
70
+ return {3: 2, 2: 1, 1: 3}.get(axis, axis)
71
+
72
+ for arg, spec in zip(node.args, op._schema.arguments):
73
+ if spec.name == "axis":
74
+ new_args.append(axis_nchw_to_nhwc(arg))
75
+ else:
76
+ new_args.append(arg)
77
+
78
+ for spec in op._schema.arguments[len(node.args) :]:
79
+ if spec.name not in node.kwargs:
80
+ continue
81
+
82
+ if spec.name == "axis":
83
+ new_kwargs[spec.name] = axis_nchw_to_nhwc(node.kwargs[spec.name])
84
+ else:
85
+ new_kwargs[spec.name] = node.kwargs[spec.name]
86
+
87
+ node.args = tuple(new_args)
88
+ node.kwargs = new_kwargs
89
+
90
+
91
+ # ======= Noop ops (layout insensitive ops)
92
+
93
+
94
+ @rewriters.register(utils.tensor_to_nhwc)
95
+ @rewriters.register(utils.tensor_to_nchw)
96
+ @rewriters.register(operator.getitem)
97
+ @rewriters.register("output")
98
+ @rewriters.register(aten.add.Tensor)
99
+ @rewriters.register(aten.add.Scalar)
100
+ @rewriters.register(aten.atan2.default)
101
+ @rewriters.register(aten.atan2.out)
102
+ @rewriters.register(aten.bitwise_and.Tensor)
103
+ @rewriters.register(aten.bitwise_and.Scalar)
104
+ @rewriters.register(aten.bitwise_or.Tensor)
105
+ @rewriters.register(aten.bitwise_or.Scalar)
106
+ @rewriters.register(aten.bitwise_xor.Tensor)
107
+ @rewriters.register(aten.bitwise_xor.Scalar)
108
+ @rewriters.register(aten.div.Tensor)
109
+ @rewriters.register(aten.div.Scalar)
110
+ @rewriters.register(aten.div.Tensor_mode)
111
+ @rewriters.register(aten.div.Scalar_mode)
112
+ @rewriters.register(aten.fmod.Tensor)
113
+ @rewriters.register(aten.fmod.Scalar)
114
+ @rewriters.register(aten.mul.Tensor)
115
+ @rewriters.register(aten.mul.Scalar)
116
+ @rewriters.register(aten.remainder.Tensor)
117
+ @rewriters.register(aten.remainder.Scalar)
118
+ @rewriters.register(aten.sub.Tensor)
119
+ @rewriters.register(aten.sub.Scalar)
120
+ @rewriters.register(aten.eq.Tensor)
121
+ @rewriters.register(aten.eq.Scalar)
122
+ @rewriters.register(aten.ne.Tensor)
123
+ @rewriters.register(aten.ne.Scalar)
124
+ @rewriters.register(aten.le.Tensor)
125
+ @rewriters.register(aten.le.Scalar)
126
+ @rewriters.register(aten.ge.Tensor)
127
+ @rewriters.register(aten.ge.Scalar)
128
+ @rewriters.register(aten.gt.Tensor)
129
+ @rewriters.register(aten.gt.Scalar)
130
+ @rewriters.register(aten.lt.Tensor)
131
+ @rewriters.register(aten.lt.Scalar)
132
+ @rewriters.register(aten.maximum.default)
133
+ @rewriters.register(aten.minimum.default)
134
+ @rewriters.register(aten.mean.default)
135
+ @rewriters.register(aten.prod.default)
136
+ @rewriters.register(aten.abs.default)
137
+ @rewriters.register(aten.acos.default)
138
+ @rewriters.register(aten.acosh.default)
139
+ @rewriters.register(aten.asin.default)
140
+ @rewriters.register(aten.asinh.default)
141
+ @rewriters.register(aten.atan.default)
142
+ @rewriters.register(aten.atanh.default)
143
+ @rewriters.register(aten.bitwise_not.default)
144
+ @rewriters.register(aten.ceil.default)
145
+ @rewriters.register(aten.clamp.default)
146
+ @rewriters.register(aten.clamp.Tensor)
147
+ @rewriters.register(aten.cos.default)
148
+ @rewriters.register(aten.cosh.default)
149
+ @rewriters.register(aten.erf.default)
150
+ @rewriters.register(aten.exp.default)
151
+ @rewriters.register(aten.expm1.default)
152
+ @rewriters.register(aten.floor.default)
153
+ @rewriters.register(aten.log.default)
154
+ @rewriters.register(aten.log10.default)
155
+ @rewriters.register(aten.log1p.default)
156
+ @rewriters.register(aten.log2.default)
157
+ @rewriters.register(aten.isnan.default)
158
+ @rewriters.register(aten.neg.default)
159
+ @rewriters.register(aten.pow.Tensor_Tensor)
160
+ @rewriters.register(aten.pow.Tensor_Scalar)
161
+ @rewriters.register(aten.pow.Scalar)
162
+ @rewriters.register(aten.reciprocal.default)
163
+ @rewriters.register(aten.round.default)
164
+ @rewriters.register(aten.rsqrt.default)
165
+ @rewriters.register(aten.sigmoid.default)
166
+ @rewriters.register(aten.sign.default)
167
+ @rewriters.register(aten.sin.default)
168
+ @rewriters.register(aten.sinh.default)
169
+ @rewriters.register(aten.sqrt.default)
170
+ @rewriters.register(aten.tan.default)
171
+ @rewriters.register(aten.tanh.default)
172
+ @rewriters.register(aten.trunc.default)
173
+ @rewriters.register(aten.nonzero.default)
174
+ @rewriters.register(aten.copy.default)
175
+ @rewriters.register(aten.mm.default)
176
+ @rewriters.register(aten.fill.Scalar)
177
+ @rewriters.register(aten.col2im.default)
178
+ @rewriters.register(aten.addmm.default)
179
+ @rewriters.register(aten.gelu.default)
180
+ @rewriters.register(aten.hardtanh.default)
181
+ @rewriters.register(aten.leaky_relu.default)
182
+ @rewriters.register(aten.relu.default)
183
+ @rewriters.register(aten.arange.start_step)
184
+ @rewriters.register(aten.isinf.default)
185
+ @rewriters.register(aten.logical_and.default)
186
+ @rewriters.register(aten.logical_not.default)
187
+ @rewriters.register(aten.logical_or.default)
188
+ @rewriters.register(aten.logical_xor.default)
189
+ @rewriters.register(aten.where.self)
190
+ @rewriters.register(aten.clone.default)
191
+ @rewriters.register(aten.any.default)
192
+ @rewriters.register(aten.repeat.default)
193
+ @rewriters.register(aten.alias.default)
194
+ @rewriters.register(aten._pdist_forward.default)
195
+ @rewriters.register(aten._cdist_forward.default)
196
+ @rewriters.register(aten.bmm.default)
197
+ @rewriters.register(aten.hardswish)
198
+ @rewriters.register(aten.hardsigmoid)
199
+ @rewriters.register(aten._to_copy)
200
+ @rewriters.register(aten._prelu_kernel)
201
+ @rewriters.register(aten.softplus)
202
+ @rewriters.register(aten.silu)
203
+ def noop(node: Node):
204
+ pass
205
+
206
+
207
+ # ======= Add transposes before and after NCHW-only ops (T-aten-T)
208
+
209
+
210
+ @rewriters.register(aten.upsample_bilinear2d)
211
+ @rewriters.register(aten.upsample_nearest2d)
212
+ @rewriters.register(aten.max_pool2d)
213
+ @rewriters.register(aten.max_pool2d_with_indices)
214
+ @rewriters.register(aten.avg_pool2d)
215
+ @rewriters.register(aten._adaptive_avg_pool2d.default)
216
+ def transpose_first_arg_rewriter(node: Node):
217
+ op = node.target
218
+
219
+ def nhwc_op(x, *args, **kwargs):
220
+ nonlocal op
221
+ x = utils.tensor_to_nchw(x)
222
+ res = pytree.tree_map_only(
223
+ torch.Tensor, utils.tensor_to_nhwc, op(x, *args, **kwargs)
224
+ )
225
+ return res
226
+
227
+ node.target = nhwc_op
228
+
229
+
230
+ @rewriters.register(aten.convolution)
231
+ def _aten_convolution_rewriter(node: Node):
232
+ op = node.target
233
+
234
+ def conv_nhwc(input, weight, bias, *args, **kwargs):
235
+ nonlocal op
236
+ nhwc_bias = None
237
+ if bias is not None and len(bias.shape) == 1:
238
+ nhwc_bias = bias
239
+ bias = None
240
+
241
+ input = utils.tensor_to_nchw(input)
242
+ res = pytree.tree_map_only(
243
+ torch.Tensor,
244
+ utils.tensor_to_nhwc,
245
+ op(input, weight, bias, *args, **kwargs),
246
+ )
247
+
248
+ if nhwc_bias is not None:
249
+ res += nhwc_bias
250
+ return res
251
+
252
+ node.target = conv_nhwc
253
+
254
+
255
+ # ======= Rewrite dim attribute(s)
256
+
257
+
258
+ @rewriters.register(aten._softmax.default)
259
+ @rewriters.register(aten.select.int)
260
+ @rewriters.register(aten.slice.Tensor)
261
+ @rewriters.register(aten.sum.dim_IntList)
262
+ @rewriters.register(aten.mean.dim)
263
+ @rewriters.register(aten.prod.dim_int)
264
+ @rewriters.register(aten.var.dim)
265
+ @rewriters.register(aten.var.correction)
266
+ @rewriters.register(aten.slice_scatter.default)
267
+ @rewriters.register(aten.diagonal.default)
268
+ @rewriters.register(aten.select_scatter.default)
269
+ @rewriters.register(aten.sym_size.int)
270
+ @rewriters.register(aten.sym_stride.int)
271
+ @rewriters.register(aten._log_softmax.default)
272
+ @rewriters.register(aten.split_with_sizes.default)
273
+ @rewriters.register(aten.squeeze.dim)
274
+ @rewriters.register(aten.squeeze.dims)
275
+ @rewriters.register(aten.scatter.value)
276
+ @rewriters.register(aten.scatter.src)
277
+ @rewriters.register(aten.scatter_add.default)
278
+ @rewriters.register(aten.scatter_reduce.two)
279
+ @rewriters.register(aten.any.dim)
280
+ @rewriters.register(aten.any.dims)
281
+ @rewriters.register(aten.flip.default)
282
+ @rewriters.register(aten.index_select.default)
283
+ @rewriters.register(aten.cumsum.default)
284
+ @rewriters.register(aten.max.dim)
285
+ @rewriters.register(aten.min.dim)
286
+ @rewriters.register(aten.gather.default)
287
+ @rewriters.register(aten.sort.default)
288
+ @rewriters.register(aten.topk.default)
289
+ @rewriters.register(aten.cat.default)
290
+ def dim_attr_rewriter(node: Node):
291
+ op = node.target
292
+
293
+ new_args = []
294
+ new_kwargs = {}
295
+
296
+ def dims_nchw_to_nhwc(dims: list[int]):
297
+ def convert(dim: int):
298
+ dim = dim if dim >= 0 else 4 + dim
299
+ return {3: 2, 2: 1, 1: 3}.get(dim, dim)
300
+
301
+ dims = pytree.tree_map_only(int, convert, dims)
302
+ dims = pytree.tree_map_only(torch.SymInt, convert, dims)
303
+ return dims
304
+
305
+ for arg, spec in zip(node.args, op._schema.arguments):
306
+ if spec.name.startswith("dim"):
307
+ new_args.append(dims_nchw_to_nhwc(arg))
308
+ else:
309
+ new_args.append(arg)
310
+
311
+ for spec in op._schema.arguments[len(node.args) :]:
312
+ if spec.name not in node.kwargs:
313
+ continue
314
+
315
+ if spec.name.startswith("dim"):
316
+ new_kwargs[spec.name] = dims_nchw_to_nhwc(node.kwargs[spec.name])
317
+ else:
318
+ new_kwargs[spec.name] = node.kwargs[spec.name]
319
+
320
+ node.args = tuple(new_args)
321
+ node.kwargs = new_kwargs
322
+
323
+
324
+ # ======= Others
325
+
326
+
327
+ @rewriters.register(aten._native_batch_norm_legit_no_training.default)
328
+ def _aten__native_batch_norm_legit_no_training(node):
329
+ def batch_norm(input, weight, bias, running_mean, running_var, momentum, eps):
330
+ a = input - running_mean
331
+ b = torch.sqrt(running_var + eps)
332
+ return a / b * weight + bias, None, None
333
+
334
+ node.target = batch_norm
335
+
336
+
337
+ @rewriters.register(aten.native_group_norm.default)
338
+ def _aten_native_group_norm(node):
339
+
340
+ def native_group_norm(
341
+ input,
342
+ weight,
343
+ bias,
344
+ batch_size: int,
345
+ num_channels: int,
346
+ flattened_inner_size: int,
347
+ num_groups: int,
348
+ eps: float,
349
+ ):
350
+ input_reshaped = torch.reshape(
351
+ input,
352
+ [batch_size, flattened_inner_size, num_groups, num_channels // num_groups],
353
+ )
354
+ reduction_dims = [1, 3]
355
+
356
+ biased_var, mean = torch.var_mean(
357
+ input_reshaped, dim=reduction_dims, unbiased=False, keepdim=True
358
+ )
359
+ rstd = torch.rsqrt(biased_var + eps)
360
+
361
+ out = (input_reshaped - mean) * rstd
362
+ out = torch.reshape(out, input.shape)
363
+
364
+ if weight is not None:
365
+ out = out * weight
366
+ if bias is not None:
367
+ out = out + bias
368
+
369
+ mean = torch.squeeze(mean, reduction_dims)
370
+ rstd = torch.squeeze(rstd, reduction_dims)
371
+
372
+ return out, mean, rstd
373
+
374
+ node.target = native_group_norm
375
+
376
+
377
+ @rewriters.register(aten.index)
378
+ @rewriters.register(aten._unsafe_index)
379
+ def _aten_index(node):
380
+ op = node.target
381
+
382
+ def index_nhwc(x, indices=[], *args, **kwargs):
383
+ nonlocal op
384
+ indices = list(indices)
385
+ if len(indices) < 4:
386
+ indices += [None] * (4 - len(indices))
387
+
388
+ indices[1:4] = indices[2], indices[3], indices[1]
389
+ return op(x, indices, *args, **kwargs)
390
+
391
+ node.target = index_nhwc
392
+
393
+
394
+ @rewriters.register(aten.reflection_pad2d.default)
395
+ def _aten_reflection_pad2d(node):
396
+ def reflection_pad2d_nhwc(x, padding):
397
+ padding = [0, 0] + padding
398
+ return torch.nn.functional.pad(x, padding, mode="reflect")
399
+
400
+ node.target = reflection_pad2d_nhwc
@@ -0,0 +1,30 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import torch
16
+
17
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
18
+
19
+
20
+ class OpFuncRegistry(dict):
21
+
22
+ def register(self, op):
23
+ ops = utils.flatten_torch_op_overloads(op)
24
+
25
+ def inner(func):
26
+ for op in ops:
27
+ self[op] = func
28
+ return func
29
+
30
+ return inner
@@ -0,0 +1,286 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import os
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.ao.quantization.quantize_pt2e
20
+ from torch.export import ExportedProgram
21
+ from torch.fx import GraphModule
22
+ from torch.fx import Node
23
+ import torch.utils._pytree as pytree
24
+
25
+ from ai_edge_torch.convert.fx_passes import ExportedProgramPassBase
26
+ from ai_edge_torch.convert.fx_passes import ExportedProgramPassResult
27
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
28
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
29
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
30
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite # NOQA
31
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
32
+
33
+ TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
34
+
35
+
36
+ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
37
+
38
+ def get_source_meta(self, node: torch.fx.Node):
39
+ keys = ["stack_trace", "nn_module_stack", "source_fn_stack", "from_node"]
40
+ meta = {}
41
+ for key in keys:
42
+ if key in node.meta:
43
+ meta[key] = node.meta[key]
44
+ return meta
45
+
46
+ def insert_t_q_dq(
47
+ self,
48
+ graph: torch.fx.Graph,
49
+ input_dq: torch.fx.Node,
50
+ target: torch.fx.Node,
51
+ transpose_func: TransposeFunc,
52
+ transpose_node_meta: dict,
53
+ ) -> list[torch.fx.Node]:
54
+ """
55
+ original:
56
+ input_dq -> target
57
+ insert the node as:
58
+ input_dq -> (T q dq) -> target
59
+ """
60
+ assert utils.is_dq_node(input_dq)
61
+
62
+ q_args = input_dq.args[1:]
63
+ q_kwargs = input_dq.kwargs
64
+ q_op, dq_op = utils.get_paired_q_dq_ops(input_dq.target)
65
+ with graph.inserting_before(target):
66
+ t = graph.call_function(transpose_func, (input_dq,))
67
+ # Q and DQ inserted here may required updating the `axis` arg when they
68
+ # are per_channel ops. However, instead of updating here, the nodes would
69
+ # be marked as NHWC/NCHW and applied rewriters after partitioning.
70
+ q = graph.call_function(q_op, (t,) + q_args, q_kwargs)
71
+ dq = graph.call_function(dq_op, (q,) + q_args, q_kwargs)
72
+
73
+ input_dq.meta = transpose_node_meta
74
+ t.meta = transpose_node_meta
75
+ q.meta = transpose_node_meta
76
+ dq.meta = self.get_source_meta(target)
77
+
78
+ target.replace_input_with(input_dq, dq)
79
+ return [t, q, dq]
80
+
81
+ def insert_dq_t_q(
82
+ self,
83
+ graph: torch.fx.Graph,
84
+ input_q: torch.fx.Node,
85
+ target: torch.fx.Node,
86
+ transpose_func: TransposeFunc,
87
+ transpose_node_meta: dict,
88
+ ) -> list[torch.fx.Node]:
89
+ """
90
+ original:
91
+ input_q -> target
92
+ insert the node as:
93
+ input_q -> (dq T q) -> target
94
+ """
95
+ assert utils.is_q_node(input_q)
96
+
97
+ q_args = input_q.args[1:]
98
+ q_kwargs = input_q.kwargs
99
+ q_op, dq_op = self.get_paired_q_dq_ops(input_q.target)
100
+ with graph.inserting_before(target):
101
+ # Q and DQ inserted here may required updating the `axis` arg when they
102
+ # are per_channel ops. However, instead of updating here, the nodes would
103
+ # be marked as NHWC/NCHW and applied rewriters after partitioning.
104
+ dq = graph.call_function(dq_op, (input_q,) + q_args, q_kwargs)
105
+ t = graph.call_function(transpose_func, (dq,))
106
+ q = graph.call_function(q_op, (t,) + q_args, q_kwargs)
107
+
108
+ dq.meta = transpose_node_meta
109
+ t.meta = transpose_node_meta
110
+ q.meta = transpose_node_meta
111
+
112
+ target.replace_input_with(input_q, q)
113
+ return [dq, t, q]
114
+
115
+ def insert_layout_transpose(
116
+ self,
117
+ graph: torch.fx.Graph,
118
+ input_node: torch.fx.Node,
119
+ target_node: torch.fx.Node,
120
+ transpose_func: TransposeFunc,
121
+ transpose_node_meta: dict,
122
+ ) -> None:
123
+ assert transpose_func in (utils.tensor_to_nchw, utils.tensor_to_nhwc)
124
+
125
+ # new_nodes only contains Q/DQ/Transpose nodes, which are all SISO.
126
+ # Insertion order input nodes -> output nodes
127
+ new_nodes = []
128
+
129
+ # Constraint Q2: the NHWC partition's entry and exit must not be output
130
+ # edges of Q/DQ ops that are connected to a constant/weight tensor.
131
+ while layout_mark.is_const_node(input_node) and (
132
+ utils.is_dq_node(input_node) or utils.is_q_node(input_node)
133
+ ):
134
+ with graph.inserting_before(target_node):
135
+ new_input_node = graph.node_copy(input_node)
136
+
137
+ target_node.replace_input_with(input_node, new_input_node)
138
+
139
+ new_nodes = [new_input_node] + new_nodes
140
+ input_node, target_node = new_input_node.args[0], new_input_node
141
+
142
+ if utils.is_q_node(input_node):
143
+ # Constraint Q3: when the entry and exit is right after a q op (occur after a (dq-op-q)
144
+ # triplet), the transpose must be added as a quantized transpose in (dq-T-q)
145
+ # input_q -> (dq T q) -> target
146
+ new_nodes = (
147
+ self.insert_dq_t_q(
148
+ graph,
149
+ input_node,
150
+ target_node,
151
+ transpose_func,
152
+ transpose_node_meta,
153
+ )
154
+ + new_nodes
155
+ )
156
+ elif utils.is_dq_node(input_node):
157
+ # Constraint Q1: the NHWC partition's entry and exit cannot be edges
158
+ # within (dq-op-q) triplet.
159
+ # input_dq -> (T q dq) -> target
160
+ new_nodes = (
161
+ self.insert_t_q_dq(
162
+ graph,
163
+ input_node,
164
+ target_node,
165
+ transpose_func,
166
+ transpose_node_meta,
167
+ )
168
+ + new_nodes
169
+ )
170
+ else:
171
+ # input -> target
172
+ with graph.inserting_before(target_node):
173
+ t = graph.call_function(transpose_func, (input_node,))
174
+ t.meta = transpose_node_meta
175
+ target_node.replace_input_with(input_node, t)
176
+ new_nodes = [t] + new_nodes
177
+
178
+ # Mark new nodes as NCHW or NHWC
179
+ # For all nodes before the transpose, mark it as input_marker
180
+ # For all nodes after the transpose (incl. transpose), mark it as output_marker
181
+ if transpose_func == utils.tensor_to_nchw:
182
+ input_marker, target_marker = (
183
+ layout_mark.mark_as_nhwc_node,
184
+ layout_mark.mark_as_nchw_node,
185
+ )
186
+ else:
187
+ input_marker, target_marker = (
188
+ layout_mark.mark_as_nchw_node,
189
+ layout_mark.mark_as_nhwc_node,
190
+ )
191
+
192
+ marker = input_marker
193
+ for node in new_nodes:
194
+ if node.target == transpose_func:
195
+ marker = target_marker
196
+ marker(node)
197
+ assert marker == target_marker
198
+
199
+ def input_to_nhwc(
200
+ self,
201
+ graph: torch.fx.Graph,
202
+ input_node: torch.fx.Node,
203
+ target_node: torch.fx.Node,
204
+ ) -> None:
205
+ if layout_mark.is_nhwc_node(input_node):
206
+ return
207
+
208
+ if not layout_check.is_4d(input_node):
209
+ raise AssertionError(
210
+ f"Attempting to convert non-NHWC compatible node to NHWC: {input_node}"
211
+ )
212
+
213
+ # Assign target node's source meta to the to_NHWC node, because the transpose
214
+ # is added for the existence of target node.
215
+ self.insert_layout_transpose(
216
+ graph,
217
+ input_node,
218
+ target_node,
219
+ utils.tensor_to_nhwc,
220
+ self.get_source_meta(target_node),
221
+ )
222
+
223
+ def input_to_nchw(
224
+ self,
225
+ graph: torch.fx.Graph,
226
+ input_node: torch.fx.Node,
227
+ target_node: torch.fx.Node,
228
+ ) -> None:
229
+ if layout_mark.is_nchw_node(input_node):
230
+ return
231
+
232
+ self.insert_layout_transpose(
233
+ graph,
234
+ input_node,
235
+ target_node,
236
+ utils.tensor_to_nchw,
237
+ self.get_source_meta(input_node),
238
+ )
239
+
240
+ def mark_const_nodes(self, exported_program: torch.export.ExportedProgram):
241
+ graph_module = exported_program.graph_module
242
+ graph = graph_module.graph
243
+
244
+ input_specs = exported_program.graph_signature.input_specs
245
+ non_user_input_names = set()
246
+ for spec in input_specs:
247
+ if spec.kind != torch.export.graph_signature.InputKind.USER_INPUT:
248
+ non_user_input_names.add(spec.arg.name)
249
+
250
+ for node in graph.nodes:
251
+ has_input_nodes = len(node.all_input_nodes) > 0
252
+ all_inputs_are_const = all(map(layout_mark.is_const_node, node.all_input_nodes))
253
+ if (
254
+ node.name in non_user_input_names
255
+ or (has_input_nodes and all_inputs_are_const)
256
+ or (node.op != "placeholder" and not has_input_nodes)
257
+ ):
258
+ layout_mark.mark_as_const_node(node)
259
+
260
+ def call(self, exported_program: torch.export.ExportedProgram):
261
+ self.mark_const_nodes(exported_program)
262
+
263
+ graph_module = exported_program.graph_module
264
+ if os.environ.get("AIEDGETORCH_LAYOUT_OPTIMIZE_USE_MINCUT_PARTITIONER"):
265
+ graph_module = layout_partitioners.min_cut.partition(graph_module)
266
+ else:
267
+ graph_module = layout_partitioners.greedy.partition(graph_module)
268
+
269
+ graph = graph_module.graph
270
+ for node in list(graph.nodes):
271
+ if layout_mark.is_nhwc_node(node):
272
+ for input_node in layout_check.get_layout_sensitive_inputs(node):
273
+ self.input_to_nhwc(graph, input_node, node)
274
+ layout_rewrite.rewrite_nhwc_node(node)
275
+ else:
276
+ for input_node in layout_check.get_layout_sensitive_inputs(node):
277
+ # Note: for non-4D tensors input_to_nchw is always noop.
278
+ self.input_to_nchw(graph, input_node, node)
279
+
280
+ graph_module.graph.eliminate_dead_code()
281
+ graph_module.recompile()
282
+ graph_module.graph.lint()
283
+ # Mark const node again for debugging
284
+ self.mark_const_nodes(exported_program)
285
+
286
+ return ExportedProgramPassResult(exported_program, True)