ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240911__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +35 -13
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  3. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  4. ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
  5. ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  7. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
  8. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
  9. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  10. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +38 -22
  11. ai_edge_torch/generative/layers/attention.py +60 -63
  12. ai_edge_torch/generative/layers/kv_cache.py +160 -51
  13. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
  14. ai_edge_torch/generative/test/test_model_conversion.py +71 -33
  15. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  16. ai_edge_torch/generative/test/utils.py +54 -0
  17. ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
  18. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  19. ai_edge_torch/version.py +1 -1
  20. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/METADATA +1 -1
  21. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/RECORD +25 -35
  22. ai_edge_torch/generative/examples/experimental/gemma/__init__.py +0 -14
  23. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +0 -88
  24. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  25. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  26. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  27. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  28. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  29. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  30. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  31. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  32. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  33. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  34. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/LICENSE +0 -0
  35. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/WHEEL +0 -0
  36. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/top_level.txt +0 -0
@@ -12,16 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Testing model conversion for a few gen-ai models.
16
- import copy
15
+
16
+ """Testing model conversion for a few gen-ai models."""
17
17
 
18
18
  import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
- from ai_edge_torch.generative.examples.gemma import gemma, gemma2
21
- from ai_edge_torch.generative.examples.phi2 import phi2
22
- from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
20
+ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
23
21
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
- from ai_edge_torch.testing import model_coverage
22
+ from ai_edge_torch.generative.layers import kv_cache
23
+ from ai_edge_torch.generative.test import utils as test_utils
25
24
  import numpy as np
26
25
  import torch
27
26
 
@@ -49,22 +48,32 @@ class TestModelConversion(googletest.TestCase):
49
48
  )
50
49
  def test_toy_model_with_kv_cache(self):
51
50
  config = toy_model_with_kv_cache.get_model_config()
52
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
53
- idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
51
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
52
+ tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
54
53
  [10], dtype=torch.int64
55
54
  )
56
-
57
- edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
55
+ kv = kv_cache.KVCache.from_model_config(config)
56
+
57
+ edge_model = ai_edge_torch.convert(
58
+ pytorch_model,
59
+ sample_kwargs={
60
+ "tokens": tokens,
61
+ "input_pos": input_pos,
62
+ "kv_cache": kv,
63
+ },
64
+ )
58
65
  edge_model.set_interpreter_builder(
59
66
  self._interpreter_builder(edge_model.tflite_model())
60
67
  )
61
68
 
