ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
- ai_edge_torch/config.py +4 -1
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +35 -16
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +29 -10
- ai_edge_torch/generative/examples/gemma/gemma.py +52 -32
- ai_edge_torch/generative/examples/gemma/gemma2.py +87 -60
- ai_edge_torch/generative/examples/{experimental/gemma → openelm}/convert_to_tflite.py +16 -18
- ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +15 -16
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +48 -45
- ai_edge_torch/generative/examples/{experimental/tiny_llama → smollm}/convert_to_tflite.py +16 -17
- ai_edge_torch/generative/examples/smollm/smollm.py +131 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -6
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
- ai_edge_torch/generative/examples/t5/t5.py +43 -30
- ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +75 -34
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +29 -10
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +57 -36
- ai_edge_torch/generative/fx_passes/__init__.py +4 -4
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
- ai_edge_torch/generative/layers/attention.py +84 -73
- ai_edge_torch/generative/layers/builder.py +38 -14
- ai_edge_torch/generative/layers/feed_forward.py +26 -8
- ai_edge_torch/generative/layers/kv_cache.py +163 -51
- ai_edge_torch/generative/layers/model_config.py +61 -33
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/quantize/example.py +2 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +77 -62
- ai_edge_torch/generative/test/test_model_conversion_large.py +61 -68
- ai_edge_torch/generative/test/test_quantize.py +5 -5
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/loader.py +28 -15
- ai_edge_torch/generative/utilities/t5_loader.py +21 -20
- ai_edge_torch/odml_torch/export.py +40 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +59 -63
- ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
- ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → openelm}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/phi → smollm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -12,19 +12,17 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
# A suite of tests to validate experimental external KV Cache layers and models.
|
16
15
|
|
17
|
-
|
18
|
-
|
19
|
-
from ai_edge_torch.generative.
|
20
|
-
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
16
|
+
"""A suite of tests to validate KV Cache layer."""
|
17
|
+
|
18
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
20
|
import torch
|
23
21
|
|
24
22
|
from absl.testing import absltest as googletest
|
25
23
|
|
26
24
|
|
27
|
-
class
|
25
|
+
class TestKVLayers(googletest.TestCase):
|
28
26
|
|
29
27
|
def _get_test_config(
|
30
28
|
self, num_layers, head_dim, num_query_groups, kv_cache_max_len
|
@@ -32,14 +30,16 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
32
30
|
attn_config = cfg.AttentionConfig(
|
33
31
|
num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
|
34
32
|
)
|
33
|
+
block_config = cfg.TransformerBlockConfig(
|
34
|
+
attn_config=attn_config, ff_config=None
|
35
|
+
)
|
35
36
|
config = cfg.ModelConfig(
|
36
37
|
kv_cache_max_len=kv_cache_max_len,
|
37
38
|
embedding_dim=head_dim,
|
38
|
-
|
39
|
+
block_configs=block_config,
|
39
40
|
num_layers=num_layers,
|
40
41
|
max_seq_len=None,
|
41
42
|
vocab_size=None,
|
42
|
-
ff_config=None,
|
43
43
|
)
|
44
44
|
return config
|
45
45
|
|
@@ -54,7 +54,7 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
54
54
|
num_query_groups=NUM_QG,
|
55
55
|
kv_cache_max_len=KV_LEN,
|
56
56
|
)
|
57
|
-
kv = kv_utils.
|
57
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
58
58
|
entry = kv.caches[0]
|
59
59
|
# single-slice update
|
60
60
|
input_pos = torch.tensor([1])
|
@@ -88,14 +88,14 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
88
88
|
def test_serialization(self):
|
89
89
|
class TestModel(torch.nn.Module):
|
90
90
|
|
91
|
-
def forward(self, kv: kv_utils.
|
91
|
+
def forward(self, kv: kv_utils.KVCache) -> kv_utils.KVCache:
|
92
92
|
updated_kv_entries = [
|
93
93
|
kv_utils.KVCacheEntry(
|
94
94
|
torch.zeros_like(entry.k_cache), torch.zeros_like(entry.v_cache)
|
95
95
|
)
|
96
96
|
for entry in kv.caches
|
97
97
|
]
|
98
|
-
return kv_utils.
|
98
|
+
return kv_utils.KVCache(updated_kv_entries)
|
99
99
|
|
100
100
|
N = 1
|
101
101
|
HEAD_DIM = 2
|
@@ -107,7 +107,7 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
107
107
|
num_query_groups=NUM_QG,
|
108
108
|
kv_cache_max_len=KV_LEN,
|
109
109
|
)
|
110
|
-
kv = kv_utils.
|
110
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
111
111
|
model = TestModel()
|
112
112
|
exported_program = torch.export.export(model, (kv,))
|
113
113
|
input_specs = exported_program.graph_signature.input_specs
|
@@ -116,17 +116,5 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
116
116
|
self.assertEqual(input_specs[1].arg.name, "kv_v_0")
|
117
117
|
|
118
118
|
|
119
|
-
class TestExternalKVModels(googletest.TestCase):
|
120
|
-
|
121
|
-
def test_can_build_gemma(self):
|
122
|
-
gemma.define_and_run_2b(checkpoint_path=None, test_model=True)
|
123
|
-
|
124
|
-
def test_can_build_phi2(self):
|
125
|
-
phi2.define_and_run(checkpoint_path=None, test_model=True)
|
126
|
-
|
127
|
-
def test_can_build_tinyllama(self):
|
128
|
-
tiny_llama.define_and_run(checkpoint_path=None, test_model=True)
|
129
|
-
|
130
|
-
|
131
119
|
if __name__ == "__main__":
|
132
120
|
googletest.main()
|
@@ -71,7 +71,7 @@ class TestLoader(googletest.TestCase):
|
|
71
71
|
safetensors.torch.save_file(test_weights, file_path)
|
72
72
|
cfg = tiny_llama.get_model_config()
|
73
73
|
cfg.num_layers = 1
|
74
|
-
model = tiny_llama.
|
74
|
+
model = tiny_llama.TinyLlama(cfg)
|
75
75
|
|
76
76
|
loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
|
77
77
|
# if returns successfully, it means all the tensors were initiallized.
|
@@ -12,16 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
15
|
+
|
16
|
+
"""Testing model conversion for a few gen-ai models."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
|
-
from ai_edge_torch.generative.examples.
|
21
|
-
from ai_edge_torch.generative.examples.phi2 import phi2
|
22
|
-
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
|
20
|
+
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
|
23
21
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
24
|
-
from ai_edge_torch.
|
22
|
+
from ai_edge_torch.generative.layers import kv_cache
|
23
|
+
from ai_edge_torch.generative.test import utils as test_utils
|
25
24
|
import numpy as np
|
26
25
|
import torch
|
27
26
|
|
@@ -43,28 +42,32 @@ class TestModelConversion(googletest.TestCase):
|
|
43
42
|
)
|
44
43
|
)
|
45
44
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
45
|
+
def _test_model_with_kv_cache(self, config, pytorch_model):
|
46
|
+
tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
|
47
|
+
[10], dtype=torch.int
|
48
|
+
)
|
49
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
50
|
+
|
51
|
+
edge_model = ai_edge_torch.convert(
|
52
|
+
pytorch_model,
|
53
|
+
sample_kwargs={
|
54
|
+
"tokens": tokens,
|
55
|
+
"input_pos": input_pos,
|
56
|
+
"kv_cache": kv,
|
57
|
+
},
|
55
58
|
)
|
56
|
-
|
57
|
-
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
58
59
|
edge_model.set_interpreter_builder(
|
59
60
|
self._interpreter_builder(edge_model.tflite_model())
|
60
61
|
)
|
61
62
|
|
62
63
|
self.assertTrue(
|
63
|
-
|
64
|
+
test_utils.compare_tflite_torch(
|
64
65
|
edge_model,
|
65
66
|
pytorch_model,
|
66
|
-
|
67
|
-
|
67
|
+
tokens,
|
68
|
+
input_pos,
|
69
|
+
kv,
|
70
|
+
signature_name="serving_default",
|
68
71
|
atol=1e-5,
|
69
72
|
rtol=1e-5,
|
70
73
|
)
|
@@ -74,83 +77,95 @@ class TestModelConversion(googletest.TestCase):
|
|
74
77
|
ai_edge_config.Config.use_torch_xla,
|
75
78
|
reason="tests with custom ops are not supported on oss",
|
76
79
|
)
|
77
|
-
def
|
80
|
+
def test_toy_model_with_kv_cache(self):
|
78
81
|
config = toy_model_with_kv_cache.get_model_config()
|
79
|
-
|
80
|
-
|
81
|
-
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
82
|
-
[10], dtype=torch.int64
|
83
|
-
)
|
84
|
-
|
85
|
-
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
86
|
-
edge_model.set_interpreter_builder(
|
87
|
-
self._interpreter_builder(edge_model.tflite_model())
|
88
|
-
)
|
89
|
-
|
90
|
-
self.assertTrue(
|
91
|
-
model_coverage.compare_tflite_torch(
|
92
|
-
edge_model,
|
93
|
-
pytorch_model,
|
94
|
-
(idx, input_pos),
|
95
|
-
num_valid_inputs=1,
|
96
|
-
atol=1e-5,
|
97
|
-
rtol=1e-5,
|
98
|
-
)
|
99
|
-
)
|
82
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
83
|
+
self._test_model_with_kv_cache(config, pytorch_model)
|
100
84
|
|
101
85
|
@googletest.skipIf(
|
102
86
|
ai_edge_config.Config.use_torch_xla,
|
103
87
|
reason="tests with custom ops are not supported on oss",
|
104
88
|
)
|
105
|
-
def
|
106
|
-
config =
|
107
|
-
|
89
|
+
def test_toy_model_with_kv_cache_with_hlfb(self):
|
90
|
+
config = toy_model_with_kv_cache.get_model_config()
|
91
|
+
config.enable_hlfb = True
|
92
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
93
|
+
self._test_model_with_kv_cache(config, pytorch_model)
|
108
94
|
|
95
|
+
def _test_multisig_model(self, config, pytorch_model, atol, rtol):
|
109
96
|
# prefill
|
110
97
|
seq_len = 10
|
111
|
-
prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
98
|
+
prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.int, device="cpu")
|
112
99
|
prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
|
113
100
|
prefill_tokens[0, : len(prompt_token)] = prompt_token
|
114
|
-
prefill_input_pos = torch.arange(0, seq_len)
|
101
|
+
prefill_input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
115
102
|
|
116
103
|
# decode
|
117
|
-
decode_token = torch.tensor([[1]], dtype=torch.
|
118
|
-
decode_input_pos = torch.tensor([5], dtype=torch.
|
104
|
+
decode_token = torch.tensor([[1]], dtype=torch.int)
|
105
|
+
decode_input_pos = torch.tensor([5], dtype=torch.int)
|
106
|
+
|
107
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
119
108
|
|
120
109
|
edge_model = (
|
121
110
|
ai_edge_torch.signature(
|
122
|
-
"prefill",
|
111
|
+
"prefill",
|
112
|
+
pytorch_model,
|
113
|
+
sample_kwargs={
|
114
|
+
"tokens": prefill_tokens,
|
115
|
+
"input_pos": prefill_input_pos,
|
116
|
+
"kv_cache": kv,
|
117
|
+
},
|
118
|
+
)
|
119
|
+
.signature(
|
120
|
+
"decode",
|
121
|
+
pytorch_model,
|
122
|
+
sample_kwargs={
|
123
|
+
"tokens": decode_token,
|
124
|
+
"input_pos": decode_input_pos,
|
125
|
+
"kv_cache": kv,
|
126
|
+
},
|
123
127
|
)
|
124
|
-
.signature("decode", pytorch_model, (decode_token, decode_input_pos))
|
125
128
|
.convert()
|
126
129
|
)
|
127
130
|
edge_model.set_interpreter_builder(
|
128
131
|
self._interpreter_builder(edge_model.tflite_model())
|
129
132
|
)
|
130
133
|
|
131
|
-
copied_model = copy.deepcopy(pytorch_model)
|
132
|
-
copied_edge = copy.deepcopy(edge_model)
|
133
|
-
|
134
134
|
self.assertTrue(
|
135
|
-
|
135
|
+
test_utils.compare_tflite_torch(
|
136
136
|
edge_model,
|
137
137
|
pytorch_model,
|
138
|
-
|
138
|
+
prefill_tokens,
|
139
|
+
prefill_input_pos,
|
140
|
+
kv,
|
139
141
|
signature_name="prefill",
|
140
|
-
|
142
|
+
atol=atol,
|
143
|
+
rtol=atol,
|
141
144
|
)
|
142
145
|
)
|
143
146
|
|
144
147
|
self.assertTrue(
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
148
|
+
test_utils.compare_tflite_torch(
|
149
|
+
edge_model,
|
150
|
+
pytorch_model,
|
151
|
+
decode_token,
|
152
|
+
decode_input_pos,
|
153
|
+
kv,
|
149
154
|
signature_name="decode",
|
150
|
-
|
155
|
+
atol=atol,
|
156
|
+
rtol=atol,
|
151
157
|
)
|
152
158
|
)
|
153
159
|
|
160
|
+
@googletest.skipIf(
|
161
|
+
ai_edge_config.Config.use_torch_xla,
|
162
|
+
reason="tests with custom ops are not supported on oss",
|
163
|
+
)
|
164
|
+
def test_tiny_llama_multisig(self):
|
165
|
+
config = tiny_llama.get_fake_model_config()
|
166
|
+
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
167
|
+
self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
|
168
|
+
|
154
169
|
|
155
170
|
if __name__ == "__main__":
|
156
171
|
googletest.main()
|
@@ -12,16 +12,18 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
15
|
+
|
16
|
+
"""Testing model conversion for a few gen-ai models."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
|
-
from ai_edge_torch.generative.examples.gemma import gemma
|
21
|
-
from ai_edge_torch.generative.examples.
|
22
|
-
from ai_edge_torch.generative.examples.
|
23
|
-
from ai_edge_torch.generative.examples.
|
24
|
-
from ai_edge_torch.
|
20
|
+
from ai_edge_torch.generative.examples.gemma import gemma
|
21
|
+
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
|
+
from ai_edge_torch.generative.examples.openelm import openelm
|
23
|
+
from ai_edge_torch.generative.examples.phi import phi2
|
24
|
+
from ai_edge_torch.generative.examples.smollm import smollm
|
25
|
+
from ai_edge_torch.generative.layers import kv_cache
|
26
|
+
from ai_edge_torch.generative.test import utils as test_utils
|
25
27
|
import numpy as np
|
26
28
|
import torch
|
27
29
|
|
@@ -43,32 +45,36 @@ class TestModelConversion(googletest.TestCase):
|
|
43
45
|
)
|
44
46
|
)
|
45
47
|
|
46
|
-
|
47
|
-
ai_edge_config.Config.use_torch_xla,
|
48
|
-
reason="tests with custom ops are not supported on oss",
|
49
|
-
)
|
50
|
-
def test_gemma(self):
|
51
|
-
config = gemma.get_fake_model_config()
|
52
|
-
model = gemma.Gemma(config)
|
53
|
-
|
48
|
+
def _test_model(self, config, model, signature_name, atol, rtol):
|
54
49
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
55
|
-
tokens = torch.full((1, 10), 0, dtype=torch.
|
50
|
+
tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
|
56
51
|
tokens[0, :4] = idx
|
57
|
-
input_pos = torch.arange(0, 10)
|
52
|
+
input_pos = torch.arange(0, 10, dtype=torch.int)
|
53
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
58
54
|
|
59
|
-
edge_model = ai_edge_torch.
|
55
|
+
edge_model = ai_edge_torch.signature(
|
56
|
+
signature_name,
|
57
|
+
model,
|
58
|
+
sample_kwargs={
|
59
|
+
"tokens": tokens,
|
60
|
+
"input_pos": input_pos,
|
61
|
+
"kv_cache": kv,
|
62
|
+
},
|
63
|
+
).convert()
|
60
64
|
edge_model.set_interpreter_builder(
|
61
65
|
self._interpreter_builder(edge_model.tflite_model())
|
62
66
|
)
|
63
67
|
|
64
68
|
self.assertTrue(
|
65
|
-
|
69
|
+
test_utils.compare_tflite_torch(
|
66
70
|
edge_model,
|
67
71
|
model,
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
+
tokens,
|
73
|
+
input_pos,
|
74
|
+
kv,
|
75
|
+
signature_name=signature_name,
|
76
|
+
atol=atol,
|
77
|
+
rtol=rtol,
|
72
78
|
)
|
73
79
|
)
|
74
80
|
|
@@ -76,34 +82,21 @@ class TestModelConversion(googletest.TestCase):
|
|
76
82
|
ai_edge_config.Config.use_torch_xla,
|
77
83
|
reason="tests with custom ops are not supported on oss",
|
78
84
|
)
|
79
|
-
def
|
80
|
-
config =
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
85
|
-
prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
86
|
-
prefill_tokens[0, :4] = idx
|
87
|
-
prefill_input_pos = torch.arange(0, 10)
|
88
|
-
|
89
|
-
edge_model = ai_edge_torch.signature(
|
90
|
-
"prefill", model, (prefill_tokens, prefill_input_pos)
|
91
|
-
).convert()
|
92
|
-
edge_model.set_interpreter_builder(
|
93
|
-
self._interpreter_builder(edge_model.tflite_model())
|
85
|
+
def test_gemma(self):
|
86
|
+
config = gemma.get_fake_model_config()
|
87
|
+
pytorch_model = gemma.Gemma(config).eval()
|
88
|
+
self._test_model(
|
89
|
+
config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
|
94
90
|
)
|
95
91
|
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
rtol=1e-5,
|
105
|
-
)
|
106
|
-
)
|
92
|
+
@googletest.skipIf(
|
93
|
+
ai_edge_config.Config.use_torch_xla,
|
94
|
+
reason="tests with custom ops are not supported on oss",
|
95
|
+
)
|
96
|
+
def test_gemma2(self):
|
97
|
+
config = gemma2.get_fake_model_config()
|
98
|
+
pytorch_model = gemma2.Gemma2(config).eval()
|
99
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-1, rtol=1e-3)
|
107
100
|
|
108
101
|
@googletest.skipIf(
|
109
102
|
ai_edge_config.Config.use_torch_xla,
|
@@ -112,27 +105,27 @@ class TestModelConversion(googletest.TestCase):
|
|
112
105
|
def test_phi2(self):
|
113
106
|
config = phi2.get_fake_model_config()
|
114
107
|
pytorch_model = phi2.Phi2(config).eval()
|
115
|
-
|
116
|
-
|
117
|
-
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
118
|
-
tokens[0, :4] = idx
|
119
|
-
input_pos = torch.arange(0, 10)
|
120
|
-
|
121
|
-
edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
|
122
|
-
edge_model.set_interpreter_builder(
|
123
|
-
self._interpreter_builder(edge_model.tflite_model())
|
108
|
+
self._test_model(
|
109
|
+
config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
|
124
110
|
)
|
125
111
|
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
112
|
+
@googletest.skipIf(
|
113
|
+
ai_edge_config.Config.use_torch_xla,
|
114
|
+
reason="tests with custom ops are not supported on oss",
|
115
|
+
)
|
116
|
+
def test_smollm(self):
|
117
|
+
config = smollm.get_fake_model_config()
|
118
|
+
pytorch_model = smollm.SmolLM(config).eval()
|
119
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
120
|
+
|
121
|
+
@googletest.skipIf(
|
122
|
+
ai_edge_config.Config.use_torch_xla,
|
123
|
+
reason="tests with custom ops are not supported on oss",
|
124
|
+
)
|
125
|
+
def test_openelm(self):
|
126
|
+
config = openelm.get_fake_model_config()
|
127
|
+
pytorch_model = openelm.OpenELM(config).eval()
|
128
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
136
129
|
|
137
130
|
|
138
131
|
if __name__ == "__main__":
|
@@ -115,8 +115,8 @@ class TestQuantizeConvert(parameterized.TestCase):
|
|
115
115
|
def test_quantize_convert_toy_sizes(self, quant_config):
|
116
116
|
config = toy_model.get_model_config()
|
117
117
|
pytorch_model = toy_model.ToySingleLayerModel(config)
|
118
|
-
idx = torch.unsqueeze(torch.arange(0, 100), 0)
|
119
|
-
input_pos = torch.arange(0, 100)
|
118
|
+
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
|
119
|
+
input_pos = torch.arange(0, 100, dtype=torch.int)
|
120
120
|
|
121
121
|
quantized_model = ai_edge_torch.convert(
|
122
122
|
pytorch_model, (idx, input_pos), quant_config=quant_config
|
@@ -131,8 +131,8 @@ class TestQuantizeConvert(parameterized.TestCase):
|
|
131
131
|
def test_quantize_convert_toy_weight_sharing(self):
|
132
132
|
config = toy_model.get_model_config()
|
133
133
|
pytorch_model = toy_model.ToySingleLayerModelWeightSharing(config)
|
134
|
-
idx = torch.unsqueeze(torch.arange(0, 100), 0)
|
135
|
-
input_pos = torch.arange(0, 100)
|
134
|
+
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
|
135
|
+
input_pos = torch.arange(0, 100, dtype=torch.int)
|
136
136
|
|
137
137
|
quant_config = quant_recipes.full_int8_dynamic_recipe()
|
138
138
|
quantized_model = ai_edge_torch.convert(
|
@@ -149,7 +149,7 @@ class TestQuantizeConvert(parameterized.TestCase):
|
|
149
149
|
self.skipTest("b/338288901")
|
150
150
|
config = toy_model_with_kv_cache.get_model_config()
|
151
151
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
152
|
-
idx, input_pos = torch.tensor([[1]], dtype=torch.
|
152
|
+
idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
|
153
153
|
[10], dtype=torch.int64
|
154
154
|
)
|
155
155
|
|
@@ -0,0 +1,54 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Common utils for testing."""
|
17
|
+
|
18
|
+
from ai_edge_torch import model
|
19
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
|
+
from ai_edge_torch.lowertools import common_utils
|
21
|
+
import numpy as np
|
22
|
+
import torch
|
23
|
+
from torch.utils import _pytree as pytree
|
24
|
+
|
25
|
+
|
26
|
+
def compare_tflite_torch(
|
27
|
+
edge_model: model.Model,
|
28
|
+
torch_model: torch.nn.Module,
|
29
|
+
tokens: torch.Tensor,
|
30
|
+
input_pos: torch.Tensor,
|
31
|
+
kv_cache: kv_utils.KVCache,
|
32
|
+
signature_name: str,
|
33
|
+
atol: float = 1e-5,
|
34
|
+
rtol: float = 1e-5,
|
35
|
+
):
|
36
|
+
"""Compares torch models and TFLite models."""
|
37
|
+
values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
|
38
|
+
flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
|
39
|
+
torch_output = torch_model(tokens, input_pos, kv_cache)
|
40
|
+
|
41
|
+
input_kv_flatten = {k: v.numpy() for k, v in zip(flat_names, values)}
|
42
|
+
edge_output = edge_model(
|
43
|
+
signature_name=signature_name,
|
44
|
+
tokens=tokens.numpy(),
|
45
|
+
input_pos=input_pos.numpy(),
|
46
|
+
**input_kv_flatten,
|
47
|
+
)
|
48
|
+
|
49
|
+
return np.allclose(
|
50
|
+
edge_output["logits"],
|
51
|
+
torch_output["logits"].detach().numpy(),
|
52
|
+
atol=atol,
|
53
|
+
rtol=rtol,
|
54
|
+
)
|