ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240911__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
- ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +38 -22
- ai_edge_torch/generative/layers/attention.py +60 -63
- ai_edge_torch/generative/layers/kv_cache.py +160 -51
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
- ai_edge_torch/generative/test/test_model_conversion.py +71 -33
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/RECORD +22 -32
- ai_edge_torch/generative/examples/experimental/gemma/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +0 -88
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/top_level.txt +0 -0
@@ -12,72 +12,181 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
# `nn.Module` which implements a KV cache.
|
16
15
|
|
17
|
-
|
16
|
+
"""Utility functions for externalized KV Cache."""
|
17
|
+
|
18
|
+
import dataclasses
|
19
|
+
from typing import List, Tuple
|
20
|
+
|
21
|
+
from ai_edge_torch import hlfb
|
22
|
+
from ai_edge_torch.generative.layers import model_config
|
18
23
|
import torch
|
19
|
-
|
24
|
+
import torch.utils._pytree as pytree
|
20
25
|
|
21
26
|
|
22
|
-
|
27
|
+
@dataclasses.dataclass
|
28
|
+
class KVCacheEntry:
|
29
|
+
"""A single cache entry that includes K and V caches.
|
23
30
|
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
"""Initializes the KVCache layer.
|
31
|
+
The chaches are built based on the provided config with the shape of
|
32
|
+
(batch_size=1, kv_cache_max, num_query_groups, head_dim).
|
33
|
+
"""
|
28
34
|
|
29
|
-
|
30
|
-
|
31
|
-
kv_cache_max (int): the max length of KV cache.
|
32
|
-
n_heads (int): number of kv heads.
|
33
|
-
head_dim (int): the head dimension size.
|
34
|
-
enable_hlfb (bool): whether hlfb is enabled or not.
|
35
|
-
"""
|
36
|
-
super().__init__()
|
37
|
-
cache_shape = (batch_size, kv_cache_max, n_heads, head_dim)
|
38
|
-
self.register_buffer("k_cache", torch.zeros(cache_shape), persistent=False)
|
39
|
-
self.register_buffer("v_cache", torch.zeros(cache_shape), persistent=False)
|
40
|
-
self.enable_hlfb = enable_hlfb
|
41
|
-
self.kv_cache_max = kv_cache_max
|
35
|
+
k_cache: torch.Tensor
|
36
|
+
v_cache: torch.Tensor
|
42
37
|
|
43
|
-
|
44
|
-
|
38
|
+
@classmethod
|
39
|
+
def from_model_config(
|
40
|
+
cls,
|
41
|
+
config: model_config.ModelConfig,
|
42
|
+
dtype: torch.dtype = torch.float32,
|
43
|
+
device: torch.device = None,
|
44
|
+
) -> "KVCacheEntry":
|
45
|
+
"""Build an instance of the class based on model config."""
|
46
|
+
shape = (
|
47
|
+
1, # Batch dimmension.
|
48
|
+
config.kv_cache_max,
|
49
|
+
config.attn_config.num_query_groups,
|
50
|
+
config.attn_config.head_dim,
|
51
|
+
)
|
52
|
+
k = torch.zeros(shape, dtype=dtype, device=device)
|
53
|
+
v = torch.zeros(shape, dtype=dtype, device=device)
|
54
|
+
obj = cls(k_cache=k, v_cache=v)
|
55
|
+
return obj
|
45
56
|
|
46
|
-
Args:
|
47
|
-
input_pos (torch.Tensor): the input position.
|
48
|
-
k_val (torch.Tensor): the new `key` value.
|
49
|
-
v_val (torch.Tensor): the new `value` value.
|
50
57
|
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
if self.enable_hlfb:
|
55
|
-
return self.update_cache_with_hlfb(input_pos, k_val, v_val)
|
58
|
+
@dataclasses.dataclass
|
59
|
+
class KVCache:
|
60
|
+
"""A utility class for holding KV cache entries per layer."""
|
56
61
|
|
57
|
-
|
58
|
-
updated_v = self.v_cache.index_copy_(1, input_pos, v_val)
|
59
|
-
# Here we need a clone otherwise dynamo export will fail.
|
60
|
-
return torch.clone(updated_k), torch.clone(updated_v)
|
62
|
+
caches: Tuple[KVCacheEntry, ...]
|
61
63
|
|
62
|
-
|
63
|
-
|
64
|
+
@classmethod
|
65
|
+
def from_model_config(
|
66
|
+
cls,
|
67
|
+
config: model_config.ModelConfig,
|
68
|
+
dtype: torch.dtype = torch.float32,
|
69
|
+
device: torch.device = None,
|
70
|
+
) -> "KVCache":
|
71
|
+
"""Build an instance of the class based on model config.
|
64
72
|
|
65
73
|
Args:
|
66
|
-
|
67
|
-
|
68
|
-
|
74
|
+
config (ModelConfig): Model config used for building the cache.
|
75
|
+
dtype (torch.dtype, optional): The data type of the cache tensor.
|
76
|
+
Defaults to torch.float32.
|
77
|
+
device (torch.device, optional): The device placement of the cache
|
78
|
+
tensors. Defaults to None.
|
69
79
|
|
70
80
|
Returns:
|
71
|
-
|
81
|
+
KVCache: The created cache object.
|
72
82
|
"""
|
83
|
+
caches = [
|
84
|
+
KVCacheEntry.from_model_config(config, dtype, device)
|
85
|
+
for _ in range(config.num_layers)
|
86
|
+
]
|
87
|
+
obj = cls(caches=tuple(caches))
|
88
|
+
return obj
|
73
89
|
|
74
|
-
|
75
|
-
|
76
|
-
)
|
77
|
-
|
78
|
-
|
90
|
+
def flatten(self) -> List[torch.Tensor]:
|
91
|
+
"""Flatten the cache entries into a list of tensors with order k_i, v_i."""
|
92
|
+
flattened, _ = _flatten_kvc(self)
|
93
|
+
return flattened
|
94
|
+
|
95
|
+
|
96
|
+
def _flatten_kvc(kvc: KVCache) -> Tuple[List[str], List[str]]:
|
97
|
+
flattened = []
|
98
|
+
flat_names = []
|
99
|
+
none_names = []
|
100
|
+
for i, kv_entry in enumerate(kvc.caches):
|
101
|
+
flattened.append(kv_entry.k_cache)
|
102
|
+
flat_names.append(f"k_{i}")
|
103
|
+
flattened.append(kv_entry.v_cache)
|
104
|
+
flat_names.append(f"v_{i}")
|
105
|
+
return flattened, [flat_names, none_names]
|
106
|
+
|
107
|
+
|
108
|
+
def _flatten_kvc_with_keys(kvc: KVCache) -> Tuple[List, List]:
|
109
|
+
flattened, (flat_names, none_names) = _flatten_kvc(kvc)
|
110
|
+
return [
|
111
|
+
(pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
|
112
|
+
], flat_names
|
113
|
+
|
114
|
+
|
115
|
+
def _unflatten_kvc(
|
116
|
+
values: List[torch.Tensor], context: Tuple[List, List]
|
117
|
+
) -> KVCache:
|
118
|
+
assert len(values) % 2 == 0, "Found odd number of K and V entries."
|
119
|
+
num_layers = len(values) // 2
|
120
|
+
flat_names = context[0]
|
121
|
+
kv_entries = []
|
122
|
+
for i in range(num_layers):
|
123
|
+
k_cache_idx = flat_names.index(f"k_{i}")
|
124
|
+
v_cache_idx = flat_names.index(f"v_{i}")
|
125
|
+
kv_entries.append(
|
126
|
+
KVCacheEntry(k_cache=values[k_cache_idx], v_cache=values[v_cache_idx])
|
79
127
|
)
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
128
|
+
obj = KVCache(tuple(kv_entries))
|
129
|
+
return obj
|
130
|
+
|
131
|
+
|
132
|
+
pytree.register_pytree_node(
|
133
|
+
KVCache,
|
134
|
+
_flatten_kvc,
|
135
|
+
_unflatten_kvc,
|
136
|
+
flatten_with_keys_fn=_flatten_kvc_with_keys,
|
137
|
+
serialized_type_name="",
|
138
|
+
)
|
139
|
+
|
140
|
+
|
141
|
+
def update(
|
142
|
+
cache: KVCacheEntry,
|
143
|
+
input_pos: torch.Tensor,
|
144
|
+
k_slice: torch.Tensor,
|
145
|
+
v_slice: torch.Tensor,
|
146
|
+
enable_hlfb: bool = True,
|
147
|
+
) -> KVCacheEntry:
|
148
|
+
"""Out of place update of Cache buffer.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
cache (KVCacheEntry): The original cache buffer.
|
152
|
+
input_pos (torch.Tensor): The update slice positions.
|
153
|
+
k_slice (torch.Tensor): The K slice to be updated in the new cache.
|
154
|
+
v_slice (torch.Tensor): The V slice to be updated in the new cache.
|
155
|
+
enable_hlfb (bool, optional): Whether the op is annotated for export with
|
156
|
+
High Level Function Boundary. Defaults to True.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
KVCacheEntry: The updated KVCache entry based on the passed inputs.
|
160
|
+
"""
|
161
|
+
update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl
|
162
|
+
return update_func(cache, input_pos, k_slice, v_slice)
|
163
|
+
|
164
|
+
|
165
|
+
def _update_kv_base_impl(
|
166
|
+
cache: KVCacheEntry,
|
167
|
+
input_pos: torch.Tensor,
|
168
|
+
k_slice: torch.Tensor,
|
169
|
+
v_slice: torch.Tensor,
|
170
|
+
) -> KVCacheEntry:
|
171
|
+
"""Update the cache buffer without High Level Function Boundary annotation."""
|
172
|
+
k = cache.k_cache.index_copy(1, input_pos, k_slice)
|
173
|
+
v = cache.v_cache.index_copy(1, input_pos, v_slice)
|
174
|
+
updated_cache = KVCacheEntry(k, v)
|
175
|
+
return updated_cache
|
176
|
+
|
177
|
+
|
178
|
+
def _update_kv_hlfb_impl(
|
179
|
+
cache: KVCacheEntry,
|
180
|
+
input_pos: torch.Tensor,
|
181
|
+
k_slice: torch.Tensor,
|
182
|
+
v_slice: torch.Tensor,
|
183
|
+
) -> KVCacheEntry:
|
184
|
+
"""Update the cache buffer with High Level Function Boundary annotation."""
|
185
|
+
builder = hlfb.StableHLOCompositeBuilder(name="odml.update_external_kv_cache")
|
186
|
+
k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
|
187
|
+
cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
|
188
|
+
)
|
189
|
+
k = k_cache.index_copy(1, input_pos, k_slice)
|
190
|
+
v = v_cache.index_copy(1, input_pos, v_slice)
|
191
|
+
k, v = builder.mark_outputs(k, v)
|
192
|
+
return KVCacheEntry(k, v)
|
@@ -12,19 +12,17 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
# A suite of tests to validate experimental external KV Cache layers and models.
|
16
15
|
|
17
|
-
|
18
|
-
|
19
|
-
from ai_edge_torch.generative.
|
20
|
-
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
16
|
+
"""A suite of tests to validate KV Cache layer."""
|
17
|
+
|
18
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
20
|
import torch
|
23
21
|
|
24
22
|
from absl.testing import absltest as googletest
|
25
23
|
|
26
24
|
|
27
|
-
class
|
25
|
+
class TestKVLayers(googletest.TestCase):
|
28
26
|
|
29
27
|
def _get_test_config(
|
30
28
|
self, num_layers, head_dim, num_query_groups, kv_cache_max_len
|
@@ -54,7 +52,7 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
54
52
|
num_query_groups=NUM_QG,
|
55
53
|
kv_cache_max_len=KV_LEN,
|
56
54
|
)
|
57
|
-
kv = kv_utils.
|
55
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
58
56
|
entry = kv.caches[0]
|
59
57
|
# single-slice update
|
60
58
|
input_pos = torch.tensor([1])
|
@@ -88,14 +86,14 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
88
86
|
def test_serialization(self):
|
89
87
|
class TestModel(torch.nn.Module):
|
90
88
|
|
91
|
-
def forward(self, kv: kv_utils.
|
89
|
+
def forward(self, kv: kv_utils.KVCache) -> kv_utils.KVCache:
|
92
90
|
updated_kv_entries = [
|
93
91
|
kv_utils.KVCacheEntry(
|
94
92
|
torch.zeros_like(entry.k_cache), torch.zeros_like(entry.v_cache)
|
95
93
|
)
|
96
94
|
for entry in kv.caches
|
97
95
|
]
|
98
|
-
return kv_utils.
|
96
|
+
return kv_utils.KVCache(updated_kv_entries)
|
99
97
|
|
100
98
|
N = 1
|
101
99
|
HEAD_DIM = 2
|
@@ -107,7 +105,7 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
107
105
|
num_query_groups=NUM_QG,
|
108
106
|
kv_cache_max_len=KV_LEN,
|
109
107
|
)
|
110
|
-
kv = kv_utils.
|
108
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
111
109
|
model = TestModel()
|
112
110
|
exported_program = torch.export.export(model, (kv,))
|
113
111
|
input_specs = exported_program.graph_signature.input_specs
|
@@ -116,17 +114,5 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
116
114
|
self.assertEqual(input_specs[1].arg.name, "kv_v_0")
|
117
115
|
|
118
116
|
|
119
|
-
class TestExternalKVModels(googletest.TestCase):
|
120
|
-
|
121
|
-
def test_can_build_gemma(self):
|
122
|
-
gemma.define_and_run_2b(checkpoint_path=None, test_model=True)
|
123
|
-
|
124
|
-
def test_can_build_phi2(self):
|
125
|
-
phi2.define_and_run(checkpoint_path=None, test_model=True)
|
126
|
-
|
127
|
-
def test_can_build_tinyllama(self):
|
128
|
-
tiny_llama.define_and_run(checkpoint_path=None, test_model=True)
|
129
|
-
|
130
|
-
|
131
117
|
if __name__ == "__main__":
|
132
118
|
googletest.main()
|
@@ -12,16 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
15
|
+
|
16
|
+
"""Testing model conversion for a few gen-ai models."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
|
-
from ai_edge_torch.generative.examples.
|
21
|
-
from ai_edge_torch.generative.examples.phi2 import phi2
|
22
|
-
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
|
20
|
+
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
|
23
21
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
24
|
-
from ai_edge_torch.
|
22
|
+
from ai_edge_torch.generative.layers import kv_cache
|
23
|
+
from ai_edge_torch.generative.test import utils as test_utils
|
25
24
|
import numpy as np
|
26
25
|
import torch
|
27
26
|
|
@@ -49,22 +48,32 @@ class TestModelConversion(googletest.TestCase):
|
|
49
48
|
)
|
50
49
|
def test_toy_model_with_kv_cache(self):
|
51
50
|
config = toy_model_with_kv_cache.get_model_config()
|
52
|
-
pytorch_model = toy_model_with_kv_cache.
|
53
|
-
|
51
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
52
|
+
tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
54
53
|
[10], dtype=torch.int64
|
55
54
|
)
|
56
|
-
|
57
|
-
|
55
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
56
|
+
|
57
|
+
edge_model = ai_edge_torch.convert(
|
58
|
+
pytorch_model,
|
59
|
+
sample_kwargs={
|
60
|
+
"tokens": tokens,
|
61
|
+
"input_pos": input_pos,
|
62
|
+
"kv_cache": kv,
|
63
|
+
},
|
64
|
+
)
|
58
65
|
edge_model.set_interpreter_builder(
|
59
66
|
self._interpreter_builder(edge_model.tflite_model())
|
60
67
|
)
|
61
68
|
|
62
69
|
self.assertTrue(
|
63
|
-
|
70
|
+
test_utils.compare_tflite_torch(
|
64
71
|
edge_model,
|
65
72
|
pytorch_model,
|
66
|
-
|
67
|
-
|
73
|
+
tokens,
|
74
|
+
input_pos,
|
75
|
+
kv,
|
76
|
+
signature_name="serving_default",
|
68
77
|
atol=1e-5,
|
69
78
|
rtol=1e-5,
|
70
79
|
)
|
@@ -77,22 +86,32 @@ class TestModelConversion(googletest.TestCase):
|
|
77
86
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
78
87
|
config = toy_model_with_kv_cache.get_model_config()
|
79
88
|
config.enable_hlfb = True
|
80
|
-
pytorch_model = toy_model_with_kv_cache.
|
81
|
-
|
89
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
90
|
+
tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
82
91
|
[10], dtype=torch.int64
|
83
92
|
)
|
84
|
-
|
85
|
-
|
93
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
94
|
+
|
95
|
+
edge_model = ai_edge_torch.convert(
|
96
|
+
pytorch_model,
|
97
|
+
sample_kwargs={
|
98
|
+
"tokens": tokens,
|
99
|
+
"input_pos": input_pos,
|
100
|
+
"kv_cache": kv,
|
101
|
+
},
|
102
|
+
)
|
86
103
|
edge_model.set_interpreter_builder(
|
87
104
|
self._interpreter_builder(edge_model.tflite_model())
|
88
105
|
)
|
89
106
|
|
90
107
|
self.assertTrue(
|
91
|
-
|
108
|
+
test_utils.compare_tflite_torch(
|
92
109
|
edge_model,
|
93
110
|
pytorch_model,
|
94
|
-
|
95
|
-
|
111
|
+
tokens,
|
112
|
+
input_pos,
|
113
|
+
kv,
|
114
|
+
signature_name="serving_default",
|
96
115
|
atol=1e-5,
|
97
116
|
rtol=1e-5,
|
98
117
|
)
|
@@ -117,37 +136,56 @@ class TestModelConversion(googletest.TestCase):
|
|
117
136
|
decode_token = torch.tensor([[1]], dtype=torch.long)
|
118
137
|
decode_input_pos = torch.tensor([5], dtype=torch.int64)
|
119
138
|
|
139
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
140
|
+
|
120
141
|
edge_model = (
|
121
142
|
ai_edge_torch.signature(
|
122
|
-
"prefill",
|
143
|
+
"prefill",
|
144
|
+
pytorch_model,
|
145
|
+
sample_kwargs={
|
146
|
+
"tokens": prefill_tokens,
|
147
|
+
"input_pos": prefill_input_pos,
|
148
|
+
"kv_cache": kv,
|
149
|
+
},
|
150
|
+
)
|
151
|
+
.signature(
|
152
|
+
"decode",
|
153
|
+
pytorch_model,
|
154
|
+
sample_kwargs={
|
155
|
+
"tokens": decode_token,
|
156
|
+
"input_pos": decode_input_pos,
|
157
|
+
"kv_cache": kv,
|
158
|
+
},
|
123
159
|
)
|
124
|
-
.signature("decode", pytorch_model, (decode_token, decode_input_pos))
|
125
160
|
.convert()
|
126
161
|
)
|
127
162
|
edge_model.set_interpreter_builder(
|
128
163
|
self._interpreter_builder(edge_model.tflite_model())
|
129
164
|
)
|
130
165
|
|
131
|
-
copied_model = copy.deepcopy(pytorch_model)
|
132
|
-
copied_edge = copy.deepcopy(edge_model)
|
133
|
-
|
134
166
|
self.assertTrue(
|
135
|
-
|
167
|
+
test_utils.compare_tflite_torch(
|
136
168
|
edge_model,
|
137
169
|
pytorch_model,
|
138
|
-
|
170
|
+
prefill_tokens,
|
171
|
+
prefill_input_pos,
|
172
|
+
kv,
|
139
173
|
signature_name="prefill",
|
140
|
-
|
174
|
+
atol=1e-5,
|
175
|
+
rtol=1e-5,
|
141
176
|
)
|
142
177
|
)
|
143
178
|
|
144
179
|
self.assertTrue(
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
180
|
+
test_utils.compare_tflite_torch(
|
181
|
+
edge_model,
|
182
|
+
pytorch_model,
|
183
|
+
decode_token,
|
184
|
+
decode_input_pos,
|
185
|
+
kv,
|
149
186
|
signature_name="decode",
|
150
|
-
|
187
|
+
atol=1e-5,
|
188
|
+
rtol=1e-5,
|
151
189
|
)
|
152
190
|
)
|
153
191
|
|
@@ -12,16 +12,16 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
15
|
+
|
16
|
+
"""Testing model conversion for a few gen-ai models."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
|
-
from ai_edge_torch.generative.examples.gemma import gemma
|
21
|
-
from ai_edge_torch.generative.examples.
|
22
|
-
from ai_edge_torch.generative.examples.
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.
|
20
|
+
from ai_edge_torch.generative.examples.gemma import gemma
|
21
|
+
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
|
+
from ai_edge_torch.generative.examples.phi import phi2
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache
|
24
|
+
from ai_edge_torch.generative.test import utils as test_utils
|
25
25
|
import numpy as np
|
26
26
|
import torch
|
27
27
|
|
@@ -55,18 +55,28 @@ class TestModelConversion(googletest.TestCase):
|
|
55
55
|
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
56
56
|
tokens[0, :4] = idx
|
57
57
|
input_pos = torch.arange(0, 10)
|
58
|
-
|
59
|
-
|
58
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
59
|
+
|
60
|
+
edge_model = ai_edge_torch.convert(
|
61
|
+
model,
|
62
|
+
sample_kwargs={
|
63
|
+
"tokens": tokens,
|
64
|
+
"input_pos": input_pos,
|
65
|
+
"kv_cache": kv,
|
66
|
+
},
|
67
|
+
)
|
60
68
|
edge_model.set_interpreter_builder(
|
61
69
|
self._interpreter_builder(edge_model.tflite_model())
|
62
70
|
)
|
63
71
|
|
64
72
|
self.assertTrue(
|
65
|
-
|
73
|
+
test_utils.compare_tflite_torch(
|
66
74
|
edge_model,
|
67
75
|
model,
|
68
|
-
|
69
|
-
|
76
|
+
tokens,
|
77
|
+
input_pos,
|
78
|
+
kv,
|
79
|
+
signature_name="serving_default",
|
70
80
|
atol=1e-2,
|
71
81
|
rtol=1e-5,
|
72
82
|
)
|
@@ -85,23 +95,31 @@ class TestModelConversion(googletest.TestCase):
|
|
85
95
|
prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
86
96
|
prefill_tokens[0, :4] = idx
|
87
97
|
prefill_input_pos = torch.arange(0, 10)
|
98
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
88
99
|
|
89
100
|
edge_model = ai_edge_torch.signature(
|
90
|
-
"prefill",
|
101
|
+
"prefill",
|
102
|
+
model,
|
103
|
+
sample_kwargs={
|
104
|
+
"tokens": prefill_tokens,
|
105
|
+
"input_pos": prefill_input_pos,
|
106
|
+
"kv_cache": kv,
|
107
|
+
},
|
91
108
|
).convert()
|
92
109
|
edge_model.set_interpreter_builder(
|
93
110
|
self._interpreter_builder(edge_model.tflite_model())
|
94
111
|
)
|
95
112
|
|
96
113
|
self.assertTrue(
|
97
|
-
|
114
|
+
test_utils.compare_tflite_torch(
|
98
115
|
edge_model,
|
99
116
|
model,
|
100
|
-
|
117
|
+
prefill_tokens,
|
118
|
+
prefill_input_pos,
|
119
|
+
kv,
|
101
120
|
signature_name="prefill",
|
102
|
-
|
103
|
-
|
104
|
-
rtol=1e-5,
|
121
|
+
atol=1e-1,
|
122
|
+
rtol=1e-3,
|
105
123
|
)
|
106
124
|
)
|
107
125
|
|
@@ -117,18 +135,28 @@ class TestModelConversion(googletest.TestCase):
|
|
117
135
|
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
118
136
|
tokens[0, :4] = idx
|
119
137
|
input_pos = torch.arange(0, 10)
|
120
|
-
|
121
|
-
|
138
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
139
|
+
|
140
|
+
edge_model = ai_edge_torch.convert(
|
141
|
+
pytorch_model,
|
142
|
+
sample_kwargs={
|
143
|
+
"tokens": tokens,
|
144
|
+
"input_pos": input_pos,
|
145
|
+
"kv_cache": kv,
|
146
|
+
},
|
147
|
+
)
|
122
148
|
edge_model.set_interpreter_builder(
|
123
149
|
self._interpreter_builder(edge_model.tflite_model())
|
124
150
|
)
|
125
151
|
|
126
152
|
self.assertTrue(
|
127
|
-
|
153
|
+
test_utils.compare_tflite_torch(
|
128
154
|
edge_model,
|
129
155
|
pytorch_model,
|
130
|
-
|
131
|
-
|
156
|
+
tokens,
|
157
|
+
input_pos,
|
158
|
+
kv,
|
159
|
+
signature_name="serving_default",
|
132
160
|
atol=1e-3,
|
133
161
|
rtol=1e-3,
|
134
162
|
)
|
@@ -0,0 +1,54 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Common utils for testing."""
|
17
|
+
|
18
|
+
from ai_edge_torch import model
|
19
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
|
+
from ai_edge_torch.lowertools import common_utils
|
21
|
+
import numpy as np
|
22
|
+
import torch
|
23
|
+
from torch.utils import _pytree as pytree
|
24
|
+
|
25
|
+
|
26
|
+
def compare_tflite_torch(
|
27
|
+
edge_model: model.Model,
|
28
|
+
torch_model: torch.nn.Module,
|
29
|
+
tokens: torch.Tensor,
|
30
|
+
input_pos: torch.Tensor,
|
31
|
+
kv_cache: kv_utils.KVCache,
|
32
|
+
signature_name: str,
|
33
|
+
atol: float = 1e-5,
|
34
|
+
rtol: float = 1e-5,
|
35
|
+
):
|
36
|
+
"""Compares torch models and TFLite models."""
|
37
|
+
values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
|
38
|
+
flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
|
39
|
+
torch_output = torch_model(tokens, input_pos, kv_cache)
|
40
|
+
|
41
|
+
input_kv_flatten = {k: v.numpy() for k, v in zip(flat_names, values)}
|
42
|
+
edge_output = edge_model(
|
43
|
+
signature_name=signature_name,
|
44
|
+
tokens=tokens.numpy(),
|
45
|
+
input_pos=input_pos.numpy(),
|
46
|
+
**input_kv_flatten,
|
47
|
+
)
|
48
|
+
|
49
|
+
return np.allclose(
|
50
|
+
edge_output["logits"],
|
51
|
+
torch_output["logits"].detach().numpy(),
|
52
|
+
atol=atol,
|
53
|
+
rtol=rtol,
|
54
|
+
)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240911
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|