ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240911__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/test/test_convert.py +35 -13
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
- ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +38 -22
- ai_edge_torch/generative/layers/attention.py +60 -63
- ai_edge_torch/generative/layers/kv_cache.py +160 -51
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
- ai_edge_torch/generative/test/test_model_conversion.py +71 -33
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/RECORD +25 -35
- ai_edge_torch/generative/examples/experimental/gemma/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +0 -88
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/top_level.txt +0 -0
@@ -12,16 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
15
|
+
|
16
|
+
"""Testing model conversion for a few gen-ai models."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
|
-
from ai_edge_torch.generative.examples.
|
21
|
-
from ai_edge_torch.generative.examples.phi2 import phi2
|
22
|
-
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
|
20
|
+
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
|
23
21
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
24
|
-
from ai_edge_torch.
|
22
|
+
from ai_edge_torch.generative.layers import kv_cache
|
23
|
+
from ai_edge_torch.generative.test import utils as test_utils
|
25
24
|
import numpy as np
|
26
25
|
import torch
|
27
26
|
|
@@ -49,22 +48,32 @@ class TestModelConversion(googletest.TestCase):
|
|
49
48
|
)
|
50
49
|
def test_toy_model_with_kv_cache(self):
|
51
50
|
config = toy_model_with_kv_cache.get_model_config()
|
52
|
-
pytorch_model = toy_model_with_kv_cache.
|
53
|
-
|
51
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
52
|
+
tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
54
53
|
[10], dtype=torch.int64
|
55
54
|
)
|
56
|
-
|
57
|
-
|
55
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
56
|
+
|
57
|
+
edge_model = ai_edge_torch.convert(
|
58
|
+
pytorch_model,
|
59
|
+
sample_kwargs={
|
60
|
+
"tokens": tokens,
|
61
|
+
"input_pos": input_pos,
|
62
|
+
"kv_cache": kv,
|
63
|
+
},
|
64
|
+
)
|
58
65
|
edge_model.set_interpreter_builder(
|
59
66
|
self._interpreter_builder(edge_model.tflite_model())
|
60
67
|
)
|
61
68
|
|
62
69
|
self.assertTrue(
|
63
|
-
|
70
|
+
test_utils.compare_tflite_torch(
|
64
71
|
edge_model,
|
65
72
|
pytorch_model,
|
66
|
-
|
67
|
-
|
73
|
+
tokens,
|
74
|
+
input_pos,
|
75
|
+
kv,
|
76
|
+
signature_name="serving_default",
|
68
77
|
atol=1e-5,
|
69
78
|
rtol=1e-5,
|
70
79
|
)
|
@@ -77,22 +86,32 @@ class TestModelConversion(googletest.TestCase):
|
|
77
86
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
78
87
|
config = toy_model_with_kv_cache.get_model_config()
|
79
88
|
config.enable_hlfb = True
|
80
|
-
pytorch_model = toy_model_with_kv_cache.
|
81
|
-
|
89
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
90
|
+
tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
82
91
|
[10], dtype=torch.int64
|
83
92
|
)
|
84
|
-
|
85
|
-
|
93
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
94
|
+
|
95
|
+
edge_model = ai_edge_torch.convert(
|
96
|
+
pytorch_model,
|
97
|
+
sample_kwargs={
|
98
|
+
"tokens": tokens,
|
99
|
+
"input_pos": input_pos,
|
100
|
+
"kv_cache": kv,
|
101
|
+
},
|
102
|
+
)
|
86
103
|
edge_model.set_interpreter_builder(
|
87
104
|
self._interpreter_builder(edge_model.tflite_model())
|
88
105
|
)
|
89
106
|
|
90
107
|
self.assertTrue(
|
91
|
-
|
108
|
+
test_utils.compare_tflite_torch(
|
92
109
|
edge_model,
|
93
110
|
pytorch_model,
|
94
|
-
|
95
|
-
|
111
|
+
tokens,
|
112
|
+
input_pos,
|
113
|
+
kv,
|
114
|
+
signature_name="serving_default",
|
96
115
|
atol=1e-5,
|
97
116
|
rtol=1e-5,
|
98
117
|
)
|
@@ -117,37 +136,56 @@ class TestModelConversion(googletest.TestCase):
|
|
117
136
|
decode_token = torch.tensor([[1]], dtype=torch.long)
|
118
137
|
decode_input_pos = torch.tensor([5], dtype=torch.int64)
|
119
138
|
|
139
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
140
|
+
|
120
141
|
edge_model = (
|
121
142
|
ai_edge_torch.signature(
|
122
|
-
"prefill",
|
143
|
+
"prefill",
|
144
|
+
pytorch_model,
|
145
|
+
sample_kwargs={
|
146
|
+
"tokens": prefill_tokens,
|
147
|
+
"input_pos": prefill_input_pos,
|
148
|
+
"kv_cache": kv,
|
149
|
+
},
|
150
|
+
)
|
151
|
+
.signature(
|
152
|
+
"decode",
|
153
|
+
pytorch_model,
|
154
|
+
sample_kwargs={
|
155
|
+
"tokens": decode_token,
|
156
|
+
"input_pos": decode_input_pos,
|
157
|
+
"kv_cache": kv,
|
158
|
+
},
|
123
159
|
)
|
124
|
-
.signature("decode", pytorch_model, (decode_token, decode_input_pos))
|
125
160
|
.convert()
|
126
161
|
)
|
127
162
|
edge_model.set_interpreter_builder(
|
128
163
|
self._interpreter_builder(edge_model.tflite_model())
|
129
164
|
)
|
130
165
|
|
131
|
-
copied_model = copy.deepcopy(pytorch_model)
|
132
|
-
copied_edge = copy.deepcopy(edge_model)
|
133
|
-
|
134
166
|
self.assertTrue(
|
135
|
-
|
167
|
+
test_utils.compare_tflite_torch(
|
136
168
|
edge_model,
|
137
169
|
pytorch_model,
|
138
|
-
|
170
|
+
prefill_tokens,
|
171
|
+
prefill_input_pos,
|
172
|
+
kv,
|
139
173
|
signature_name="prefill",
|
140
|
-
|
174
|
+
atol=1e-5,
|
175
|
+
rtol=1e-5,
|
141
176
|
)
|
142
177
|
)
|
143
178
|
|
144
179
|
self.assertTrue(
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
180
|
+
test_utils.compare_tflite_torch(
|
181
|
+
edge_model,
|
182
|
+
pytorch_model,
|
183
|
+
decode_token,
|
184
|
+
decode_input_pos,
|
185
|
+
kv,
|
149
186
|
signature_name="decode",
|
150
|
-
|
187
|
+
atol=1e-5,
|
188
|
+
rtol=1e-5,
|
151
189
|
)
|
152
190
|
)
|
153
191
|
|
@@ -12,16 +12,16 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
15
|
+
|
16
|
+
"""Testing model conversion for a few gen-ai models."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
|
-
from ai_edge_torch.generative.examples.gemma import gemma
|
21
|
-
from ai_edge_torch.generative.examples.
|
22
|
-
from ai_edge_torch.generative.examples.
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.
|
20
|
+
from ai_edge_torch.generative.examples.gemma import gemma
|
21
|
+
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
|
+
from ai_edge_torch.generative.examples.phi import phi2
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache
|
24
|
+
from ai_edge_torch.generative.test import utils as test_utils
|
25
25
|
import numpy as np
|
26
26
|
import torch
|
27
27
|
|
@@ -55,18 +55,28 @@ class TestModelConversion(googletest.TestCase):
|
|
55
55
|
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
56
56
|
tokens[0, :4] = idx
|
57
57
|
input_pos = torch.arange(0, 10)
|
58
|
-
|
59
|
-
|
58
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
59
|
+
|
60
|
+
edge_model = ai_edge_torch.convert(
|
61
|
+
model,
|
62
|
+
sample_kwargs={
|
63
|
+
"tokens": tokens,
|
64
|
+
"input_pos": input_pos,
|
65
|
+
"kv_cache": kv,
|
66
|
+
},
|
67
|
+
)
|
60
68
|
edge_model.set_interpreter_builder(
|
61
69
|
self._interpreter_builder(edge_model.tflite_model())
|
62
70
|
)
|
63
71
|
|
64
72
|
self.assertTrue(
|
65
|
-
|
73
|
+
test_utils.compare_tflite_torch(
|
66
74
|
edge_model,
|
67
75
|
model,
|
68
|
-
|
69
|
-
|
76
|
+
tokens,
|
77
|
+
input_pos,
|
78
|
+
kv,
|
79
|
+
signature_name="serving_default",
|
70
80
|
atol=1e-2,
|
71
81
|
rtol=1e-5,
|
72
82
|
)
|
@@ -85,23 +95,31 @@ class TestModelConversion(googletest.TestCase):
|
|
85
95
|
prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
86
96
|
prefill_tokens[0, :4] = idx
|
87
97
|
prefill_input_pos = torch.arange(0, 10)
|
98
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
88
99
|
|
89
100
|
edge_model = ai_edge_torch.signature(
|
90
|
-
"prefill",
|
101
|
+
"prefill",
|
102
|
+
model,
|
103
|
+
sample_kwargs={
|
104
|
+
"tokens": prefill_tokens,
|
105
|
+
"input_pos": prefill_input_pos,
|
106
|
+
"kv_cache": kv,
|
107
|
+
},
|
91
108
|
).convert()
|
92
109
|
edge_model.set_interpreter_builder(
|
93
110
|
self._interpreter_builder(edge_model.tflite_model())
|
94
111
|
)
|
95
112
|
|
96
113
|
self.assertTrue(
|
97
|
-
|
114
|
+
test_utils.compare_tflite_torch(
|
98
115
|
edge_model,
|
99
116
|
model,
|
100
|
-
|
117
|
+
prefill_tokens,
|
118
|
+
prefill_input_pos,
|
119
|
+
kv,
|
101
120
|
signature_name="prefill",
|
102
|
-
|
103
|
-
|
104
|
-
rtol=1e-5,
|
121
|
+
atol=1e-1,
|
122
|
+
rtol=1e-3,
|
105
123
|
)
|
106
124
|
)
|
107
125
|
|
@@ -117,18 +135,28 @@ class TestModelConversion(googletest.TestCase):
|
|
117
135
|
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
118
136
|
tokens[0, :4] = idx
|
119
137
|
input_pos = torch.arange(0, 10)
|
120
|
-
|
121
|
-
|
138
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
139
|
+
|
140
|
+
edge_model = ai_edge_torch.convert(
|
141
|
+
pytorch_model,
|
142
|
+
sample_kwargs={
|
143
|
+
"tokens": tokens,
|
144
|
+
"input_pos": input_pos,
|
145
|
+
"kv_cache": kv,
|
146
|
+
},
|
147
|
+
)
|
122
148
|
edge_model.set_interpreter_builder(
|
123
149
|
self._interpreter_builder(edge_model.tflite_model())
|
124
150
|
)
|
125
151
|
|
126
152
|
self.assertTrue(
|
127
|
-
|
153
|
+
test_utils.compare_tflite_torch(
|
128
154
|
edge_model,
|
129
155
|
pytorch_model,
|
130
|
-
|
131
|
-
|
156
|
+
tokens,
|
157
|
+
input_pos,
|
158
|
+
kv,
|
159
|
+
signature_name="serving_default",
|
132
160
|
atol=1e-3,
|
133
161
|
rtol=1e-3,
|
134
162
|
)
|
@@ -0,0 +1,54 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Common utils for testing."""
|
17
|
+
|
18
|
+
from ai_edge_torch import model
|
19
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
|
+
from ai_edge_torch.lowertools import common_utils
|
21
|
+
import numpy as np
|
22
|
+
import torch
|
23
|
+
from torch.utils import _pytree as pytree
|
24
|
+
|
25
|
+
|
26
|
+
def compare_tflite_torch(
|
27
|
+
edge_model: model.Model,
|
28
|
+
torch_model: torch.nn.Module,
|
29
|
+
tokens: torch.Tensor,
|
30
|
+
input_pos: torch.Tensor,
|
31
|
+
kv_cache: kv_utils.KVCache,
|
32
|
+
signature_name: str,
|
33
|
+
atol: float = 1e-5,
|
34
|
+
rtol: float = 1e-5,
|
35
|
+
):
|
36
|
+
"""Compares torch models and TFLite models."""
|
37
|
+
values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
|
38
|
+
flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
|
39
|
+
torch_output = torch_model(tokens, input_pos, kv_cache)
|
40
|
+
|
41
|
+
input_kv_flatten = {k: v.numpy() for k, v in zip(flat_names, values)}
|
42
|
+
edge_output = edge_model(
|
43
|
+
signature_name=signature_name,
|
44
|
+
tokens=tokens.numpy(),
|
45
|
+
input_pos=input_pos.numpy(),
|
46
|
+
**input_kv_flatten,
|
47
|
+
)
|
48
|
+
|
49
|
+
return np.allclose(
|
50
|
+
edge_output["logits"],
|
51
|
+
torch_output["logits"].detach().numpy(),
|
52
|
+
atol=atol,
|
53
|
+
rtol=rtol,
|
54
|
+
)
|
@@ -12,22 +12,171 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""Provides lowering for coreaten to
|
15
|
+
"""Provides lowering for coreaten to stablehlo for Convolution."""
|
16
16
|
|
17
17
|
import math
|
18
18
|
from typing import Optional
|
19
19
|
|
20
|
+
from ai_edge_torch.odml_torch.lowerings import registry
|
20
21
|
from jax._src.lib.mlir import ir
|
21
22
|
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
22
23
|
import torch
|
23
24
|
|
24
|
-
|
25
|
+
|
26
|
+
def make_padding(padding):
|
27
|
+
"""Change the padding from pytorch to stablehlo style.
|
28
|
+
|
29
|
+
Stablehlo allows start and end padding for each dimension while aten only
|
30
|
+
allows symmetric padding and so only has one number per dimension.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
padding: The padding of the convolution
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
The padding in stablehlo style
|
37
|
+
"""
|
38
|
+
return tuple((p, p) for p in padding)
|
39
|
+
|
40
|
+
|
41
|
+
def create_conv_dimension_numbers(lhs, transposed: bool = False):
|
42
|
+
"""Create the dimension numbers for the convolution.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
lhs: The input tensor
|
46
|
+
transposed: Whether the convolution is transposed
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
The dimension numbers for the convolution
|
50
|
+
"""
|
51
|
+
num_spatial_dims = len(lhs.type.shape) - 2
|
52
|
+
spatial_dimensions = []
|
53
|
+
for i in range(0, num_spatial_dims):
|
54
|
+
spatial_dimensions.append(i + 2)
|
55
|
+
|
56
|
+
# Regular kernels are OIHW
|
57
|
+
# TransposedConv kernels are IOHW
|
58
|
+
dimension_numbers = stablehlo.ConvDimensionNumbers.get(
|
59
|
+
input_batch_dimension=0,
|
60
|
+
input_feature_dimension=1,
|
61
|
+
input_spatial_dimensions=spatial_dimensions,
|
62
|
+
kernel_input_feature_dimension=0 if transposed else 1,
|
63
|
+
kernel_output_feature_dimension=1 if transposed else 0,
|
64
|
+
kernel_spatial_dimensions=spatial_dimensions,
|
65
|
+
output_batch_dimension=0,
|
66
|
+
output_feature_dimension=1,
|
67
|
+
output_spatial_dimensions=spatial_dimensions,
|
68
|
+
)
|
69
|
+
return dimension_numbers
|
70
|
+
|
71
|
+
|
72
|
+
def infer_output_shape(
|
73
|
+
lhs,
|
74
|
+
rhs,
|
75
|
+
stride,
|
76
|
+
dilation,
|
77
|
+
padding,
|
78
|
+
transposed: bool = False,
|
79
|
+
output_padding: list[int] = 0,
|
80
|
+
):
|
81
|
+
"""Infer the output shape of the convolution.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
lhs: The input tensor
|
85
|
+
rhs: The kernel tensor
|
86
|
+
stride: The stride of the convolution (dilation of input in transposed conv)
|
87
|
+
dilation: The kernel dilation of the convolution
|
88
|
+
padding: The padding of the convolution
|
89
|
+
transposed: Whether the convolution is transposed
|
90
|
+
output_padding: The output padding of the convolution
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
The output shape of the convolution
|
94
|
+
"""
|
95
|
+
lhs_type: ir.RankedTensorType = lhs.type
|
96
|
+
lhs_shape: list[int] = lhs_type.shape
|
97
|
+
rhs_shape: list[int] = rhs.type.shape
|
98
|
+
|
99
|
+
# Input layout is: (N)CHW and Kernel layout is: (O)IHW for regular conv
|
100
|
+
# Input layout is: (N)CHW and Kernel layout is: I(O)HW for transposed conv
|
101
|
+
output_shape = (
|
102
|
+
[lhs_shape[0], rhs_shape[1]]
|
103
|
+
if transposed
|
104
|
+
else [lhs_shape[0], rhs_shape[0]]
|
105
|
+
)
|
106
|
+
num_spatial_dims = len(lhs.type.shape) - 2
|
107
|
+
|
108
|
+
# looping over the spatial dims (skipping the first 2 dims which are
|
109
|
+
# batch and features)
|
110
|
+
for spatial_dim in range(0, num_spatial_dims):
|
111
|
+
dim = spatial_dim + 2
|
112
|
+
dim_size = lhs_shape[dim]
|
113
|
+
kernel_dim_size = rhs_shape[dim]
|
114
|
+
|
115
|
+
if transposed:
|
116
|
+
output_dim_size = (
|
117
|
+
(dim_size - 1) * stride[spatial_dim]
|
118
|
+
- 2 * padding[spatial_dim]
|
119
|
+
+ dilation[spatial_dim] * (kernel_dim_size - 1)
|
120
|
+
+ output_padding[spatial_dim]
|
121
|
+
+ 1
|
122
|
+
)
|
123
|
+
else:
|
124
|
+
output_dim_size = math.floor(
|
125
|
+
(
|
126
|
+
(
|
127
|
+
dim_size
|
128
|
+
+ 2 * padding[spatial_dim]
|
129
|
+
- dilation[spatial_dim] * (kernel_dim_size - 1)
|
130
|
+
- 1
|
131
|
+
)
|
132
|
+
/ stride[spatial_dim]
|
133
|
+
)
|
134
|
+
+ 1
|
135
|
+
)
|
136
|
+
|
137
|
+
output_shape.append(output_dim_size)
|
138
|
+
|
139
|
+
return output_shape
|
140
|
+
|
141
|
+
|
142
|
+
def build_transpose_conv(
|
143
|
+
lctx,
|
144
|
+
output_type: ir.RankedTensorType,
|
145
|
+
lhs: ir.Value,
|
146
|
+
rhs: ir.Value,
|
147
|
+
stride: list[int],
|
148
|
+
padding: list[int],
|
149
|
+
dilation: list[int],
|
150
|
+
output_padding: list[int],
|
151
|
+
groups: int,
|
152
|
+
):
|
153
|
+
lhs_type: ir.RankedTensorType = lhs.type
|
154
|
+
num_spatial_dims = len(lhs_type.shape) - 2
|
155
|
+
rhs = stablehlo.reverse(rhs, list(range(2, 2 + num_spatial_dims)))
|
156
|
+
|
157
|
+
kernel_size = rhs.type.shape
|
158
|
+
# We need to additional padding on the input to get the right output size.
|
159
|
+
adjusted_padding = [
|
160
|
+
dilation[dim] * (kernel_size[dim + 2] - 1) - padding[dim]
|
161
|
+
for dim in range(num_spatial_dims)
|
162
|
+
]
|
163
|
+
return stablehlo.convolution(
|
164
|
+
result=output_type,
|
165
|
+
lhs=lhs,
|
166
|
+
rhs=rhs,
|
167
|
+
dimension_numbers=create_conv_dimension_numbers(lhs, True),
|
168
|
+
feature_group_count=groups,
|
169
|
+
batch_group_count=1,
|
170
|
+
padding=make_padding(adjusted_padding),
|
171
|
+
lhs_dilation=stride,
|
172
|
+
rhs_dilation=dilation,
|
173
|
+
)
|
25
174
|
|
26
175
|
|
27
176
|
# convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride,
|
28
177
|
# SymInt[] padding, SymInt[] dilation, bool transposed,
|
29
178
|
# SymInt[] output_padding, SymInt groups) -> Tensor
|
30
|
-
|
179
|
+
@registry.lower(torch.ops.aten.convolution)
|
31
180
|
def _aten_convolution(
|
32
181
|
lctx,
|
33
182
|
lhs: ir.Value,
|
@@ -40,80 +189,53 @@ def _aten_convolution(
|
|
40
189
|
output_padding: list[int],
|
41
190
|
groups: int,
|
42
191
|
):
|
43
|
-
if transposed:
|
44
|
-
raise NotImplementedError("Transposed convolution is not implemented.")
|
45
192
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
# allows symmetric padding and so only has one number per dimension.
|
51
|
-
def make_padding(padding):
|
52
|
-
return tuple((p, p) for p in padding)
|
53
|
-
|
54
|
-
def create_conv_dimension_numbers():
|
55
|
-
num_spatial_dims = len(lhs.type.shape) - 2
|
56
|
-
spatial_dimensions = []
|
57
|
-
for i in range(0, num_spatial_dims):
|
58
|
-
spatial_dimensions.append(i + 2)
|
59
|
-
|
60
|
-
dimension_numbers = stablehlo.ConvDimensionNumbers.get(
|
61
|
-
input_batch_dimension=0,
|
62
|
-
input_feature_dimension=1,
|
63
|
-
input_spatial_dimensions=spatial_dimensions,
|
64
|
-
kernel_input_feature_dimension=1,
|
65
|
-
kernel_output_feature_dimension=0,
|
66
|
-
kernel_spatial_dimensions=spatial_dimensions,
|
67
|
-
output_batch_dimension=0,
|
68
|
-
output_feature_dimension=1,
|
69
|
-
output_spatial_dimensions=spatial_dimensions,
|
193
|
+
# TODO(b/365559296) Add support for output_padding
|
194
|
+
if any(output_padding):
|
195
|
+
raise NotImplementedError(
|
196
|
+
"Output padding on convolution is not implemented."
|
70
197
|
)
|
71
|
-
return dimension_numbers
|
72
|
-
|
73
|
-
def infer_output_shape():
|
74
|
-
lhs_type: ir.RankedTensorType = lhs.type
|
75
|
-
lhs_shape: list[int] = lhs_type.shape
|
76
|
-
rhs_shape: list[int] = rhs.type.shape
|
77
|
-
|
78
|
-
# Input layout is: (N)CHW and Kernel layout is: (O)IHW
|
79
|
-
output_shape = [lhs_shape[0], rhs_shape[0]]
|
80
|
-
num_spatial_dims = len(lhs.type.shape) - 2
|
81
|
-
|
82
|
-
# looping over the spatial dims (skipping the first 2 dims which are
|
83
|
-
# batch and features)
|
84
|
-
for spatial_dim in range(0, num_spatial_dims):
|
85
|
-
dim_size = lhs_shape[spatial_dim + 2]
|
86
|
-
kernel_dim_size = rhs_shape[spatial_dim + 2]
|
87
|
-
|
88
|
-
# for example, a dilation of 2 increases the dimension size by 2
|
89
|
-
dim_size *= dilation[spatial_dim]
|
90
|
-
|
91
|
-
# padding added to both sides
|
92
|
-
dim_size += 2 * padding[spatial_dim]
|
93
|
-
|
94
|
-
output_dim_size = math.ceil(
|
95
|
-
(dim_size - kernel_dim_size + 1) / stride[spatial_dim]
|
96
|
-
)
|
97
|
-
|
98
|
-
output_shape.append(output_dim_size)
|
99
|
-
|
100
|
-
return output_shape
|
101
198
|
|
102
199
|
lhs_type: ir.RankedTensorType = lhs.type
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
rhs=rhs,
|
110
|
-
dimension_numbers=create_conv_dimension_numbers(),
|
111
|
-
feature_group_count=groups,
|
112
|
-
batch_group_count=1,
|
113
|
-
window_strides=stride,
|
114
|
-
padding=make_padding(padding),
|
115
|
-
lhs_dilation=(1,) * len(stride),
|
116
|
-
rhs_dilation=dilation,
|
200
|
+
output_shape = infer_output_shape(
|
201
|
+
lhs, rhs, stride, dilation, padding, transposed, output_padding
|
202
|
+
)
|
203
|
+
output_type = ir.RankedTensorType.get(
|
204
|
+
output_shape,
|
205
|
+
lhs_type.element_type,
|
117
206
|
)
|
118
207
|
|
119
|
-
|
208
|
+
if transposed:
|
209
|
+
res = build_transpose_conv(
|
210
|
+
lctx,
|
211
|
+
output_type,
|
212
|
+
lhs,
|
213
|
+
rhs,
|
214
|
+
stride,
|
215
|
+
padding,
|
216
|
+
dilation,
|
217
|
+
output_padding,
|
218
|
+
groups,
|
219
|
+
)
|
220
|
+
else:
|
221
|
+
res = stablehlo.convolution(
|
222
|
+
result=output_type,
|
223
|
+
lhs=lhs,
|
224
|
+
rhs=rhs,
|
225
|
+
dimension_numbers=create_conv_dimension_numbers(lhs),
|
226
|
+
feature_group_count=groups,
|
227
|
+
batch_group_count=1,
|
228
|
+
window_strides=stride,
|
229
|
+
padding=make_padding(padding),
|
230
|
+
rhs_dilation=dilation,
|
231
|
+
)
|
232
|
+
|
233
|
+
if bias is not None:
|
234
|
+
# broadcast [C] to [NCHW]
|
235
|
+
broadcasted_bias = stablehlo.broadcast_in_dim(output_type, bias, [1])
|
236
|
+
res = stablehlo.add(
|
237
|
+
lhs=res,
|
238
|
+
rhs=broadcasted_bias,
|
239
|
+
)
|
240
|
+
|
241
|
+
return res
|
@@ -105,7 +105,6 @@ lower_by_torch_xla2(torch.ops.aten.clamp.default)
|
|
105
105
|
lower_by_torch_xla2(torch.ops.aten.clone)
|
106
106
|
lower_by_torch_xla2(torch.ops.aten.clone.default)
|
107
107
|
lower_by_torch_xla2(torch.ops.aten.constant_pad_nd)
|
108
|
-
lower_by_torch_xla2(torch.ops.aten.convolution)
|
109
108
|
lower_by_torch_xla2(torch.ops.aten.cos)
|
110
109
|
lower_by_torch_xla2(torch.ops.aten.cosh)
|
111
110
|
lower_by_torch_xla2(torch.ops.aten.cumsum)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240911
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|