62
69
  self.assertTrue(
63
- model_coverage.compare_tflite_torch(
70
+ test_utils.compare_tflite_torch(
64
71
  edge_model,
65
72
  pytorch_model,
66
- (idx, input_pos),
67
- num_valid_inputs=1,
73
+ tokens,
74
+ input_pos,
75
+ kv,
76
+ signature_name="serving_default",
68
77
  atol=1e-5,
69
78
  rtol=1e-5,
70
79
  )
@@ -77,22 +86,32 @@ class TestModelConversion(googletest.TestCase):
77
86
  def test_toy_model_with_kv_cache_with_hlfb(self):
78
87
  config = toy_model_with_kv_cache.get_model_config()
79
88
  config.enable_hlfb = True
80
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
81
- idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
89
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
90
+ tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
82
91
  [10], dtype=torch.int64
83
92
  )
84
-
85
- edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
93
+ kv = kv_cache.KVCache.from_model_config(config)
94
+
95
+ edge_model = ai_edge_torch.convert(
96
+ pytorch_model,
97
+ sample_kwargs={
98
+ "tokens": tokens,
99
+ "input_pos": input_pos,
100
+ "kv_cache": kv,
101
+ },
102
+ )
86
103
  edge_model.set_interpreter_builder(
87
104
  self._interpreter_builder(edge_model.tflite_model())
88
105
  )
89
106
 
90
107
  self.assertTrue(
91
- model_coverage.compare_tflite_torch(
108
+ test_utils.compare_tflite_torch(
92
109
  edge_model,
93
110
  pytorch_model,
94
- (idx, input_pos),
95
- num_valid_inputs=1,
111
+ tokens,
112
+ input_pos,
113
+ kv,
114
+ signature_name="serving_default",
96
115
  atol=1e-5,
97
116
  rtol=1e-5,
98
117
  )
@@ -117,37 +136,56 @@ class TestModelConversion(googletest.TestCase):
117
136
  decode_token = torch.tensor([[1]], dtype=torch.long)
118
137
  decode_input_pos = torch.tensor([5], dtype=torch.int64)
119
138
 
139
+ kv = kv_cache.KVCache.from_model_config(config)
140
+
120
141
  edge_model = (
121
142
  ai_edge_torch.signature(
122
- "prefill", pytorch_model, (prefill_tokens, prefill_input_pos)
143
+ "prefill",
144
+ pytorch_model,
145
+ sample_kwargs={
146
+ "tokens": prefill_tokens,
147
+ "input_pos": prefill_input_pos,
148
+ "kv_cache": kv,
149
+ },
150
+ )
151
+ .signature(
152
+ "decode",
153
+ pytorch_model,
154
+ sample_kwargs={
155
+ "tokens": decode_token,
156
+ "input_pos": decode_input_pos,
157
+ "kv_cache": kv,
158
+ },
123
159
  )
124
- .signature("decode", pytorch_model, (decode_token, decode_input_pos))
125
160
  .convert()
126
161
  )
127
162
  edge_model.set_interpreter_builder(
128
163
  self._interpreter_builder(edge_model.tflite_model())
129
164
  )
130
165
 
131
- copied_model = copy.deepcopy(pytorch_model)
132
- copied_edge = copy.deepcopy(edge_model)
133
-
134
166
  self.assertTrue(
135
- model_coverage.compare_tflite_torch(
167
+ test_utils.compare_tflite_torch(
136
168
  edge_model,
137
169
  pytorch_model,
138
- (prefill_tokens, prefill_input_pos),
170
+ prefill_tokens,
171
+ prefill_input_pos,
172
+ kv,
139
173
  signature_name="prefill",
140
- num_valid_inputs=1,
174
+ atol=1e-5,
175
+ rtol=1e-5,
141
176
  )
142
177
  )
143
178
 
144
179
  self.assertTrue(
145
- model_coverage.compare_tflite_torch(
146
- copied_edge,
147
- copied_model,
148
- (decode_token, decode_input_pos),
180
+ test_utils.compare_tflite_torch(
181
+ edge_model,
182
+ pytorch_model,
183
+ decode_token,
184
+ decode_input_pos,
185
+ kv,
149
186
  signature_name="decode",
150
- num_valid_inputs=1,
187
+ atol=1e-5,
188
+ rtol=1e-5,
151
189
  )
152
190
  )
153
191
 
@@ -12,16 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Testing model conversion for a few gen-ai models.
16
- import copy
15
+
16
+ """Testing model conversion for a few gen-ai models."""
17
17
 
18
18
  import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
- from ai_edge_torch.generative.examples.gemma import gemma, gemma2
21
- from ai_edge_torch.generative.examples.phi2 import phi2
22
- from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
23
- from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
- from ai_edge_torch.testing import model_coverage
20
+ from ai_edge_torch.generative.examples.gemma import gemma
21
+ from ai_edge_torch.generative.examples.gemma import gemma2
22
+ from ai_edge_torch.generative.examples.phi import phi2
23
+ from ai_edge_torch.generative.layers import kv_cache
24
+ from ai_edge_torch.generative.test import utils as test_utils
25
25
  import numpy as np
26
26
  import torch
27
27
 
@@ -55,18 +55,28 @@ class TestModelConversion(googletest.TestCase):
55
55
  tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
56
56
  tokens[0, :4] = idx
57
57
  input_pos = torch.arange(0, 10)
58
-
59
- edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
58
+ kv = kv_cache.KVCache.from_model_config(config)
59
+
60
+ edge_model = ai_edge_torch.convert(
61
+ model,
62
+ sample_kwargs={
63
+ "tokens": tokens,
64
+ "input_pos": input_pos,
65
+ "kv_cache": kv,
66
+ },
67
+ )
60
68
  edge_model.set_interpreter_builder(
61
69
  self._interpreter_builder(edge_model.tflite_model())
62
70
  )
63
71
 
64
72
  self.assertTrue(
65
- model_coverage.compare_tflite_torch(
73
+ test_utils.compare_tflite_torch(
66
74
  edge_model,
67
75
  model,
68
- (tokens, input_pos),
69
- num_valid_inputs=1,
76
+ tokens,
77
+ input_pos,
78
+ kv,
79
+ signature_name="serving_default",
70
80
  atol=1e-2,
71
81
  rtol=1e-5,
72
82
  )
@@ -85,23 +95,31 @@ class TestModelConversion(googletest.TestCase):
85
95
  prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
86
96
  prefill_tokens[0, :4] = idx
87
97
  prefill_input_pos = torch.arange(0, 10)
98
+ kv = kv_cache.KVCache.from_model_config(config)
88
99
 
89
100
  edge_model = ai_edge_torch.signature(
90
- "prefill", model, (prefill_tokens, prefill_input_pos)
101
+ "prefill",
102
+ model,
103
+ sample_kwargs={
104
+ "tokens": prefill_tokens,
105
+ "input_pos": prefill_input_pos,
106
+ "kv_cache": kv,
107
+ },
91
108
  ).convert()
92
109
  edge_model.set_interpreter_builder(
93
110
  self._interpreter_builder(edge_model.tflite_model())
94
111
  )
95
112
 
96
113
  self.assertTrue(
97
- model_coverage.compare_tflite_torch(
114
+ test_utils.compare_tflite_torch(
98
115
  edge_model,
99
116
  model,
100
- (prefill_tokens, prefill_input_pos),
117
+ prefill_tokens,
118
+ prefill_input_pos,
119
+ kv,
101
120
  signature_name="prefill",
102
- num_valid_inputs=1,
103
- atol=1e-2,
104
- rtol=1e-5,
121
+ atol=1e-1,
122
+ rtol=1e-3,
105
123
  )
106
124
  )
107
125
 
@@ -117,18 +135,28 @@ class TestModelConversion(googletest.TestCase):
117
135
  tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
118
136
  tokens[0, :4] = idx
119
137
  input_pos = torch.arange(0, 10)
120
-
121
- edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
138
+ kv = kv_cache.KVCache.from_model_config(config)
139
+
140
+ edge_model = ai_edge_torch.convert(
141
+ pytorch_model,
142
+ sample_kwargs={
143
+ "tokens": tokens,
144
+ "input_pos": input_pos,
145
+ "kv_cache": kv,
146
+ },
147
+ )
122
148
  edge_model.set_interpreter_builder(
123
149
  self._interpreter_builder(edge_model.tflite_model())
124
150
  )
125
151
 
126
152
  self.assertTrue(
127
- model_coverage.compare_tflite_torch(
153
+ test_utils.compare_tflite_torch(
128
154
  edge_model,
129
155
  pytorch_model,
130
- (tokens, input_pos),
131
- num_valid_inputs=1,
156
+ tokens,
157
+ input_pos,
158
+ kv,
159
+ signature_name="serving_default",
132
160
  atol=1e-3,
133
161
  rtol=1e-3,
134
162
  )
@@ -0,0 +1,54 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Common utils for testing."""
17
+
18
+ from ai_edge_torch import model
19
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
20
+ from ai_edge_torch.lowertools import common_utils
21
+ import numpy as np
22
+ import torch
23
+ from torch.utils import _pytree as pytree
24
+
25
+
26
+ def compare_tflite_torch(
27
+ edge_model: model.Model,
28
+ torch_model: torch.nn.Module,
29
+ tokens: torch.Tensor,
30
+ input_pos: torch.Tensor,
31
+ kv_cache: kv_utils.KVCache,
32
+ signature_name: str,
33
+ atol: float = 1e-5,
34
+ rtol: float = 1e-5,
35
+ ):
36
+ """Compares torch models and TFLite models."""
37
+ values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
38
+ flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
39
+ torch_output = torch_model(tokens, input_pos, kv_cache)
40
+
41
+ input_kv_flatten = {k: v.numpy() for k, v in zip(flat_names, values)}
42
+ edge_output = edge_model(
43
+ signature_name=signature_name,
44
+ tokens=tokens.numpy(),
45
+ input_pos=input_pos.numpy(),
46
+ **input_kv_flatten,
47
+ )
48
+
49
+ return np.allclose(
50
+ edge_output["logits"],
51
+ torch_output["logits"].detach().numpy(),
52
+ atol=atol,
53
+ rtol=rtol,
54
+ )
@@ -12,22 +12,171 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Provides lowering for coreaten to mlir stablehlo op: Convolution"""
15
+ """Provides lowering for coreaten to stablehlo for Convolution."""
16
16
 
17
17
  import math
18
18
  from typing import Optional
19
19
 
20
+ from ai_edge_torch.odml_torch.lowerings import registry
20
21
  from jax._src.lib.mlir import ir
21
22
  from jax._src.lib.mlir.dialects import hlo as stablehlo
22
23
  import torch
23
24
 
24
- from .registry import lower
25
+
26
+ def make_padding(padding):
27
+ """Change the padding from pytorch to stablehlo style.
28
+
29
+ Stablehlo allows start and end padding for each dimension while aten only
30
+ allows symmetric padding and so only has one number per dimension.
31
+
32
+ Args:
33
+ padding: The padding of the convolution
34
+
35
+ Returns:
36
+ The padding in stablehlo style
37
+ """
38
+ return tuple((p, p) for p in padding)
39
+
40
+
41
+ def create_conv_dimension_numbers(lhs, transposed: bool = False):
42
+ """Create the dimension numbers for the convolution.
43
+
44
+ Args:
45
+ lhs: The input tensor
46
+ transposed: Whether the convolution is transposed
47
+
48
+ Returns:
49
+ The dimension numbers for the convolution
50
+ """
51
+ num_spatial_dims = len(lhs.type.shape) - 2
52
+ spatial_dimensions = []
53
+ for i in range(0, num_spatial_dims):
54
+ spatial_dimensions.append(i + 2)
55
+
56
+ # Regular kernels are OIHW
57
+ # TransposedConv kernels are IOHW
58
+ dimension_numbers = stablehlo.ConvDimensionNumbers.get(
59
+ input_batch_dimension=0,
60
+ input_feature_dimension=1,
61
+ input_spatial_dimensions=spatial_dimensions,
62
+ kernel_input_feature_dimension=0 if transposed else 1,
63
+ kernel_output_feature_dimension=1 if transposed else 0,
64
+ kernel_spatial_dimensions=spatial_dimensions,
65
+ output_batch_dimension=0,
66
+ output_feature_dimension=1,
67
+ output_spatial_dimensions=spatial_dimensions,
68
+ )
69
+ return dimension_numbers
70
+
71
+
72
+ def infer_output_shape(
73
+ lhs,
74
+ rhs,
75
+ stride,
76
+ dilation,
77
+ padding,
78
+ transposed: bool = False,
79
+ output_padding: list[int] = 0,
80
+ ):
81
+ """Infer the output shape of the convolution.
82
+
83
+ Args:
84
+ lhs: The input tensor
85
+ rhs: The kernel tensor
86
+ stride: The stride of the convolution (dilation of input in transposed conv)
87
+ dilation: The kernel dilation of the convolution
88
+ padding: The padding of the convolution
89
+ transposed: Whether the convolution is transposed
90
+ output_padding: The output padding of the convolution
91
+
92
+ Returns:
93
+ The output shape of the convolution
94
+ """
95
+ lhs_type: ir.RankedTensorType = lhs.type
96
+ lhs_shape: list[int] = lhs_type.shape
97
+ rhs_shape: list[int] = rhs.type.shape
98
+
99
+ # Input layout is: (N)CHW and Kernel layout is: (O)IHW for regular conv
100
+ # Input layout is: (N)CHW and Kernel layout is: I(O)HW for transposed conv
101
+ output_shape = (
102
+ [lhs_shape[0], rhs_shape[1]]
103
+ if transposed
104
+ else [lhs_shape[0], rhs_shape[0]]
105
+ )
106
+ num_spatial_dims = len(lhs.type.shape) - 2
107
+
108
+ # looping over the spatial dims (skipping the first 2 dims which are
109
+ # batch and features)
110
+ for spatial_dim in range(0, num_spatial_dims):
111
+ dim = spatial_dim + 2
112
+ dim_size = lhs_shape[dim]
113
+ kernel_dim_size = rhs_shape[dim]
114
+
115
+ if transposed:
116
+ output_dim_size = (
117
+ (dim_size - 1) * stride[spatial_dim]
118
+ - 2 * padding[spatial_dim]
119
+ + dilation[spatial_dim] * (kernel_dim_size - 1)
120
+ + output_padding[spatial_dim]
121
+ + 1
122
+ )
123
+ else:
124
+ output_dim_size = math.floor(
125
+ (
126
+ (
127
+ dim_size
128
+ + 2 * padding[spatial_dim]
129
+ - dilation[spatial_dim] * (kernel_dim_size - 1)
130
+ - 1
131
+ )
132
+ / stride[spatial_dim]
133
+ )
134
+ + 1
135
+ )
136
+
137
+ output_shape.append(output_dim_size)
138
+
139
+ return output_shape
140
+
141
+
142
+ def build_transpose_conv(
143
+ lctx,
144
+ output_type: ir.RankedTensorType,
145
+ lhs: ir.Value,
146
+ rhs: ir.Value,
147
+ stride: list[int],
148
+ padding: list[int],
149
+ dilation: list[int],
150
+ output_padding: list[int],
151
+ groups: int,
152
+ ):
153
+ lhs_type: ir.RankedTensorType = lhs.type
154
+ num_spatial_dims = len(lhs_type.shape) - 2
155
+ rhs = stablehlo.reverse(rhs, list(range(2, 2 + num_spatial_dims)))
156
+
157
+ kernel_size = rhs.type.shape
158
+ # We need to additional padding on the input to get the right output size.
159
+ adjusted_padding = [
160
+ dilation[dim] * (kernel_size[dim + 2] - 1) - padding[dim]
161
+ for dim in range(num_spatial_dims)
162
+ ]
163
+ return stablehlo.convolution(
164
+ result=output_type,
165
+ lhs=lhs,
166
+ rhs=rhs,
167
+ dimension_numbers=create_conv_dimension_numbers(lhs, True),
168
+ feature_group_count=groups,
169
+ batch_group_count=1,
170
+ padding=make_padding(adjusted_padding),
171
+ lhs_dilation=stride,
172
+ rhs_dilation=dilation,
173
+ )
25
174
 
26
175
 
27
176
  # convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride,
28
177
  # SymInt[] padding, SymInt[] dilation, bool transposed,
29
178
  # SymInt[] output_padding, SymInt groups) -> Tensor
30
- # @lower(torch.ops.aten.convolution)
179
+ @registry.lower(torch.ops.aten.convolution)
31
180
  def _aten_convolution(
32
181
  lctx,
33
182
  lhs: ir.Value,
@@ -40,80 +189,53 @@ def _aten_convolution(
40
189
  output_padding: list[int],
41
190
  groups: int,
42
191
  ):
43
- if transposed:
44
- raise NotImplementedError("Transposed convolution is not implemented.")
45
192
 
46
- if bias is not None:
47
- raise NotImplementedError("Bias on convolution is not implemented.")
48
-
49
- # Stablehlo allows start and end padding for each dimension while aten only
50
- # allows symmetric padding and so only has one number per dimension.
51
- def make_padding(padding):
52
- return tuple((p, p) for p in padding)
53
-
54
- def create_conv_dimension_numbers():
55
- num_spatial_dims = len(lhs.type.shape) - 2
56
- spatial_dimensions = []
57
- for i in range(0, num_spatial_dims):
58
- spatial_dimensions.append(i + 2)
59
-
60
- dimension_numbers = stablehlo.ConvDimensionNumbers.get(
61
- input_batch_dimension=0,
62
- input_feature_dimension=1,
63
- input_spatial_dimensions=spatial_dimensions,
64
- kernel_input_feature_dimension=1,
65
- kernel_output_feature_dimension=0,
66
- kernel_spatial_dimensions=spatial_dimensions,
67
- output_batch_dimension=0,
68
- output_feature_dimension=1,
69
- output_spatial_dimensions=spatial_dimensions,
193
+ # TODO(b/365559296) Add support for output_padding
194
+ if any(output_padding):
195
+ raise NotImplementedError(
196
+ "Output padding on convolution is not implemented."
70
197
  )
71
- return dimension_numbers
72
-
73
- def infer_output_shape():
74
- lhs_type: ir.RankedTensorType = lhs.type
75
- lhs_shape: list[int] = lhs_type.shape
76
- rhs_shape: list[int] = rhs.type.shape
77
-
78
- # Input layout is: (N)CHW and Kernel layout is: (O)IHW
79
- output_shape = [lhs_shape[0], rhs_shape[0]]
80
- num_spatial_dims = len(lhs.type.shape) - 2
81
-
82
- # looping over the spatial dims (skipping the first 2 dims which are
83
- # batch and features)
84
- for spatial_dim in range(0, num_spatial_dims):
85
- dim_size = lhs_shape[spatial_dim + 2]
86
- kernel_dim_size = rhs_shape[spatial_dim + 2]
87
-
88
- # for example, a dilation of 2 increases the dimension size by 2
89
- dim_size *= dilation[spatial_dim]
90
-
91
- # padding added to both sides
92
- dim_size += 2 * padding[spatial_dim]
93
-
94
- output_dim_size = math.ceil(
95
- (dim_size - kernel_dim_size + 1) / stride[spatial_dim]
96
- )
97
-
98
- output_shape.append(output_dim_size)
99
-
100
- return output_shape
101
198
 
102
199
  lhs_type: ir.RankedTensorType = lhs.type
103
-
104
- op = stablehlo.ConvolutionOp(
105
- result=ir.RankedTensorType.get(
106
- infer_output_shape(), lhs_type.element_type
107
- ),
108
- lhs=lhs,
109
- rhs=rhs,
110
- dimension_numbers=create_conv_dimension_numbers(),
111
- feature_group_count=groups,
112
- batch_group_count=1,
113
- window_strides=stride,
114
- padding=make_padding(padding),
115
- lhs_dilation=(1,) * len(stride),
116
- rhs_dilation=dilation,
200
+ output_shape = infer_output_shape(
201
+ lhs, rhs, stride, dilation, padding, transposed, output_padding
202
+ )
203
+ output_type = ir.RankedTensorType.get(
204
+ output_shape,
205
+ lhs_type.element_type,
117
206
  )
118
207
 
119
- return op.result
208
+ if transposed:
209
+ res = build_transpose_conv(
210
+ lctx,
211
+ output_type,
212
+ lhs,
213
+ rhs,
214
+ stride,
215
+ padding,
216
+ dilation,
217
+ output_padding,
218
+ groups,
219
+ )
220
+ else:
221
+ res = stablehlo.convolution(
222
+ result=output_type,
223
+ lhs=lhs,
224
+ rhs=rhs,
225
+ dimension_numbers=create_conv_dimension_numbers(lhs),
226
+ feature_group_count=groups,
227
+ batch_group_count=1,
228
+ window_strides=stride,
229
+ padding=make_padding(padding),
230
+ rhs_dilation=dilation,
231
+ )
232
+
233
+ if bias is not None:
234
+ # broadcast [C] to [NCHW]
235
+ broadcasted_bias = stablehlo.broadcast_in_dim(output_type, bias, [1])
236
+ res = stablehlo.add(
237
+ lhs=res,
238
+ rhs=broadcasted_bias,
239
+ )
240
+
241
+ return res
@@ -105,7 +105,6 @@ lower_by_torch_xla2(torch.ops.aten.clamp.default)
105
105
  lower_by_torch_xla2(torch.ops.aten.clone)
106
106
  lower_by_torch_xla2(torch.ops.aten.clone.default)
107
107
  lower_by_torch_xla2(torch.ops.aten.constant_pad_nd)
108
- lower_by_torch_xla2(torch.ops.aten.convolution)
109
108
  lower_by_torch_xla2(torch.ops.aten.cos)
110
109
  lower_by_torch_xla2(torch.ops.aten.cosh)
111
110
  lower_by_torch_xla2(torch.ops.aten.cumsum)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240909"
16
+ __version__ = "0.3.0.dev20240911"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240909
3
+ Version: 0.3.0.dev20240911
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI