ai-edge-torch-nightly 0.3.0.dev20240828__py3-none-any.whl → 0.3.0.dev20240829__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (45) hide show
  1. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +6 -1
  2. ai_edge_torch/_convert/test/test_convert.py +1 -1
  3. ai_edge_torch/_convert/test/test_convert_composites.py +1 -1
  4. ai_edge_torch/_convert/test/test_convert_multisig.py +1 -1
  5. ai_edge_torch/_convert/test/test_to_channel_last_io.py +1 -1
  6. ai_edge_torch/debug/test/test_culprit.py +1 -1
  7. ai_edge_torch/debug/test/test_search_model.py +1 -1
  8. ai_edge_torch/generative/test/test_experimental_ekv.py +1 -1
  9. ai_edge_torch/generative/test/test_loader.py +1 -1
  10. ai_edge_torch/generative/test/test_model_conversion.py +1 -1
  11. ai_edge_torch/generative/test/test_quantize.py +1 -1
  12. ai_edge_torch/hlfb/test/test_mark_pattern.py +1 -1
  13. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +1 -1
  14. ai_edge_torch/lowertools/odml_torch_utils.py +5 -1
  15. ai_edge_torch/lowertools/test_utils.py +1 -1
  16. ai_edge_torch/odml_torch/__init__.py +20 -0
  17. ai_edge_torch/odml_torch/_torch_future.py +61 -0
  18. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  19. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  20. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  21. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  22. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  23. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  24. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  25. ai_edge_torch/odml_torch/export.py +320 -0
  26. ai_edge_torch/odml_torch/export_utils.py +168 -0
  27. ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
  28. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +152 -0
  29. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  30. ai_edge_torch/odml_torch/lowerings/__init__.py +24 -0
  31. ai_edge_torch/odml_torch/lowerings/_basic.py +204 -0
  32. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  33. ai_edge_torch/odml_torch/lowerings/_convolution.py +119 -0
  34. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -0
  35. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  36. ai_edge_torch/odml_torch/lowerings/registry.py +87 -0
  37. ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
  38. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  39. ai_edge_torch/odml_torch/tf_integration.py +194 -0
  40. ai_edge_torch/version.py +1 -1
  41. {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/METADATA +1 -1
  42. {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/RECORD +45 -21
  43. {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/LICENSE +0 -0
  44. {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/WHEEL +0 -0
  45. {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,75 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utilities for Jax bridge."""
16
+
17
+ from ai_edge_torch import odml_torch
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from jax._src.lib.mlir import ir
21
+ import torch
22
+
23
+
24
+ def t2j_dtype(dtype):
25
+ return {
26
+ torch.bfloat16: jnp.bfloat16,
27
+ torch.half: jnp.float16,
28
+ torch.float32: jnp.float32,
29
+ torch.double: jnp.double,
30
+ torch.long: jnp.int64,
31
+ torch.int64: jnp.int64,
32
+ torch.int32: jnp.int32,
33
+ torch.int16: jnp.int16,
34
+ torch.int8: jnp.int8,
35
+ torch.uint8: jnp.uint8,
36
+ torch.bool: jnp.bool_,
37
+ torch.complex64: jnp.complex64,
38
+ torch.complex128: jnp.complex128,
39
+ }.get(dtype)
40
+
41
+
42
+ def is_ir_variable(value):
43
+ if isinstance(value, ir.Value):
44
+ return True
45
+ if isinstance(value, (list, tuple)):
46
+ return any(is_ir_variable(x) for x in value)
47
+ return False
48
+
49
+
50
+ def ir_variable_to_jax(value):
51
+ if isinstance(value, (list, tuple)):
52
+ return tuple([ir_variable_to_jax(x) for x in value])
53
+ elif not isinstance(value, ir.Value):
54
+ return value
55
+ elif not isinstance(value.type, ir.RankedTensorType):
56
+ raise ValueError(
57
+ f"ir.Value to JAX must be in ir.RankedTensorType, got {value}"
58
+ )
59
+
60
+ return jax.ShapeDtypeStruct(
61
+ value.type.shape,
62
+ t2j_dtype(
63
+ odml_torch.export_utils.ir_element_type_to_torch_dtype(
64
+ value.type.element_type
65
+ )
66
+ ),
67
+ )
68
+
69
+
70
+ def tree_map_list_to_tuple(value):
71
+ if isinstance(value, dict):
72
+ return {k: tree_map_list_to_tuple(v) for k, v in value.items()}
73
+ if isinstance(value, (list, tuple)):
74
+ return tuple([tree_map_list_to_tuple(v) for v in value])
75
+ return value
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from . import _basic
16
+ from . import _batch_norm
17
+ from . import _convolution
18
+ from . import _jax_lowerings
19
+ from . import context
20
+ from . import registry
21
+ from . import utils
22
+ from .registry import decompositions
23
+ from .registry import lookup
24
+ from .registry import lower
@@ -0,0 +1,204 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import math
16
+ from typing import Optional, Union
17
+
18
+ from ai_edge_torch.odml_torch.lowerings import utils
19
+ from jax._src.lib.mlir import ir
20
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
21
+ import numpy as np
22
+ import torch
23
+
24
+ from .registry import lower
25
+
26
+
27
+ # add(Tensor self, Tensor other) -> Tensor
28
+ # @lower(torch.ops.aten.add)
29
+ def _aten_add(lctx, x: ir.Value, y: ir.Value, alpha=1):
30
+ x, y = utils.upcast_to_same_type(x, y)
31
+ x, y = utils.broadcast_args_if_needed(x, y)
32
+ if alpha == 1:
33
+ return stablehlo.add(x, y)
34
+
35
+ alpha_splat = utils.splat(alpha, y.type.element_type, y.type.shape)
36
+ return stablehlo.add(x, stablehlo.multiply(y, alpha_splat))
37
+
38
+
39
+ # mul.Tensor(Tensor self, Tensor other) -> Tensor
40
+ # @lower(torch.ops.aten.mul.Tensor)
41
+ def _aten_mul_tensor(lctx, self: ir.Value, other: ir.Value):
42
+ self, other = utils.upcast_to_same_type(self, other)
43
+ self, other = utils.broadcast_args_if_needed(self, other)
44
+
45
+ return stablehlo.multiply(self, other)
46
+
47
+
48
+ # cat(Tensor[] tensors, int dim=0) -> Tensor
49
+ # @lower(torch.ops.aten.cat)
50
+ def _aten_cat(lctx, tensors: list[ir.Value], dim: int = 1):
51
+ return stablehlo.ConcatenateOp(tensors, dim).result
52
+
53
+
54
+ # view(Tensor(a) self, SymInt[] size) -> Tensor(a)
55
+ # @lower(torch.ops.aten.view)
56
+ def _aten_view(lctx, self: ir.Value, size: list[int]):
57
+ return stablehlo.ReshapeOp(
58
+ ir.RankedTensorType.get(size, self.type.element_type), self
59
+ ).result
60
+
61
+
62
+ # hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor
63
+ @lower(torch.ops.aten.hardtanh)
64
+ def _aten_hardtanh(
65
+ lctx,
66
+ self: ir.Value,
67
+ min_val: Union[int, float] = -1.0,
68
+ max_val: Union[int, float] = 1.0,
69
+ ):
70
+ elty = self.type.element_type
71
+ min_val = utils.splat(min_val, elty)
72
+ max_val = utils.splat(max_val, elty)
73
+
74
+ return stablehlo.clamp(min_val, self, max_val)
75
+
76
+
77
+ # mean(Tensor self, *, ScalarType? dtype=None) -> Tensor
78
+ # mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *,
79
+ # ScalarType? dtype=None) -> Tensor
80
+ @lower(torch.ops.aten.mean)
81
+ @lower(torch.ops.aten.mean.dim)
82
+ def _aten_mean_dim(
83
+ lctx,
84
+ self: ir.Value,
85
+ dim: Optional[list[int]] = None,
86
+ keepdim: bool = False,
87
+ *,
88
+ dtype=None,
89
+ ):
90
+ self_shape = self.type.shape
91
+ self_elty = self.type.element_type
92
+ if dim is None:
93
+ dim = list(range(len(self_shape)))
94
+ dim = [len(self_shape) + d if d < 0 else d for d in dim]
95
+ dim_ = ir.DenseI64ArrayAttr.get(np.asarray(dim, np.int64))
96
+ dim_to_keep = [d for d in range(len(self_shape)) if d not in dim]
97
+ dim_to_keep_ = ir.DenseI64ArrayAttr.get(np.asarray(dim_to_keep, np.int64))
98
+
99
+ zero_ = utils.splat(0.0, self_elty)
100
+
101
+ reduce_result_shape = [
102
+ s for d, s in enumerate(self_shape) if d in dim_to_keep
103
+ ]
104
+ reduce_result_ty = ir.RankedTensorType.get(reduce_result_shape, self_elty)
105
+ reduce_op = stablehlo.ReduceOp([reduce_result_ty], [self], [zero_], dim_)
106
+
107
+ reducer_arg_ty = ir.RankedTensorType.get(tuple(), self_elty)
108
+ reducer = reduce_op.regions[0].blocks.append(reducer_arg_ty, reducer_arg_ty)
109
+ with ir.InsertionPoint(reducer):
110
+ stablehlo.return_(
111
+ [stablehlo.add(reducer.arguments[0], reducer.arguments[1])]
112
+ )
113
+
114
+ sum_ = reduce_op.result
115
+ if keepdim:
116
+ sum_ = stablehlo.broadcast_in_dim(
117
+ ir.RankedTensorType.get(
118
+ [s if d in dim_to_keep else 1 for d, s in enumerate(self_shape)],
119
+ self_elty,
120
+ ),
121
+ sum_,
122
+ dim_to_keep_,
123
+ )
124
+
125
+ dim_els = math.prod([s for d, s in enumerate(self_shape) if d in dim])
126
+ dim_els_ = utils.splat(dim_els, self_elty)
127
+ div_ = stablehlo.broadcast_in_dim(
128
+ sum_.type, dim_els_, ir.DenseI64ArrayAttr.get([])
129
+ )
130
+ mean_ = stablehlo.divide(sum_, div_)
131
+
132
+ return mean_
133
+
134
+
135
+ # https://pytorch.org/docs/stable/generated/torch.clone.html
136
+ # https://github.com/pytorch/pytorch/blob/a95ceb51a23ae33c00b3a99224143c609b1b3eb3/aten/src/ATen/native/TensorFactories.cpp#L1730
137
+ @lower(torch.ops.aten.clone)
138
+ def _aten_clone(lctx, x: ir.Value, *, memory_format=None):
139
+ return x
140
+
141
+
142
+ # https://pytorch.org/docs/stable/generated/torch.permute.html
143
+ # https://github.com/pytorch/pytorch/blob/519151a062a9bd4f0d32a9c7c7eae47d7ed847b2/aten/src/ATen/native/TensorShape.cpp#L1448
144
+ # https://github.com/openxla/stablehlo/blob/main/docs/spec.md#transpose
145
+ @lower(torch.ops.aten.permute)
146
+ def _aten_permute(lctx, x: ir.Value, dims: list[int]):
147
+ dim = len(x.type.shape)
148
+ return stablehlo.transpose(x, ir.DenseI64ArrayAttr.get(dims))
149
+
150
+
151
+ # https://pytorch.org/docs/stable/generated/torch.mm.html
152
+ # https://github.com/pytorch/pytorch/blob/ffabb25c489df1dc631a577c12a0c843c8b202f3/aten/src/ATen/native/LinearAlgebra.cpp#L193
153
+ # https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dot_general
154
+ @lower(torch.ops.aten.mm)
155
+ def _aten_mm(mod, mat1: ir.Value, mat2: ir.Value) -> ir.Value:
156
+ mat1_shape = mat1.type.shape
157
+ mat2_shape = mat2.type.shape
158
+ mat1_dims = len(mat1_shape)
159
+ mat2_dims = len(mat2_shape)
160
+
161
+ if mat1_dims != 2 or mat1_dims != 2:
162
+ raise ValueError(
163
+ "Both arguments must be 2D matrices, received dimensions %d and %d"
164
+ % (mat1_dims, mat2_dims)
165
+ )
166
+
167
+ if mat1_shape[1] != mat2_shape[0]:
168
+ raise ValueError(
169
+ "mat1 and mat2 shapes cannot be multiplied, received shapes %s and %s"
170
+ % (mat1_shape, mat2_shape)
171
+ )
172
+
173
+ dot_dnums = stablehlo.DotDimensionNumbers.get(
174
+ lhs_batching_dimensions=[],
175
+ rhs_batching_dimensions=[],
176
+ lhs_contracting_dimensions=(1,),
177
+ rhs_contracting_dimensions=(0,),
178
+ )
179
+ return stablehlo.dot_general(
180
+ ir.RankedTensorType.get(
181
+ (mat1.type.shape[0], mat2.type.shape[1]), mat1.type.element_type
182
+ ),
183
+ mat1,
184
+ mat2,
185
+ dot_dnums,
186
+ )
187
+
188
+
189
+ # https://pytorch.org/docs/stable/generated/torch.div.html
190
+ # https://openxla.org/stablehlo/spec#divide
191
+ # TODO: support rounding mode and type promotion (see torch.div spec).
192
+ # @lower(torch.ops.aten.div)
193
+ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
194
+ # By default, PyTorch performs a "true" division like Python 3. This requires
195
+ # casting integer input types to float to achieve the same semantics using
196
+ # stablehlo.divide.
197
+ if isinstance(x.type.element_type, ir.IntegerType):
198
+ x = utils.convert_int_to_float(x)
199
+ if isinstance(y.type.element_type, ir.IntegerType):
200
+ y = utils.convert_int_to_float(y)
201
+
202
+ x, y = utils.broadcast_args_if_needed(x, y)
203
+
204
+ return stablehlo.divide(x, y)
@@ -0,0 +1,65 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Provides lowering for coreaten to mlir stablehlo op: Convolution"""
16
+
17
+ from typing import Optional
18
+
19
+ from ai_edge_torch.odml_torch.lowerings import utils
20
+ from jax._src.lib.mlir import ir
21
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
22
+ import torch
23
+
24
+ from .registry import lower
25
+
26
+
27
+ # _native_batch_norm_legit_no_training(
28
+ # Tensor input,
29
+ # Tensor? weight,
30
+ # Tensor? bias,
31
+ # Tensor running_mean,
32
+ # Tensor running_var,
33
+ # float momentum,
34
+ # float eps) -> (Tensor, Tensor, Tensor)
35
+ @lower(torch.ops.aten._native_batch_norm_legit_no_training)
36
+ def _native_batch_norm_legit_no_training(
37
+ lctx,
38
+ input_tensor: ir.Value,
39
+ weight: Optional[ir.Value],
40
+ bias: Optional[ir.Value],
41
+ running_mean: ir.Value,
42
+ running_var: ir.Value,
43
+ momentum: float,
44
+ eps: float,
45
+ ):
46
+ if weight is None:
47
+ weight = utils.splat(
48
+ 1, running_mean.type.element_type, running_mean.type.shape
49
+ )
50
+ if bias is None:
51
+ bias = utils.splat(
52
+ 0, running_mean.type.element_type, running_mean.type.shape
53
+ )
54
+
55
+ return [
56
+ stablehlo.batch_norm_inference(
57
+ input_tensor, weight, bias, running_mean, running_var, eps, 1
58
+ ),
59
+ utils.splat(
60
+ 0, input_tensor.type.element_type
61
+ ), # TODO: return empty array instead
62
+ utils.splat(
63
+ 0, input_tensor.type.element_type
64
+ ), # TODO: return empty array instead
65
+ ]
@@ -0,0 +1,119 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Provides lowering for coreaten to mlir stablehlo op: Convolution"""
16
+
17
+ import math
18
+ from typing import Optional
19
+
20
+ from jax._src.lib.mlir import ir
21
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
22
+ import torch
23
+
24
+ from .registry import lower
25
+
26
+
27
+ # convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride,
28
+ # SymInt[] padding, SymInt[] dilation, bool transposed,
29
+ # SymInt[] output_padding, SymInt groups) -> Tensor
30
+ # @lower(torch.ops.aten.convolution)
31
+ def _aten_convolution(
32
+ lctx,
33
+ lhs: ir.Value,
34
+ rhs: ir.Value,
35
+ bias: Optional[ir.Value],
36
+ stride: list[int],
37
+ padding: list[int],
38
+ dilation: list[int],
39
+ transposed: bool,
40
+ output_padding: list[int],
41
+ groups: int,
42
+ ):
43
+ if transposed:
44
+ raise NotImplementedError("Transposed convolution is not implemented.")
45
+
46
+ if bias is not None:
47
+ raise NotImplementedError("Bias on convolution is not implemented.")
48
+
49
+ # Stablehlo allows start and end padding for each dimension while aten only
50
+ # allows symmetric padding and so only has one number per dimension.
51
+ def make_padding(padding):
52
+ return tuple((p, p) for p in padding)
53
+
54
+ def create_conv_dimension_numbers():
55
+ num_spatial_dims = len(lhs.type.shape) - 2
56
+ spatial_dimensions = []
57
+ for i in range(0, num_spatial_dims):
58
+ spatial_dimensions.append(i + 2)
59
+
60
+ dimension_numbers = stablehlo.ConvDimensionNumbers.get(
61
+ input_batch_dimension=0,
62
+ input_feature_dimension=1,
63
+ input_spatial_dimensions=spatial_dimensions,
64
+ kernel_input_feature_dimension=1,
65
+ kernel_output_feature_dimension=0,
66
+ kernel_spatial_dimensions=spatial_dimensions,
67
+ output_batch_dimension=0,
68
+ output_feature_dimension=1,
69
+ output_spatial_dimensions=spatial_dimensions,
70
+ )
71
+ return dimension_numbers
72
+
73
+ def infer_output_shape():
74
+ lhs_type: ir.RankedTensorType = lhs.type
75
+ lhs_shape: list[int] = lhs_type.shape
76
+ rhs_shape: list[int] = rhs.type.shape
77
+
78
+ # Input layout is: (N)CHW and Kernel layout is: (O)IHW
79
+ output_shape = [lhs_shape[0], rhs_shape[0]]
80
+ num_spatial_dims = len(lhs.type.shape) - 2
81
+
82
+ # looping over the spatial dims (skipping the first 2 dims which are
83
+ # batch and features)
84
+ for spatial_dim in range(0, num_spatial_dims):
85
+ dim_size = lhs_shape[spatial_dim + 2]
86
+ kernel_dim_size = rhs_shape[spatial_dim + 2]
87
+
88
+ # for example, a dilation of 2 increases the dimension size by 2
89
+ dim_size *= dilation[spatial_dim]
90
+
91
+ # padding added to both sides
92
+ dim_size += 2 * padding[spatial_dim]
93
+
94
+ output_dim_size = math.ceil(
95
+ (dim_size - kernel_dim_size + 1) / stride[spatial_dim]
96
+ )
97
+
98
+ output_shape.append(output_dim_size)
99
+
100
+ return output_shape
101
+
102
+ lhs_type: ir.RankedTensorType = lhs.type
103
+
104
+ op = stablehlo.ConvolutionOp(
105
+ result=ir.RankedTensorType.get(
106
+ infer_output_shape(), lhs_type.element_type
107
+ ),
108
+ lhs=lhs,
109
+ rhs=rhs,
110
+ dimension_numbers=create_conv_dimension_numbers(),
111
+ feature_group_count=groups,
112
+ batch_group_count=1,
113
+ window_strides=stride,
114
+ padding=make_padding(padding),
115
+ lhs_dilation=(1,) * len(stride),
116
+ rhs_dilation=dilation,
117
+ )
118
+
119
+ return op.result