ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240912__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
- ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
- ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
- ai_edge_torch/generative/examples/smallm/smallm.py +119 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +40 -24
- ai_edge_torch/generative/layers/attention.py +60 -63
- ai_edge_torch/generative/layers/builder.py +4 -2
- ai_edge_torch/generative/layers/kv_cache.py +160 -51
- ai_edge_torch/generative/layers/model_config.py +1 -0
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +72 -34
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/RECORD +33 -39
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,10 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
# Common normalization layers.
|
16
16
|
|
17
|
+
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
17
18
|
import torch
|
19
|
+
from torch import nn
|
20
|
+
import torch.nn.functional as F
|
18
21
|
|
19
22
|
|
20
23
|
# Implementation for RMSNorm from: https://arxiv.org/abs/1910.07467
|
@@ -58,3 +61,158 @@ class RMSNorm(torch.nn.Module):
|
|
58
61
|
return output * (1 + self.weight)
|
59
62
|
else:
|
60
63
|
return output * self.weight
|
64
|
+
|
65
|
+
|
66
|
+
class GroupNorm(torch.nn.Module):
|
67
|
+
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
group_num: int,
|
71
|
+
dim: int,
|
72
|
+
eps: float = 1e-5,
|
73
|
+
enable_hlfb: bool = False,
|
74
|
+
):
|
75
|
+
"""Initialize the GroupNorm layer.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
group_num (int): Number of groups to separate the channels into.
|
79
|
+
dim (int): Dimension of the input tensor.
|
80
|
+
eps (float): A small float value to ensure numerical stability (default:
|
81
|
+
1e-6).
|
82
|
+
enable_hlfb (bool): Whether to convert this normalization into a single
|
83
|
+
op.
|
84
|
+
"""
|
85
|
+
super().__init__()
|
86
|
+
self.enable_hlfb = enable_hlfb
|
87
|
+
self.group_num = group_num
|
88
|
+
self.eps = eps
|
89
|
+
self.weight = torch.nn.Parameter(torch.ones(dim))
|
90
|
+
self.bias = torch.nn.Parameter(torch.ones(dim))
|
91
|
+
|
92
|
+
def forward(self, x):
|
93
|
+
"""Running the forward pass of GroupNorm layer.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
x (torch.Tensor): input tensor.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
torch.Tensor: output tensor after applying GroupNorm.
|
100
|
+
"""
|
101
|
+
if self.enable_hlfb:
|
102
|
+
return group_norm_with_hlfb(
|
103
|
+
x,
|
104
|
+
self.weight,
|
105
|
+
self.bias,
|
106
|
+
self.group_num,
|
107
|
+
self.eps,
|
108
|
+
)
|
109
|
+
else:
|
110
|
+
return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
|
111
|
+
|
112
|
+
|
113
|
+
class LayerNorm(torch.nn.Module):
|
114
|
+
|
115
|
+
def __init__(self, dim: int, eps: float = 1e-5, enable_hlfb: bool = False):
|
116
|
+
"""Initialize the LayerNorm layer.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
dim (int): dimension of the input tensor.
|
120
|
+
eps (float): A small float value to ensure numerical stability (default:
|
121
|
+
1e-6).
|
122
|
+
enable_hlfb (bool): Whether to convert this normalization into a single
|
123
|
+
op.
|
124
|
+
"""
|
125
|
+
super().__init__()
|
126
|
+
self.enable_hlfb = enable_hlfb
|
127
|
+
self.eps = eps
|
128
|
+
self.weight = torch.nn.Parameter(torch.ones(dim))
|
129
|
+
self.bias = torch.nn.Parameter(torch.ones(dim))
|
130
|
+
|
131
|
+
def forward(self, x):
|
132
|
+
"""Running the forward pass of LayerNorm layer.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
x (torch.Tensor): input tensor.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
torch.Tensor: output tensor after applying LayerNorm.
|
139
|
+
"""
|
140
|
+
if self.enable_hlfb:
|
141
|
+
return layer_norm_with_hlfb(
|
142
|
+
x,
|
143
|
+
self.weight,
|
144
|
+
self.bias,
|
145
|
+
self.eps,
|
146
|
+
)
|
147
|
+
else:
|
148
|
+
return F.layer_norm(
|
149
|
+
x,
|
150
|
+
x.shape,
|
151
|
+
self.weight.broadcast_to(x.shape),
|
152
|
+
self.bias.broadcast_to(x.shape),
|
153
|
+
self.eps,
|
154
|
+
)
|
155
|
+
|
156
|
+
|
157
|
+
def group_norm_with_hlfb(
|
158
|
+
x: torch.Tensor,
|
159
|
+
w: torch.Tensor,
|
160
|
+
b: torch.Tensor,
|
161
|
+
num_groups: int,
|
162
|
+
eps: float,
|
163
|
+
):
|
164
|
+
"""Group Normalization with high-level function boundary enabled.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
x (torch.Tensor): Input tensor for Group Normalization, with BCHW shape.
|
168
|
+
w (torch.Tensor): The weight tensor for the normalization.
|
169
|
+
b (torch.Tensor): The bias tensor for the normalization.
|
170
|
+
num_groups (int): Number of groups to separate the channels into.
|
171
|
+
eps (float): A small float value to ensure numerical stability.
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
The output tensor of Group Normalization.
|
175
|
+
"""
|
176
|
+
x = torch.permute(x, (0, 2, 3, 1))
|
177
|
+
|
178
|
+
builder = StableHLOCompositeBuilder(
|
179
|
+
name="odml.group_norm", attr={"num_groups": num_groups, "eps": eps}
|
180
|
+
)
|
181
|
+
x, w, b = builder.mark_inputs(x, w, b)
|
182
|
+
x = torch.permute(x, (0, 3, 1, 2))
|
183
|
+
y = F.group_norm(x, num_groups, weight=w, bias=b, eps=eps)
|
184
|
+
y = torch.permute(y, (0, 2, 3, 1))
|
185
|
+
y = builder.mark_outputs(y)
|
186
|
+
|
187
|
+
y = torch.permute(y, (0, 3, 1, 2))
|
188
|
+
return y
|
189
|
+
|
190
|
+
|
191
|
+
def layer_norm_with_hlfb(
|
192
|
+
x: torch.Tensor,
|
193
|
+
w: torch.Tensor,
|
194
|
+
b: torch.Tensor,
|
195
|
+
eps: float,
|
196
|
+
):
|
197
|
+
"""Layer Normalization with high-level function boundary enabled.
|
198
|
+
|
199
|
+
Args:
|
200
|
+
x (torch.Tensor): Input tensor for Layer Normalization.
|
201
|
+
w (torch.Tensor): The weight tensor for the normalization.
|
202
|
+
b (torch.Tensor): The bias tensor for the normalization.
|
203
|
+
eps (float): A small float value to ensure numerical stability.
|
204
|
+
|
205
|
+
Returns:
|
206
|
+
The output tensor of Layer Normalization.
|
207
|
+
"""
|
208
|
+
builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
|
209
|
+
x, w, b = builder.mark_inputs(x, w, b)
|
210
|
+
y = F.layer_norm(
|
211
|
+
x,
|
212
|
+
x.shape,
|
213
|
+
weight=w.broadcast_to(x.shape),
|
214
|
+
bias=b.broadcast_to(x.shape),
|
215
|
+
eps=eps,
|
216
|
+
)
|
217
|
+
y = builder.mark_outputs(y)
|
218
|
+
return y
|
@@ -122,7 +122,6 @@ class AttentionBlock2D(nn.Module):
|
|
122
122
|
config.attention_batch_size,
|
123
123
|
config.dim,
|
124
124
|
config.attention_config,
|
125
|
-
0,
|
126
125
|
enable_hlfb=config.enable_hlfb,
|
127
126
|
)
|
128
127
|
|
@@ -180,7 +179,6 @@ class CrossAttentionBlock2D(nn.Module):
|
|
180
179
|
config.query_dim,
|
181
180
|
config.cross_dim,
|
182
181
|
config.attention_config,
|
183
|
-
0,
|
184
182
|
enable_hlfb=config.enable_hlfb,
|
185
183
|
)
|
186
184
|
|
@@ -12,19 +12,17 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
# A suite of tests to validate experimental external KV Cache layers and models.
|
16
15
|
|
17
|
-
|
18
|
-
|
19
|
-
from ai_edge_torch.generative.
|
20
|
-
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
16
|
+
"""A suite of tests to validate KV Cache layer."""
|
17
|
+
|
18
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
20
|
import torch
|
23
21
|
|
24
22
|
from absl.testing import absltest as googletest
|
25
23
|
|
26
24
|
|
27
|
-
class
|
25
|
+
class TestKVLayers(googletest.TestCase):
|
28
26
|
|
29
27
|
def _get_test_config(
|
30
28
|
self, num_layers, head_dim, num_query_groups, kv_cache_max_len
|
@@ -54,7 +52,7 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
54
52
|
num_query_groups=NUM_QG,
|
55
53
|
kv_cache_max_len=KV_LEN,
|
56
54
|
)
|
57
|
-
kv = kv_utils.
|
55
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
58
56
|
entry = kv.caches[0]
|
59
57
|
# single-slice update
|
60
58
|
input_pos = torch.tensor([1])
|
@@ -88,14 +86,14 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
88
86
|
def test_serialization(self):
|
89
87
|
class TestModel(torch.nn.Module):
|
90
88
|
|
91
|
-
def forward(self, kv: kv_utils.
|
89
|
+
def forward(self, kv: kv_utils.KVCache) -> kv_utils.KVCache:
|
92
90
|
updated_kv_entries = [
|
93
91
|
kv_utils.KVCacheEntry(
|
94
92
|
torch.zeros_like(entry.k_cache), torch.zeros_like(entry.v_cache)
|
95
93
|
)
|
96
94
|
for entry in kv.caches
|
97
95
|
]
|
98
|
-
return kv_utils.
|
96
|
+
return kv_utils.KVCache(updated_kv_entries)
|
99
97
|
|
100
98
|
N = 1
|
101
99
|
HEAD_DIM = 2
|
@@ -107,7 +105,7 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
107
105
|
num_query_groups=NUM_QG,
|
108
106
|
kv_cache_max_len=KV_LEN,
|
109
107
|
)
|
110
|
-
kv = kv_utils.
|
108
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
111
109
|
model = TestModel()
|
112
110
|
exported_program = torch.export.export(model, (kv,))
|
113
111
|
input_specs = exported_program.graph_signature.input_specs
|
@@ -116,17 +114,5 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
116
114
|
self.assertEqual(input_specs[1].arg.name, "kv_v_0")
|
117
115
|
|
118
116
|
|
119
|
-
class TestExternalKVModels(googletest.TestCase):
|
120
|
-
|
121
|
-
def test_can_build_gemma(self):
|
122
|
-
gemma.define_and_run_2b(checkpoint_path=None, test_model=True)
|
123
|
-
|
124
|
-
def test_can_build_phi2(self):
|
125
|
-
phi2.define_and_run(checkpoint_path=None, test_model=True)
|
126
|
-
|
127
|
-
def test_can_build_tinyllama(self):
|
128
|
-
tiny_llama.define_and_run(checkpoint_path=None, test_model=True)
|
129
|
-
|
130
|
-
|
131
117
|
if __name__ == "__main__":
|
132
118
|
googletest.main()
|
@@ -71,7 +71,7 @@ class TestLoader(googletest.TestCase):
|
|
71
71
|
safetensors.torch.save_file(test_weights, file_path)
|
72
72
|
cfg = tiny_llama.get_model_config()
|
73
73
|
cfg.num_layers = 1
|
74
|
-
model = tiny_llama.
|
74
|
+
model = tiny_llama.TinyLlama(cfg)
|
75
75
|
|
76
76
|
loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
|
77
77
|
# if returns successfully, it means all the tensors were initiallized.
|
@@ -12,16 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
15
|
+
|
16
|
+
"""Testing model conversion for a few gen-ai models."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
|
-
from ai_edge_torch.generative.examples.
|
21
|
-
from ai_edge_torch.generative.examples.phi2 import phi2
|
22
|
-
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
|
20
|
+
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
|
23
21
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
24
|
-
from ai_edge_torch.
|
22
|
+
from ai_edge_torch.generative.layers import kv_cache
|
23
|
+
from ai_edge_torch.generative.test import utils as test_utils
|
25
24
|
import numpy as np
|
26
25
|
import torch
|
27
26
|
|
@@ -49,22 +48,32 @@ class TestModelConversion(googletest.TestCase):
|
|
49
48
|
)
|
50
49
|
def test_toy_model_with_kv_cache(self):
|
51
50
|
config = toy_model_with_kv_cache.get_model_config()
|
52
|
-
pytorch_model = toy_model_with_kv_cache.
|
53
|
-
|
51
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
52
|
+
tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
54
53
|
[10], dtype=torch.int64
|
55
54
|
)
|
56
|
-
|
57
|
-
|
55
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
56
|
+
|
57
|
+
edge_model = ai_edge_torch.convert(
|
58
|
+
pytorch_model,
|
59
|
+
sample_kwargs={
|
60
|
+
"tokens": tokens,
|
61
|
+
"input_pos": input_pos,
|
62
|
+
"kv_cache": kv,
|
63
|
+
},
|
64
|
+
)
|
58
65
|
edge_model.set_interpreter_builder(
|
59
66
|
self._interpreter_builder(edge_model.tflite_model())
|
60
67
|
)
|
61
68
|
|
62
69
|
self.assertTrue(
|
63
|
-
|
70
|
+
test_utils.compare_tflite_torch(
|
64
71
|
edge_model,
|
65
72
|
pytorch_model,
|
66
|
-
|
67
|
-
|
73
|
+
tokens,
|
74
|
+
input_pos,
|
75
|
+
kv,
|
76
|
+
signature_name="serving_default",
|
68
77
|
atol=1e-5,
|
69
78
|
rtol=1e-5,
|
70
79
|
)
|
@@ -77,22 +86,32 @@ class TestModelConversion(googletest.TestCase):
|
|
77
86
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
78
87
|
config = toy_model_with_kv_cache.get_model_config()
|
79
88
|
config.enable_hlfb = True
|
80
|
-
pytorch_model = toy_model_with_kv_cache.
|
81
|
-
|
89
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
90
|
+
tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
82
91
|
[10], dtype=torch.int64
|
83
92
|
)
|
84
|
-
|
85
|
-
|
93
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
94
|
+
|
95
|
+
edge_model = ai_edge_torch.convert(
|
96
|
+
pytorch_model,
|
97
|
+
sample_kwargs={
|
98
|
+
"tokens": tokens,
|
99
|
+
"input_pos": input_pos,
|
100
|
+
"kv_cache": kv,
|
101
|
+
},
|
102
|
+
)
|
86
103
|
edge_model.set_interpreter_builder(
|
87
104
|
self._interpreter_builder(edge_model.tflite_model())
|
88
105
|
)
|
89
106
|
|
90
107
|
self.assertTrue(
|
91
|
-
|
108
|
+
test_utils.compare_tflite_torch(
|
92
109
|
edge_model,
|
93
110
|
pytorch_model,
|
94
|
-
|
95
|
-
|
111
|
+
tokens,
|
112
|
+
input_pos,
|
113
|
+
kv,
|
114
|
+
signature_name="serving_default",
|
96
115
|
atol=1e-5,
|
97
116
|
rtol=1e-5,
|
98
117
|
)
|
@@ -104,7 +123,7 @@ class TestModelConversion(googletest.TestCase):
|
|
104
123
|
)
|
105
124
|
def test_tiny_llama_multisig(self):
|
106
125
|
config = tiny_llama.get_fake_model_config()
|
107
|
-
pytorch_model = tiny_llama.
|
126
|
+
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
108
127
|
|
109
128
|
# prefill
|
110
129
|
seq_len = 10
|
@@ -117,37 +136,56 @@ class TestModelConversion(googletest.TestCase):
|
|
117
136
|
decode_token = torch.tensor([[1]], dtype=torch.long)
|
118
137
|
decode_input_pos = torch.tensor([5], dtype=torch.int64)
|
119
138
|
|
139
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
140
|
+
|
120
141
|
edge_model = (
|
121
142
|
ai_edge_torch.signature(
|
122
|
-
"prefill",
|
143
|
+
"prefill",
|
144
|
+
pytorch_model,
|
145
|
+
sample_kwargs={
|
146
|
+
"tokens": prefill_tokens,
|
147
|
+
"input_pos": prefill_input_pos,
|
148
|
+
"kv_cache": kv,
|
149
|
+
},
|
150
|
+
)
|
151
|
+
.signature(
|
152
|
+
"decode",
|
153
|
+
pytorch_model,
|
154
|
+
sample_kwargs={
|
155
|
+
"tokens": decode_token,
|
156
|
+
"input_pos": decode_input_pos,
|
157
|
+
"kv_cache": kv,
|
158
|
+
},
|
123
159
|
)
|
124
|
-
.signature("decode", pytorch_model, (decode_token, decode_input_pos))
|
125
160
|
.convert()
|
126
161
|
)
|
127
162
|
edge_model.set_interpreter_builder(
|
128
163
|
self._interpreter_builder(edge_model.tflite_model())
|
129
164
|
)
|
130
165
|
|
131
|
-
copied_model = copy.deepcopy(pytorch_model)
|
132
|
-
copied_edge = copy.deepcopy(edge_model)
|
133
|
-
|
134
166
|
self.assertTrue(
|
135
|
-
|
167
|
+
test_utils.compare_tflite_torch(
|
136
168
|
edge_model,
|
137
169
|
pytorch_model,
|
138
|
-
|
170
|
+
prefill_tokens,
|
171
|
+
prefill_input_pos,
|
172
|
+
kv,
|
139
173
|
signature_name="prefill",
|
140
|
-
|
174
|
+
atol=1e-5,
|
175
|
+
rtol=1e-5,
|
141
176
|
)
|
142
177
|
)
|
143
178
|
|
144
179
|
self.assertTrue(
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
180
|
+
test_utils.compare_tflite_torch(
|
181
|
+
edge_model,
|
182
|
+
pytorch_model,
|
183
|
+
decode_token,
|
184
|
+
decode_input_pos,
|
185
|
+
kv,
|
149
186
|
signature_name="decode",
|
150
|
-
|
187
|
+
atol=1e-5,
|
188
|
+
rtol=1e-5,
|
151
189
|
)
|
152
190
|
)
|
153
191
|
|
@@ -12,16 +12,16 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
15
|
+
|
16
|
+
"""Testing model conversion for a few gen-ai models."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
|
-
from ai_edge_torch.generative.examples.gemma import gemma
|
21
|
-
from ai_edge_torch.generative.examples.
|
22
|
-
from ai_edge_torch.generative.examples.
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.
|
20
|
+
from ai_edge_torch.generative.examples.gemma import gemma
|
21
|
+
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
|
+
from ai_edge_torch.generative.examples.phi import phi2
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache
|
24
|
+
from ai_edge_torch.generative.test import utils as test_utils
|
25
25
|
import numpy as np
|
26
26
|
import torch
|
27
27
|
|
@@ -55,18 +55,28 @@ class TestModelConversion(googletest.TestCase):
|
|
55
55
|
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
56
56
|
tokens[0, :4] = idx
|
57
57
|
input_pos = torch.arange(0, 10)
|
58
|
-
|
59
|
-
|
58
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
59
|
+
|
60
|
+
edge_model = ai_edge_torch.convert(
|
61
|
+
model,
|
62
|
+
sample_kwargs={
|
63
|
+
"tokens": tokens,
|
64
|
+
"input_pos": input_pos,
|
65
|
+
"kv_cache": kv,
|
66
|
+
},
|
67
|
+
)
|
60
68
|
edge_model.set_interpreter_builder(
|
61
69
|
self._interpreter_builder(edge_model.tflite_model())
|
62
70
|
)
|
63
71
|
|
64
72
|
self.assertTrue(
|
65
|
-
|
73
|
+
test_utils.compare_tflite_torch(
|
66
74
|
edge_model,
|
67
75
|
model,
|
68
|
-
|
69
|
-
|
76
|
+
tokens,
|
77
|
+
input_pos,
|
78
|
+
kv,
|
79
|
+
signature_name="serving_default",
|
70
80
|
atol=1e-2,
|
71
81
|
rtol=1e-5,
|
72
82
|
)
|
@@ -85,23 +95,31 @@ class TestModelConversion(googletest.TestCase):
|
|
85
95
|
prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
86
96
|
prefill_tokens[0, :4] = idx
|
87
97
|
prefill_input_pos = torch.arange(0, 10)
|
98
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
88
99
|
|
89
100
|
edge_model = ai_edge_torch.signature(
|
90
|
-
"prefill",
|
101
|
+
"prefill",
|
102
|
+
model,
|
103
|
+
sample_kwargs={
|
104
|
+
"tokens": prefill_tokens,
|
105
|
+
"input_pos": prefill_input_pos,
|
106
|
+
"kv_cache": kv,
|
107
|
+
},
|
91
108
|
).convert()
|
92
109
|
edge_model.set_interpreter_builder(
|
93
110
|
self._interpreter_builder(edge_model.tflite_model())
|
94
111
|
)
|
95
112
|
|
96
113
|
self.assertTrue(
|
97
|
-
|
114
|
+
test_utils.compare_tflite_torch(
|
98
115
|
edge_model,
|
99
116
|
model,
|
100
|
-
|
117
|
+
prefill_tokens,
|
118
|
+
prefill_input_pos,
|
119
|
+
kv,
|
101
120
|
signature_name="prefill",
|
102
|
-
|
103
|
-
|
104
|
-
rtol=1e-5,
|
121
|
+
atol=1e-1,
|
122
|
+
rtol=1e-3,
|
105
123
|
)
|
106
124
|
)
|
107
125
|
|
@@ -117,18 +135,28 @@ class TestModelConversion(googletest.TestCase):
|
|
117
135
|
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
118
136
|
tokens[0, :4] = idx
|
119
137
|
input_pos = torch.arange(0, 10)
|
120
|
-
|
121
|
-
|
138
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
139
|
+
|
140
|
+
edge_model = ai_edge_torch.convert(
|
141
|
+
pytorch_model,
|
142
|
+
sample_kwargs={
|
143
|
+
"tokens": tokens,
|
144
|
+
"input_pos": input_pos,
|
145
|
+
"kv_cache": kv,
|
146
|
+
},
|
147
|
+
)
|
122
148
|
edge_model.set_interpreter_builder(
|
123
149
|
self._interpreter_builder(edge_model.tflite_model())
|
124
150
|
)
|
125
151
|
|
126
152
|
self.assertTrue(
|
127
|
-
|
153
|
+
test_utils.compare_tflite_torch(
|
128
154
|
edge_model,
|
129
155
|
pytorch_model,
|
130
|
-
|
131
|
-
|
156
|
+
tokens,
|
157
|
+
input_pos,
|
158
|
+
kv,
|
159
|
+
signature_name="serving_default",
|
132
160
|
atol=1e-3,
|
133
161
|
rtol=1e-3,
|
134
162
|
)
|
@@ -0,0 +1,54 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Common utils for testing."""
|
17
|
+
|
18
|
+
from ai_edge_torch import model
|
19
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
|
+
from ai_edge_torch.lowertools import common_utils
|
21
|
+
import numpy as np
|
22
|
+
import torch
|
23
|
+
from torch.utils import _pytree as pytree
|
24
|
+
|
25
|
+
|
26
|
+
def compare_tflite_torch(
|
27
|
+
edge_model: model.Model,
|
28
|
+
torch_model: torch.nn.Module,
|
29
|
+
tokens: torch.Tensor,
|
30
|
+
input_pos: torch.Tensor,
|
31
|
+
kv_cache: kv_utils.KVCache,
|
32
|
+
signature_name: str,
|
33
|
+
atol: float = 1e-5,
|
34
|
+
rtol: float = 1e-5,
|
35
|
+
):
|
36
|
+
"""Compares torch models and TFLite models."""
|
37
|
+
values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
|
38
|
+
flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
|
39
|
+
torch_output = torch_model(tokens, input_pos, kv_cache)
|
40
|
+
|
41
|
+
input_kv_flatten = {k: v.numpy() for k, v in zip(flat_names, values)}
|
42
|
+
edge_output = edge_model(
|
43
|
+
signature_name=signature_name,
|
44
|
+
tokens=tokens.numpy(),
|
45
|
+
input_pos=input_pos.numpy(),
|
46
|
+
**input_kv_flatten,
|
47
|
+
)
|
48
|
+
|
49
|
+
return np.allclose(
|
50
|
+
edge_output["logits"],
|
51
|
+
torch_output["logits"].detach().numpy(),
|
52
|
+
atol=atol,
|
53
|
+
rtol=rtol,
|
54
|
+
)
|
@@ -167,7 +167,6 @@ lower_by_torch_xla2(torch.ops.aten.mul.Scalar)
|
|
167
167
|
lower_by_torch_xla2(torch.ops.aten.mul.Tensor)
|
168
168
|
lower_by_torch_xla2(torch.ops.aten.native_batch_norm)
|
169
169
|
lower_by_torch_xla2(torch.ops.aten.native_group_norm)
|
170
|
-
lower_by_torch_xla2(torch.ops.aten.native_layer_norm)
|
171
170
|
lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward)
|
172
171
|
lower_by_torch_xla2(torch.ops.aten.ne)
|
173
172
|
lower_by_torch_xla2(torch.ops.aten.neg)
|