ai-edge-torch-nightly 0.3.0.dev20240827__py3-none-any.whl → 0.3.0.dev20240829__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (46) hide show
  1. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +6 -1
  2. ai_edge_torch/_convert/test/test_convert.py +1 -1
  3. ai_edge_torch/_convert/test/test_convert_composites.py +1 -1
  4. ai_edge_torch/_convert/test/test_convert_multisig.py +71 -31
  5. ai_edge_torch/_convert/test/test_to_channel_last_io.py +1 -1
  6. ai_edge_torch/debug/test/test_culprit.py +1 -1
  7. ai_edge_torch/debug/test/test_search_model.py +1 -1
  8. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +43 -59
  9. ai_edge_torch/generative/test/test_experimental_ekv.py +1 -1
  10. ai_edge_torch/generative/test/test_loader.py +1 -1
  11. ai_edge_torch/generative/test/test_model_conversion.py +1 -1
  12. ai_edge_torch/generative/test/test_quantize.py +1 -1
  13. ai_edge_torch/hlfb/test/test_mark_pattern.py +1 -1
  14. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +1 -1
  15. ai_edge_torch/lowertools/odml_torch_utils.py +5 -1
  16. ai_edge_torch/lowertools/test_utils.py +1 -1
  17. ai_edge_torch/odml_torch/__init__.py +20 -0
  18. ai_edge_torch/odml_torch/_torch_future.py +61 -0
  19. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  20. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  21. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  22. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  23. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  24. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  25. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  26. ai_edge_torch/odml_torch/export.py +320 -0
  27. ai_edge_torch/odml_torch/export_utils.py +168 -0
  28. ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
  29. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +152 -0
  30. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  31. ai_edge_torch/odml_torch/lowerings/__init__.py +24 -0
  32. ai_edge_torch/odml_torch/lowerings/_basic.py +204 -0
  33. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  34. ai_edge_torch/odml_torch/lowerings/_convolution.py +119 -0
  35. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -0
  36. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  37. ai_edge_torch/odml_torch/lowerings/registry.py +87 -0
  38. ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
  39. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  40. ai_edge_torch/odml_torch/tf_integration.py +194 -0
  41. ai_edge_torch/version.py +1 -1
  42. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/METADATA +1 -1
  43. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/RECORD +46 -22
  44. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/LICENSE +0 -0
  45. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/WHEEL +0 -0
  46. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,65 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Provides lowering for coreaten to mlir stablehlo op: Convolution"""
16
+
17
+ from typing import Optional
18
+
19
+ from ai_edge_torch.odml_torch.lowerings import utils
20
+ from jax._src.lib.mlir import ir
21
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
22
+ import torch
23
+
24
+ from .registry import lower
25
+
26
+
27
+ # _native_batch_norm_legit_no_training(
28
+ # Tensor input,
29
+ # Tensor? weight,
30
+ # Tensor? bias,
31
+ # Tensor running_mean,
32
+ # Tensor running_var,
33
+ # float momentum,
34
+ # float eps) -> (Tensor, Tensor, Tensor)
35
+ @lower(torch.ops.aten._native_batch_norm_legit_no_training)
36
+ def _native_batch_norm_legit_no_training(
37
+ lctx,
38
+ input_tensor: ir.Value,
39
+ weight: Optional[ir.Value],
40
+ bias: Optional[ir.Value],
41
+ running_mean: ir.Value,
42
+ running_var: ir.Value,
43
+ momentum: float,
44
+ eps: float,
45
+ ):
46
+ if weight is None:
47
+ weight = utils.splat(
48
+ 1, running_mean.type.element_type, running_mean.type.shape
49
+ )
50
+ if bias is None:
51
+ bias = utils.splat(
52
+ 0, running_mean.type.element_type, running_mean.type.shape
53
+ )
54
+
55
+ return [
56
+ stablehlo.batch_norm_inference(
57
+ input_tensor, weight, bias, running_mean, running_var, eps, 1
58
+ ),
59
+ utils.splat(
60
+ 0, input_tensor.type.element_type
61
+ ), # TODO: return empty array instead
62
+ utils.splat(
63
+ 0, input_tensor.type.element_type
64
+ ), # TODO: return empty array instead
65
+ ]
@@ -0,0 +1,119 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Provides lowering for coreaten to mlir stablehlo op: Convolution"""
16
+
17
+ import math
18
+ from typing import Optional
19
+
20
+ from jax._src.lib.mlir import ir
21
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
22
+ import torch
23
+
24
+ from .registry import lower
25
+
26
+
27
+ # convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride,
28
+ # SymInt[] padding, SymInt[] dilation, bool transposed,
29
+ # SymInt[] output_padding, SymInt groups) -> Tensor
30
+ # @lower(torch.ops.aten.convolution)
31
+ def _aten_convolution(
32
+ lctx,
33
+ lhs: ir.Value,
34
+ rhs: ir.Value,
35
+ bias: Optional[ir.Value],
36
+ stride: list[int],
37
+ padding: list[int],
38
+ dilation: list[int],
39
+ transposed: bool,
40
+ output_padding: list[int],
41
+ groups: int,
42
+ ):
43
+ if transposed:
44
+ raise NotImplementedError("Transposed convolution is not implemented.")
45
+
46
+ if bias is not None:
47
+ raise NotImplementedError("Bias on convolution is not implemented.")
48
+
49
+ # Stablehlo allows start and end padding for each dimension while aten only
50
+ # allows symmetric padding and so only has one number per dimension.
51
+ def make_padding(padding):
52
+ return tuple((p, p) for p in padding)
53
+
54
+ def create_conv_dimension_numbers():
55
+ num_spatial_dims = len(lhs.type.shape) - 2
56
+ spatial_dimensions = []
57
+ for i in range(0, num_spatial_dims):
58
+ spatial_dimensions.append(i + 2)
59
+
60
+ dimension_numbers = stablehlo.ConvDimensionNumbers.get(
61
+ input_batch_dimension=0,
62
+ input_feature_dimension=1,
63
+ input_spatial_dimensions=spatial_dimensions,
64
+ kernel_input_feature_dimension=1,
65
+ kernel_output_feature_dimension=0,
66
+ kernel_spatial_dimensions=spatial_dimensions,
67
+ output_batch_dimension=0,
68
+ output_feature_dimension=1,
69
+ output_spatial_dimensions=spatial_dimensions,
70
+ )
71
+ return dimension_numbers
72
+
73
+ def infer_output_shape():
74
+ lhs_type: ir.RankedTensorType = lhs.type
75
+ lhs_shape: list[int] = lhs_type.shape
76
+ rhs_shape: list[int] = rhs.type.shape
77
+
78
+ # Input layout is: (N)CHW and Kernel layout is: (O)IHW
79
+ output_shape = [lhs_shape[0], rhs_shape[0]]
80
+ num_spatial_dims = len(lhs.type.shape) - 2
81
+
82
+ # looping over the spatial dims (skipping the first 2 dims which are
83
+ # batch and features)
84
+ for spatial_dim in range(0, num_spatial_dims):
85
+ dim_size = lhs_shape[spatial_dim + 2]
86
+ kernel_dim_size = rhs_shape[spatial_dim + 2]
87
+
88
+ # for example, a dilation of 2 increases the dimension size by 2
89
+ dim_size *= dilation[spatial_dim]
90
+
91
+ # padding added to both sides
92
+ dim_size += 2 * padding[spatial_dim]
93
+
94
+ output_dim_size = math.ceil(
95
+ (dim_size - kernel_dim_size + 1) / stride[spatial_dim]
96
+ )
97
+
98
+ output_shape.append(output_dim_size)
99
+
100
+ return output_shape
101
+
102
+ lhs_type: ir.RankedTensorType = lhs.type
103
+
104
+ op = stablehlo.ConvolutionOp(
105
+ result=ir.RankedTensorType.get(
106
+ infer_output_shape(), lhs_type.element_type
107
+ ),
108
+ lhs=lhs,
109
+ rhs=rhs,
110
+ dimension_numbers=create_conv_dimension_numbers(),
111
+ feature_group_count=groups,
112
+ batch_group_count=1,
113
+ window_strides=stride,
114
+ padding=make_padding(padding),
115
+ lhs_dilation=(1,) * len(stride),
116
+ rhs_dilation=dilation,
117
+ )
118
+
119
+ return op.result
@@ -0,0 +1,255 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import functools
16
+ import logging
17
+
18
+ from ai_edge_torch.odml_torch import jax_bridge
19
+ import torch
20
+ import torch_xla2.ops.jaten # Import to load torch_xla2 ops
21
+ import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops
22
+
23
+ from . import registry
24
+
25
+
26
+ @functools.cache
27
+ def _log_usage(op):
28
+ logging.warning("Use jax lowering: %s", str(op))
29
+
30
+
31
+ def lower_by_jax(op, ir_input_names=None):
32
+ def inner(lowering):
33
+ bridged = jax_bridge.wrap(lowering, ir_input_names)
34
+
35
+ @registry.lower(op)
36
+ def _jax_lowering(lctx, *args, **kwargs):
37
+ _log_usage(op)
38
+ return bridged(lctx, *args, **kwargs)
39
+
40
+ return lowering
41
+
42
+ return inner
43
+
44
+
45
+ _TORCH_XLA2_IMPLS = {
46
+ key: val.func
47
+ for key, val in torch_xla2.ops.ops_registry.all_aten_ops.items()
48
+ if val.is_jax_function
49
+ }
50
+
51
+
52
+ def lower_by_torch_xla2(op):
53
+ return lower_by_jax(op)(_TORCH_XLA2_IMPLS[op])
54
+
55
+
56
+ lower_by_torch_xla2(torch.ops.aten._adaptive_avg_pool2d)
57
+ lower_by_torch_xla2(torch.ops.aten._adaptive_avg_pool3d)
58
+ lower_by_torch_xla2(torch.ops.aten._cdist_forward)
59
+ lower_by_torch_xla2(torch.ops.aten._local_scalar_dense)
60
+ lower_by_torch_xla2(torch.ops.aten._local_scalar_dense)
61
+ lower_by_torch_xla2(torch.ops.aten._log_softmax)
62
+ lower_by_torch_xla2(torch.ops.aten._native_batch_norm_legit)
63
+ lower_by_torch_xla2(torch.ops.aten._native_batch_norm_legit_no_training)
64
+ lower_by_torch_xla2(torch.ops.aten._pdist_forward)
65
+ lower_by_torch_xla2(torch.ops.aten._softmax)
66
+ lower_by_torch_xla2(torch.ops.aten._to_copy)
67
+ lower_by_torch_xla2(torch.ops.aten._unsafe_index)
68
+ lower_by_torch_xla2(torch.ops.aten._unsafe_view)
69
+ lower_by_torch_xla2(torch.ops.aten.abs)
70
+ lower_by_torch_xla2(torch.ops.aten.acos)
71
+ lower_by_torch_xla2(torch.ops.aten.acosh)
72
+ lower_by_torch_xla2(torch.ops.aten.add.Scalar)
73
+ lower_by_torch_xla2(torch.ops.aten.add.Tensor)
74
+ lower_by_torch_xla2(torch.ops.aten.addbmm.default)
75
+ lower_by_torch_xla2(torch.ops.aten.addmm)
76
+ lower_by_torch_xla2(torch.ops.aten.addmv)
77
+ lower_by_torch_xla2(torch.ops.aten.alias)
78
+ lower_by_torch_xla2(torch.ops.aten.allclose)
79
+ lower_by_torch_xla2(torch.ops.aten.amax)
80
+ lower_by_torch_xla2(torch.ops.aten.amin)
81
+ lower_by_torch_xla2(torch.ops.aten.any)
82
+ lower_by_torch_xla2(torch.ops.aten.arange.default)
83
+ lower_by_torch_xla2(torch.ops.aten.arange.start)
84
+ lower_by_torch_xla2(torch.ops.aten.arange.start_step)
85
+ lower_by_torch_xla2(torch.ops.aten.argmax)
86
+ lower_by_torch_xla2(torch.ops.aten.argmin)
87
+ lower_by_torch_xla2(torch.ops.aten.as_strided)
88
+ lower_by_torch_xla2(torch.ops.aten.as_strided_copy)
89
+ lower_by_torch_xla2(torch.ops.aten.asin)
90
+ lower_by_torch_xla2(torch.ops.aten.asinh)
91
+ lower_by_torch_xla2(torch.ops.aten.atan)
92
+ lower_by_torch_xla2(torch.ops.aten.atan2)
93
+ lower_by_torch_xla2(torch.ops.aten.atanh)
94
+ lower_by_torch_xla2(torch.ops.aten.avg_pool2d)
95
+ lower_by_torch_xla2(torch.ops.aten.avg_pool3d)
96
+ lower_by_torch_xla2(torch.ops.aten.bitwise_and)
97
+ lower_by_torch_xla2(torch.ops.aten.bitwise_not)
98
+ lower_by_torch_xla2(torch.ops.aten.bitwise_or)
99
+ lower_by_torch_xla2(torch.ops.aten.bitwise_xor)
100
+ lower_by_torch_xla2(torch.ops.aten.bmm)
101
+ lower_by_torch_xla2(torch.ops.aten.cat)
102
+ lower_by_torch_xla2(torch.ops.aten.ceil)
103
+ lower_by_torch_xla2(torch.ops.aten.clamp.Tensor)
104
+ lower_by_torch_xla2(torch.ops.aten.clamp.default)
105
+ lower_by_torch_xla2(torch.ops.aten.clone)
106
+ lower_by_torch_xla2(torch.ops.aten.clone.default)
107
+ lower_by_torch_xla2(torch.ops.aten.constant_pad_nd)
108
+ lower_by_torch_xla2(torch.ops.aten.convolution)
109
+ lower_by_torch_xla2(torch.ops.aten.cos)
110
+ lower_by_torch_xla2(torch.ops.aten.cosh)
111
+ lower_by_torch_xla2(torch.ops.aten.cumsum)
112
+ lower_by_torch_xla2(torch.ops.aten.detach)
113
+ lower_by_torch_xla2(torch.ops.aten.diagonal)
114
+ lower_by_torch_xla2(torch.ops.aten.div)
115
+ lower_by_torch_xla2(torch.ops.aten.dot)
116
+ lower_by_torch_xla2(torch.ops.aten.embedding)
117
+ lower_by_torch_xla2(torch.ops.aten.empty)
118
+ lower_by_torch_xla2(torch.ops.aten.eq)
119
+ lower_by_torch_xla2(torch.ops.aten.erf)
120
+ lower_by_torch_xla2(torch.ops.aten.exp)
121
+ lower_by_torch_xla2(torch.ops.aten.expand)
122
+ lower_by_torch_xla2(torch.ops.aten.expand_copy)
123
+ lower_by_torch_xla2(torch.ops.aten.expm1)
124
+ lower_by_torch_xla2(torch.ops.aten.fill)
125
+ lower_by_torch_xla2(torch.ops.aten.flip)
126
+ lower_by_torch_xla2(torch.ops.aten.floor)
127
+ lower_by_torch_xla2(torch.ops.aten.fmod)
128
+ lower_by_torch_xla2(torch.ops.aten.full)
129
+ lower_by_torch_xla2(torch.ops.aten.full_like)
130
+ lower_by_torch_xla2(torch.ops.aten.gather)
131
+ lower_by_torch_xla2(torch.ops.aten.ge)
132
+ lower_by_torch_xla2(torch.ops.aten.gelu)
133
+ lower_by_torch_xla2(torch.ops.aten.glu)
134
+ lower_by_torch_xla2(torch.ops.aten.glu.default)
135
+ lower_by_torch_xla2(torch.ops.aten.gt)
136
+ lower_by_torch_xla2(torch.ops.aten.hardtanh)
137
+ lower_by_torch_xla2(torch.ops.aten.index)
138
+ lower_by_torch_xla2(torch.ops.aten.index.Tensor)
139
+ lower_by_torch_xla2(torch.ops.aten.index_copy)
140
+ lower_by_torch_xla2(torch.ops.aten.index_put)
141
+ lower_by_torch_xla2(torch.ops.aten.index_select)
142
+ lower_by_torch_xla2(torch.ops.aten.isinf)
143
+ lower_by_torch_xla2(torch.ops.aten.isnan)
144
+ lower_by_torch_xla2(torch.ops.aten.le)
145
+ lower_by_torch_xla2(torch.ops.aten.leaky_relu)
146
+ lower_by_torch_xla2(torch.ops.aten.lift_fresh_copy)
147
+ lower_by_torch_xla2(torch.ops.aten.linalg_vector_norm)
148
+ lower_by_torch_xla2(torch.ops.aten.log)
149
+ lower_by_torch_xla2(torch.ops.aten.log10)
150
+ lower_by_torch_xla2(torch.ops.aten.log1p)
151
+ lower_by_torch_xla2(torch.ops.aten.log2)
152
+ lower_by_torch_xla2(torch.ops.aten.logical_and)
153
+ lower_by_torch_xla2(torch.ops.aten.logical_not)
154
+ lower_by_torch_xla2(torch.ops.aten.logical_or)
155
+ lower_by_torch_xla2(torch.ops.aten.logical_xor)
156
+ lower_by_torch_xla2(torch.ops.aten.lt)
157
+ lower_by_torch_xla2(torch.ops.aten.max)
158
+ lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices)
159
+ lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices_backward)
160
+ lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices_backward)
161
+ lower_by_torch_xla2(torch.ops.aten.max_pool3d_with_indices)
162
+ lower_by_torch_xla2(torch.ops.aten.maximum)
163
+ lower_by_torch_xla2(torch.ops.aten.mean)
164
+ lower_by_torch_xla2(torch.ops.aten.min)
165
+ lower_by_torch_xla2(torch.ops.aten.minimum)
166
+ lower_by_torch_xla2(torch.ops.aten.mm)
167
+ lower_by_torch_xla2(torch.ops.aten.mul.Scalar)
168
+ lower_by_torch_xla2(torch.ops.aten.mul.Tensor)
169
+ lower_by_torch_xla2(torch.ops.aten.native_batch_norm)
170
+ lower_by_torch_xla2(torch.ops.aten.native_group_norm)
171
+ lower_by_torch_xla2(torch.ops.aten.native_layer_norm)
172
+ lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward)
173
+ lower_by_torch_xla2(torch.ops.aten.ne)
174
+ lower_by_torch_xla2(torch.ops.aten.neg)
175
+ lower_by_torch_xla2(torch.ops.aten.nonzero)
176
+ lower_by_torch_xla2(torch.ops.aten.outer)
177
+ lower_by_torch_xla2(torch.ops.aten.permute)
178
+ lower_by_torch_xla2(torch.ops.aten.permute_copy)
179
+ lower_by_torch_xla2(torch.ops.aten.pixel_shuffle)
180
+ lower_by_torch_xla2(torch.ops.aten.pow)
181
+ lower_by_torch_xla2(torch.ops.aten.prod)
182
+ lower_by_torch_xla2(torch.ops.aten.rand)
183
+ lower_by_torch_xla2(torch.ops.aten.randn)
184
+ lower_by_torch_xla2(torch.ops.aten.reciprocal)
185
+ lower_by_torch_xla2(torch.ops.aten.reflection_pad1d)
186
+ lower_by_torch_xla2(torch.ops.aten.relu)
187
+ lower_by_torch_xla2(torch.ops.aten.remainder)
188
+ lower_by_torch_xla2(torch.ops.aten.repeat)
189
+ lower_by_torch_xla2(torch.ops.aten.reshape)
190
+ lower_by_torch_xla2(torch.ops.aten.roll)
191
+ lower_by_torch_xla2(torch.ops.aten.round)
192
+ lower_by_torch_xla2(torch.ops.aten.rsqrt)
193
+ lower_by_torch_xla2(torch.ops.aten.scalar_tensor)
194
+ lower_by_torch_xla2(torch.ops.aten.scatter.src)
195
+ lower_by_torch_xla2(torch.ops.aten.scatter.value)
196
+ lower_by_torch_xla2(torch.ops.aten.scatter_add)
197
+ lower_by_torch_xla2(torch.ops.aten.scatter_reduce)
198
+ lower_by_torch_xla2(torch.ops.aten.select)
199
+ lower_by_torch_xla2(torch.ops.aten.select_copy)
200
+ lower_by_torch_xla2(torch.ops.aten.select_scatter)
201
+ lower_by_torch_xla2(torch.ops.aten.sigmoid)
202
+ lower_by_torch_xla2(torch.ops.aten.sign)
203
+ lower_by_torch_xla2(torch.ops.aten.silu)
204
+ lower_by_torch_xla2(torch.ops.aten.sin)
205
+ lower_by_torch_xla2(torch.ops.aten.sinh)
206
+ lower_by_torch_xla2(torch.ops.aten.slice)
207
+ lower_by_torch_xla2(torch.ops.aten.slice_copy)
208
+ lower_by_torch_xla2(torch.ops.aten.slice_scatter)
209
+ lower_by_torch_xla2(torch.ops.aten.sort)
210
+ lower_by_torch_xla2(torch.ops.aten.split)
211
+ lower_by_torch_xla2(torch.ops.aten.split_copy)
212
+ lower_by_torch_xla2(torch.ops.aten.split_with_sizes)
213
+ lower_by_torch_xla2(torch.ops.aten.sqrt)
214
+ lower_by_torch_xla2(torch.ops.aten.squeeze)
215
+ lower_by_torch_xla2(torch.ops.aten.squeeze_copy)
216
+ lower_by_torch_xla2(torch.ops.aten.stack)
217
+ lower_by_torch_xla2(torch.ops.aten.sub.Scalar)
218
+ lower_by_torch_xla2(torch.ops.aten.sub.Tensor)
219
+ lower_by_torch_xla2(torch.ops.aten.sum)
220
+ lower_by_torch_xla2(torch.ops.aten.sym_size)
221
+ lower_by_torch_xla2(torch.ops.aten.t)
222
+ lower_by_torch_xla2(torch.ops.aten.tan)
223
+ lower_by_torch_xla2(torch.ops.aten.tanh)
224
+ lower_by_torch_xla2(torch.ops.aten.tensor_split.sections)
225
+ lower_by_torch_xla2(torch.ops.aten.tensor_split.sections)
226
+ lower_by_torch_xla2(torch.ops.aten.to.device)
227
+ lower_by_torch_xla2(torch.ops.aten.to.device)
228
+ lower_by_torch_xla2(torch.ops.aten.to.dtype)
229
+ lower_by_torch_xla2(torch.ops.aten.topk)
230
+ lower_by_torch_xla2(torch.ops.aten.transpose)
231
+ lower_by_torch_xla2(torch.ops.aten.transpose_copy)
232
+ lower_by_torch_xla2(torch.ops.aten.triu)
233
+ lower_by_torch_xla2(torch.ops.aten.true_divide)
234
+ lower_by_torch_xla2(torch.ops.aten.trunc)
235
+ lower_by_torch_xla2(torch.ops.aten.unbind)
236
+ lower_by_torch_xla2(torch.ops.aten.unbind_copy)
237
+ lower_by_torch_xla2(torch.ops.aten.unsqueeze)
238
+ lower_by_torch_xla2(torch.ops.aten.unsqueeze.default)
239
+ lower_by_torch_xla2(torch.ops.aten.unsqueeze_copy)
240
+ lower_by_torch_xla2(torch.ops.aten.var.correction)
241
+ lower_by_torch_xla2(torch.ops.aten.var_mean.correction)
242
+ lower_by_torch_xla2(torch.ops.aten.view)
243
+ lower_by_torch_xla2(torch.ops.aten.view_as_complex)
244
+ lower_by_torch_xla2(torch.ops.aten.view_as_real)
245
+ lower_by_torch_xla2(torch.ops.aten.view_copy)
246
+ lower_by_torch_xla2(torch.ops.aten.where.ScalarOther)
247
+ lower_by_torch_xla2(torch.ops.aten.where.ScalarSelf)
248
+ lower_by_torch_xla2(torch.ops.aten.where.self)
249
+ lower_by_torch_xla2(torch.ops.prims.broadcast_in_dim)
250
+ lower_by_torch_xla2(torch.ops.prims.var)
251
+
252
+
253
+ @lower_by_jax(torch.ops.aten.copy, ir_input_names=["src"])
254
+ def _aten_copy(self, src, **kwargs):
255
+ return _TORCH_XLA2_IMPLS[torch.ops.aten.copy](self, src)
@@ -0,0 +1,42 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Define context object for export and MLIR lowerings."""
16
+
17
+ import dataclasses
18
+ from jax._src.lib.mlir import ir
19
+ import torch
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class LoweringContext:
24
+ """The context object used in export interpreter and MLIR lowerings."""
25
+
26
+ ir_context: ir.Context
27
+ ir_module: ir.Module
28
+ ir_location: ir.Location = None
29
+ node: torch.fx.Node = None
30
+
31
+ @property
32
+ def ctx(self):
33
+ """Shortcut for ir_context."""
34
+ return self.ir_context
35
+
36
+ @property
37
+ def loc(self):
38
+ """Shortcut for ir_location."""
39
+ return self.ir_location
40
+
41
+ def replace(self, **kwargs):
42
+ return dataclasses.replace(self, **kwargs)
@@ -0,0 +1,87 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Torch op decompositions and MLIR lowerings registry."""
16
+
17
+ from typing import Any, Callable
18
+
19
+ import torch
20
+
21
+ from . import context
22
+
23
+
24
+ class LoweringRegistry:
25
+ """Registry object for torch op decompositions and to-MLIR lowerings."""
26
+
27
+ def __init__(self):
28
+ self.registered_ops = {}
29
+ self.decompositions = {}
30
+
31
+ def lookup(self, op_or_name):
32
+ candidate = self._get_lowering(op_or_name)
33
+ if candidate is None:
34
+ if isinstance(op_or_name, torch._ops.OpOverloadPacket):
35
+ candidate = self._get_lowering(op_or_name.default)
36
+ if isinstance(op_or_name, torch._ops.OpOverload):
37
+ candidate = self._get_lowering(op_or_name.overloadpacket)
38
+ return candidate
39
+
40
+ def _get_lowering(self, op):
41
+ candidate = self.registered_ops.get(op)
42
+ return candidate
43
+
44
+ def register(self, op, lowering):
45
+ if isinstance(op, torch._ops.OpOverloadPacket):
46
+ ops = [getattr(op, overload) for overload in op.overloads()]
47
+ else:
48
+ ops = [op]
49
+
50
+ for op in ops:
51
+ self.registered_ops[op] = lowering
52
+
53
+
54
+ global_registry = LoweringRegistry()
55
+ global_registry.decompositions.update(
56
+ torch._decomp.get_decompositions([
57
+ torch.ops.aten.upsample_nearest2d,
58
+ torch.ops.aten._native_batch_norm_legit.no_stats,
59
+ torch.ops.aten._adaptive_avg_pool2d,
60
+ torch.ops.aten._adaptive_avg_pool3d,
61
+ torch.ops.aten.grid_sampler_2d,
62
+ torch.ops.aten.native_dropout,
63
+ torch.ops.aten.reflection_pad1d,
64
+ torch.ops.aten.reflection_pad2d,
65
+ torch.ops.aten.reflection_pad3d,
66
+ torch.ops.aten.replication_pad1d,
67
+ torch.ops.aten.replication_pad2d,
68
+ torch.ops.aten.replication_pad3d,
69
+ torch.ops.aten.addmm,
70
+ ])
71
+ )
72
+
73
+
74
+ def lookup(op):
75
+ return global_registry.lookup(op)
76
+
77
+
78
+ def lower(op):
79
+ def inner(lowering: Callable[[context.LoweringContext, ...], Any]):
80
+ global_registry.register(op, lowering)
81
+ return lowering
82
+
83
+ return inner
84
+
85
+
86
+ def decompositions():
87
+ return global_registry.decompositions