ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240911__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  3. ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
  4. ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
  5. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
  7. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
  8. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  9. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +38 -22
  10. ai_edge_torch/generative/layers/attention.py +60 -63
  11. ai_edge_torch/generative/layers/kv_cache.py +160 -51
  12. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
  13. ai_edge_torch/generative/test/test_model_conversion.py +71 -33
  14. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  15. ai_edge_torch/generative/test/utils.py +54 -0
  16. ai_edge_torch/version.py +1 -1
  17. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/METADATA +1 -1
  18. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/RECORD +22 -32
  19. ai_edge_torch/generative/examples/experimental/gemma/__init__.py +0 -14
  20. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +0 -88
  21. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  22. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  23. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  24. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  25. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  26. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  27. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  28. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  29. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  30. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  31. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/LICENSE +0 -0
  32. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/WHEEL +0 -0
  33. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/top_level.txt +0 -0
@@ -12,72 +12,181 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # `nn.Module` which implements a KV cache.
16
15
 
17
- from ai_edge_torch.hlfb import StableHLOCompositeBuilder
16
+ """Utility functions for externalized KV Cache."""
17
+
18
+ import dataclasses
19
+ from typing import List, Tuple
20
+
21
+ from ai_edge_torch import hlfb
22
+ from ai_edge_torch.generative.layers import model_config
18
23
  import torch
19
- from torch import nn
24
+ import torch.utils._pytree as pytree
20
25
 
21
26
 
22
- class KVCache(nn.Module):
27
+ @dataclasses.dataclass
28
+ class KVCacheEntry:
29
+ """A single cache entry that includes K and V caches.
23
30
 
24
- def __init__(
25
- self, batch_size, kv_cache_max, n_heads, head_dim, enable_hlfb=False
26
- ):
27
- """Initializes the KVCache layer.
31
+ The chaches are built based on the provided config with the shape of
32
+ (batch_size=1, kv_cache_max, num_query_groups, head_dim).
33
+ """
28
34
 
29
- Args:
30
- batch_size (int): batch size. Currently only batch size 1 is supported.
31
- kv_cache_max (int): the max length of KV cache.
32
- n_heads (int): number of kv heads.
33
- head_dim (int): the head dimension size.
34
- enable_hlfb (bool): whether hlfb is enabled or not.
35
- """
36
- super().__init__()
37
- cache_shape = (batch_size, kv_cache_max, n_heads, head_dim)
38
- self.register_buffer("k_cache", torch.zeros(cache_shape), persistent=False)
39
- self.register_buffer("v_cache", torch.zeros(cache_shape), persistent=False)
40
- self.enable_hlfb = enable_hlfb
41
- self.kv_cache_max = kv_cache_max
35
+ k_cache: torch.Tensor
36
+ v_cache: torch.Tensor
42
37
 
43
- def update_cache(self, input_pos, k_val, v_val):
44
- """Update an entry in the KV cache.
38
+ @classmethod
39
+ def from_model_config(
40
+ cls,
41
+ config: model_config.ModelConfig,
42
+ dtype: torch.dtype = torch.float32,
43
+ device: torch.device = None,
44
+ ) -> "KVCacheEntry":
45
+ """Build an instance of the class based on model config."""
46
+ shape = (
47
+ 1, # Batch dimmension.
48
+ config.kv_cache_max,
49
+ config.attn_config.num_query_groups,
50
+ config.attn_config.head_dim,
51
+ )
52
+ k = torch.zeros(shape, dtype=dtype, device=device)
53
+ v = torch.zeros(shape, dtype=dtype, device=device)
54
+ obj = cls(k_cache=k, v_cache=v)
55
+ return obj
45
56
 
46
- Args:
47
- input_pos (torch.Tensor): the input position.
48
- k_val (torch.Tensor): the new `key` value.
49
- v_val (torch.Tensor): the new `value` value.
50
57
 
51
- Returns:
52
- The updated key and value tensor.
53
- """
54
- if self.enable_hlfb:
55
- return self.update_cache_with_hlfb(input_pos, k_val, v_val)
58
+ @dataclasses.dataclass
59
+ class KVCache:
60
+ """A utility class for holding KV cache entries per layer."""
56
61
 
57
- updated_k = self.k_cache.index_copy_(1, input_pos, k_val)
58
- updated_v = self.v_cache.index_copy_(1, input_pos, v_val)
59
- # Here we need a clone otherwise dynamo export will fail.
60
- return torch.clone(updated_k), torch.clone(updated_v)
62
+ caches: Tuple[KVCacheEntry, ...]
61
63
 
62
- def update_cache_with_hlfb(self, input_pos, k_val, v_val):
63
- """Update an entry in the KV cache and enable high-level function boundary.
64
+ @classmethod
65
+ def from_model_config(
66
+ cls,
67
+ config: model_config.ModelConfig,
68
+ dtype: torch.dtype = torch.float32,
69
+ device: torch.device = None,
70
+ ) -> "KVCache":
71
+ """Build an instance of the class based on model config.
64
72
 
65
73
  Args:
66
- input_pos (torch.Tensor): the input position.
67
- k_val (torch.Tensor): the new `key` value.
68
- v_val (torch.Tensor): the new `value` value.
74
+ config (ModelConfig): Model config used for building the cache.
75
+ dtype (torch.dtype, optional): The data type of the cache tensor.
76
+ Defaults to torch.float32.
77
+ device (torch.device, optional): The device placement of the cache
78
+ tensors. Defaults to None.
69
79
 
70
80
  Returns:
71
- The updated key and value tensor.
81
+ KVCache: The created cache object.
72
82
  """
83
+ caches = [
84
+ KVCacheEntry.from_model_config(config, dtype, device)
85
+ for _ in range(config.num_layers)
86
+ ]
87
+ obj = cls(caches=tuple(caches))
88
+ return obj
73
89
 
74
- builder = StableHLOCompositeBuilder(
75
- name="odml.update_kv_cache", attr={"kv_cache_max": self.kv_cache_max}
76
- )
77
- k_cache, v_cache, input_pos, k_val, v_val = builder.mark_inputs(
78
- self.k_cache, self.v_cache, input_pos, k_val, v_val
90
+ def flatten(self) -> List[torch.Tensor]:
91
+ """Flatten the cache entries into a list of tensors with order k_i, v_i."""
92
+ flattened, _ = _flatten_kvc(self)
93
+ return flattened
94
+
95
+
96
+ def _flatten_kvc(kvc: KVCache) -> Tuple[List[str], List[str]]:
97
+ flattened = []
98
+ flat_names = []
99
+ none_names = []
100
+ for i, kv_entry in enumerate(kvc.caches):
101
+ flattened.append(kv_entry.k_cache)
102
+ flat_names.append(f"k_{i}")
103
+ flattened.append(kv_entry.v_cache)
104
+ flat_names.append(f"v_{i}")
105
+ return flattened, [flat_names, none_names]
106
+
107
+
108
+ def _flatten_kvc_with_keys(kvc: KVCache) -> Tuple[List, List]:
109
+ flattened, (flat_names, none_names) = _flatten_kvc(kvc)
110
+ return [
111
+ (pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
112
+ ], flat_names
113
+
114
+
115
+ def _unflatten_kvc(
116
+ values: List[torch.Tensor], context: Tuple[List, List]
117
+ ) -> KVCache:
118
+ assert len(values) % 2 == 0, "Found odd number of K and V entries."
119
+ num_layers = len(values) // 2
120
+ flat_names = context[0]
121
+ kv_entries = []
122
+ for i in range(num_layers):
123
+ k_cache_idx = flat_names.index(f"k_{i}")
124
+ v_cache_idx = flat_names.index(f"v_{i}")
125
+ kv_entries.append(
126
+ KVCacheEntry(k_cache=values[k_cache_idx], v_cache=values[v_cache_idx])
79
127
  )
80
- updated_k = k_cache.index_copy_(1, input_pos, k_val)
81
- updated_v = v_cache.index_copy_(1, input_pos, v_val)
82
- updated_k, updated_v = builder.mark_outputs(updated_k, updated_v)
83
- return updated_k, updated_v
128
+ obj = KVCache(tuple(kv_entries))
129
+ return obj
130
+
131
+
132
+ pytree.register_pytree_node(
133
+ KVCache,
134
+ _flatten_kvc,
135
+ _unflatten_kvc,
136
+ flatten_with_keys_fn=_flatten_kvc_with_keys,
137
+ serialized_type_name="",
138
+ )
139
+
140
+
141
+ def update(
142
+ cache: KVCacheEntry,
143
+ input_pos: torch.Tensor,
144
+ k_slice: torch.Tensor,
145
+ v_slice: torch.Tensor,
146
+ enable_hlfb: bool = True,
147
+ ) -> KVCacheEntry:
148
+ """Out of place update of Cache buffer.
149
+
150
+ Args:
151
+ cache (KVCacheEntry): The original cache buffer.
152
+ input_pos (torch.Tensor): The update slice positions.
153
+ k_slice (torch.Tensor): The K slice to be updated in the new cache.
154
+ v_slice (torch.Tensor): The V slice to be updated in the new cache.
155
+ enable_hlfb (bool, optional): Whether the op is annotated for export with
156
+ High Level Function Boundary. Defaults to True.
157
+
158
+ Returns:
159
+ KVCacheEntry: The updated KVCache entry based on the passed inputs.
160
+ """
161
+ update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl
162
+ return update_func(cache, input_pos, k_slice, v_slice)
163
+
164
+
165
+ def _update_kv_base_impl(
166
+ cache: KVCacheEntry,
167
+ input_pos: torch.Tensor,
168
+ k_slice: torch.Tensor,
169
+ v_slice: torch.Tensor,
170
+ ) -> KVCacheEntry:
171
+ """Update the cache buffer without High Level Function Boundary annotation."""
172
+ k = cache.k_cache.index_copy(1, input_pos, k_slice)
173
+ v = cache.v_cache.index_copy(1, input_pos, v_slice)
174
+ updated_cache = KVCacheEntry(k, v)
175
+ return updated_cache
176
+
177
+
178
+ def _update_kv_hlfb_impl(
179
+ cache: KVCacheEntry,
180
+ input_pos: torch.Tensor,
181
+ k_slice: torch.Tensor,
182
+ v_slice: torch.Tensor,
183
+ ) -> KVCacheEntry:
184
+ """Update the cache buffer with High Level Function Boundary annotation."""
185
+ builder = hlfb.StableHLOCompositeBuilder(name="odml.update_external_kv_cache")
186
+ k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
187
+ cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
188
+ )
189
+ k = k_cache.index_copy(1, input_pos, k_slice)
190
+ v = v_cache.index_copy(1, input_pos, v_slice)
191
+ k, v = builder.mark_outputs(k, v)
192
+ return KVCacheEntry(k, v)
@@ -12,19 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # A suite of tests to validate experimental external KV Cache layers and models.
16
15
 
17
- from ai_edge_torch.generative.examples.experimental.gemma import gemma
18
- from ai_edge_torch.generative.examples.experimental.phi import phi2
19
- from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama # NOQA
20
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
16
+ """A suite of tests to validate KV Cache layer."""
17
+
18
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
19
  import ai_edge_torch.generative.layers.model_config as cfg
22
20
  import torch
23
21
 
24
22
  from absl.testing import absltest as googletest
25
23
 
26
24
 
27
- class TestExternalKVLayers(googletest.TestCase):
25
+ class TestKVLayers(googletest.TestCase):
28
26
 
29
27
  def _get_test_config(
30
28
  self, num_layers, head_dim, num_query_groups, kv_cache_max_len
@@ -54,7 +52,7 @@ class TestExternalKVLayers(googletest.TestCase):
54
52
  num_query_groups=NUM_QG,
55
53
  kv_cache_max_len=KV_LEN,
56
54
  )
57
- kv = kv_utils.EKVCache.from_model_config(config)
55
+ kv = kv_utils.KVCache.from_model_config(config)
58
56
  entry = kv.caches[0]
59
57
  # single-slice update
60
58
  input_pos = torch.tensor([1])
@@ -88,14 +86,14 @@ class TestExternalKVLayers(googletest.TestCase):
88
86
  def test_serialization(self):
89
87
  class TestModel(torch.nn.Module):
90
88
 
91
- def forward(self, kv: kv_utils.EKVCache) -> kv_utils.EKVCache:
89
+ def forward(self, kv: kv_utils.KVCache) -> kv_utils.KVCache:
92
90
  updated_kv_entries = [
93
91
  kv_utils.KVCacheEntry(
94
92
  torch.zeros_like(entry.k_cache), torch.zeros_like(entry.v_cache)
95
93
  )
96
94
  for entry in kv.caches
97
95
  ]
98
- return kv_utils.EKVCache(updated_kv_entries)
96
+ return kv_utils.KVCache(updated_kv_entries)
99
97
 
100
98
  N = 1
101
99
  HEAD_DIM = 2
@@ -107,7 +105,7 @@ class TestExternalKVLayers(googletest.TestCase):
107
105
  num_query_groups=NUM_QG,
108
106
  kv_cache_max_len=KV_LEN,
109
107
  )
110
- kv = kv_utils.EKVCache.from_model_config(config)
108
+ kv = kv_utils.KVCache.from_model_config(config)
111
109
  model = TestModel()
112
110
  exported_program = torch.export.export(model, (kv,))
113
111
  input_specs = exported_program.graph_signature.input_specs
@@ -116,17 +114,5 @@ class TestExternalKVLayers(googletest.TestCase):
116
114
  self.assertEqual(input_specs[1].arg.name, "kv_v_0")
117
115
 
118
116
 
119
- class TestExternalKVModels(googletest.TestCase):
120
-
121
- def test_can_build_gemma(self):
122
- gemma.define_and_run_2b(checkpoint_path=None, test_model=True)
123
-
124
- def test_can_build_phi2(self):
125
- phi2.define_and_run(checkpoint_path=None, test_model=True)
126
-
127
- def test_can_build_tinyllama(self):
128
- tiny_llama.define_and_run(checkpoint_path=None, test_model=True)
129
-
130
-
131
117
  if __name__ == "__main__":
132
118
  googletest.main()
@@ -12,16 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Testing model conversion for a few gen-ai models.
16
- import copy
15
+
16
+ """Testing model conversion for a few gen-ai models."""
17
17
 
18
18
  import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
- from ai_edge_torch.generative.examples.gemma import gemma, gemma2
21
- from ai_edge_torch.generative.examples.phi2 import phi2
22
- from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
20
+ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
23
21
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
- from ai_edge_torch.testing import model_coverage
22
+ from ai_edge_torch.generative.layers import kv_cache
23
+ from ai_edge_torch.generative.test import utils as test_utils
25
24
  import numpy as np
26
25
  import torch
27
26
 
@@ -49,22 +48,32 @@ class TestModelConversion(googletest.TestCase):
49
48
  )
50
49
  def test_toy_model_with_kv_cache(self):
51
50
  config = toy_model_with_kv_cache.get_model_config()
52
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
53
- idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
51
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
52
+ tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
54
53
  [10], dtype=torch.int64
55
54
  )
56
-
57
- edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
55
+ kv = kv_cache.KVCache.from_model_config(config)
56
+
57
+ edge_model = ai_edge_torch.convert(
58
+ pytorch_model,
59
+ sample_kwargs={
60
+ "tokens": tokens,
61
+ "input_pos": input_pos,
62
+ "kv_cache": kv,
63
+ },
64
+ )
58
65
  edge_model.set_interpreter_builder(
59
66
  self._interpreter_builder(edge_model.tflite_model())
60
67
  )
61
68
 
62
69
  self.assertTrue(
63
- model_coverage.compare_tflite_torch(
70
+ test_utils.compare_tflite_torch(
64
71
  edge_model,
65
72
  pytorch_model,
66
- (idx, input_pos),
67
- num_valid_inputs=1,
73
+ tokens,
74
+ input_pos,
75
+ kv,
76
+ signature_name="serving_default",
68
77
  atol=1e-5,
69
78
  rtol=1e-5,
70
79
  )
@@ -77,22 +86,32 @@ class TestModelConversion(googletest.TestCase):
77
86
  def test_toy_model_with_kv_cache_with_hlfb(self):
78
87
  config = toy_model_with_kv_cache.get_model_config()
79
88
  config.enable_hlfb = True
80
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
81
- idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
89
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
90
+ tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
82
91
  [10], dtype=torch.int64
83
92
  )
84
-
85
- edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
93
+ kv = kv_cache.KVCache.from_model_config(config)
94
+
95
+ edge_model = ai_edge_torch.convert(
96
+ pytorch_model,
97
+ sample_kwargs={
98
+ "tokens": tokens,
99
+ "input_pos": input_pos,
100
+ "kv_cache": kv,
101
+ },
102
+ )
86
103
  edge_model.set_interpreter_builder(
87
104
  self._interpreter_builder(edge_model.tflite_model())
88
105
  )
89
106
 
90
107
  self.assertTrue(
91
- model_coverage.compare_tflite_torch(
108
+ test_utils.compare_tflite_torch(
92
109
  edge_model,
93
110
  pytorch_model,
94
- (idx, input_pos),
95
- num_valid_inputs=1,
111
+ tokens,
112
+ input_pos,
113
+ kv,
114
+ signature_name="serving_default",
96
115
  atol=1e-5,
97
116
  rtol=1e-5,
98
117
  )
@@ -117,37 +136,56 @@ class TestModelConversion(googletest.TestCase):
117
136
  decode_token = torch.tensor([[1]], dtype=torch.long)
118
137
  decode_input_pos = torch.tensor([5], dtype=torch.int64)
119
138
 
139
+ kv = kv_cache.KVCache.from_model_config(config)
140
+
120
141
  edge_model = (
121
142
  ai_edge_torch.signature(
122
- "prefill", pytorch_model, (prefill_tokens, prefill_input_pos)
143
+ "prefill",
144
+ pytorch_model,
145
+ sample_kwargs={
146
+ "tokens": prefill_tokens,
147
+ "input_pos": prefill_input_pos,
148
+ "kv_cache": kv,
149
+ },
150
+ )
151
+ .signature(
152
+ "decode",
153
+ pytorch_model,
154
+ sample_kwargs={
155
+ "tokens": decode_token,
156
+ "input_pos": decode_input_pos,
157
+ "kv_cache": kv,
158
+ },
123
159
  )
124
- .signature("decode", pytorch_model, (decode_token, decode_input_pos))
125
160
  .convert()
126
161
  )
127
162
  edge_model.set_interpreter_builder(
128
163
  self._interpreter_builder(edge_model.tflite_model())
129
164
  )
130
165
 
131
- copied_model = copy.deepcopy(pytorch_model)
132
- copied_edge = copy.deepcopy(edge_model)
133
-
134
166
  self.assertTrue(
135
- model_coverage.compare_tflite_torch(
167
+ test_utils.compare_tflite_torch(
136
168
  edge_model,
137
169
  pytorch_model,
138
- (prefill_tokens, prefill_input_pos),
170
+ prefill_tokens,
171
+ prefill_input_pos,
172
+ kv,
139
173
  signature_name="prefill",
140
- num_valid_inputs=1,
174
+ atol=1e-5,
175
+ rtol=1e-5,
141
176
  )
142
177
  )
143
178
 
144
179
  self.assertTrue(
145
- model_coverage.compare_tflite_torch(
146
- copied_edge,
147
- copied_model,
148
- (decode_token, decode_input_pos),
180
+ test_utils.compare_tflite_torch(
181
+ edge_model,
182
+ pytorch_model,
183
+ decode_token,
184
+ decode_input_pos,
185
+ kv,
149
186
  signature_name="decode",
150
- num_valid_inputs=1,
187
+ atol=1e-5,
188
+ rtol=1e-5,
151
189
  )
152
190
  )
153
191
 
@@ -12,16 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Testing model conversion for a few gen-ai models.
16
- import copy
15
+
16
+ """Testing model conversion for a few gen-ai models."""
17
17
 
18
18
  import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
- from ai_edge_torch.generative.examples.gemma import gemma, gemma2
21
- from ai_edge_torch.generative.examples.phi2 import phi2
22
- from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
23
- from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
- from ai_edge_torch.testing import model_coverage
20
+ from ai_edge_torch.generative.examples.gemma import gemma
21
+ from ai_edge_torch.generative.examples.gemma import gemma2
22
+ from ai_edge_torch.generative.examples.phi import phi2
23
+ from ai_edge_torch.generative.layers import kv_cache
24
+ from ai_edge_torch.generative.test import utils as test_utils
25
25
  import numpy as np
26
26
  import torch
27
27
 
@@ -55,18 +55,28 @@ class TestModelConversion(googletest.TestCase):
55
55
  tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
56
56
  tokens[0, :4] = idx
57
57
  input_pos = torch.arange(0, 10)
58
-
59
- edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
58
+ kv = kv_cache.KVCache.from_model_config(config)
59
+
60
+ edge_model = ai_edge_torch.convert(
61
+ model,
62
+ sample_kwargs={
63
+ "tokens": tokens,
64
+ "input_pos": input_pos,
65
+ "kv_cache": kv,
66
+ },
67
+ )
60
68
  edge_model.set_interpreter_builder(
61
69
  self._interpreter_builder(edge_model.tflite_model())
62
70
  )
63
71
 
64
72
  self.assertTrue(
65
- model_coverage.compare_tflite_torch(
73
+ test_utils.compare_tflite_torch(
66
74
  edge_model,
67
75
  model,
68
- (tokens, input_pos),
69
- num_valid_inputs=1,
76
+ tokens,
77
+ input_pos,
78
+ kv,
79
+ signature_name="serving_default",
70
80
  atol=1e-2,
71
81
  rtol=1e-5,
72
82
  )
@@ -85,23 +95,31 @@ class TestModelConversion(googletest.TestCase):
85
95
  prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
86
96
  prefill_tokens[0, :4] = idx
87
97
  prefill_input_pos = torch.arange(0, 10)
98
+ kv = kv_cache.KVCache.from_model_config(config)
88
99
 
89
100
  edge_model = ai_edge_torch.signature(
90
- "prefill", model, (prefill_tokens, prefill_input_pos)
101
+ "prefill",
102
+ model,
103
+ sample_kwargs={
104
+ "tokens": prefill_tokens,
105
+ "input_pos": prefill_input_pos,
106
+ "kv_cache": kv,
107
+ },
91
108
  ).convert()
92
109
  edge_model.set_interpreter_builder(
93
110
  self._interpreter_builder(edge_model.tflite_model())
94
111
  )
95
112
 
96
113
  self.assertTrue(
97
- model_coverage.compare_tflite_torch(
114
+ test_utils.compare_tflite_torch(
98
115
  edge_model,
99
116
  model,
100
- (prefill_tokens, prefill_input_pos),
117
+ prefill_tokens,
118
+ prefill_input_pos,
119
+ kv,
101
120
  signature_name="prefill",
102
- num_valid_inputs=1,
103
- atol=1e-2,
104
- rtol=1e-5,
121
+ atol=1e-1,
122
+ rtol=1e-3,
105
123
  )
106
124
  )
107
125
 
@@ -117,18 +135,28 @@ class TestModelConversion(googletest.TestCase):
117
135
  tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
118
136
  tokens[0, :4] = idx
119
137
  input_pos = torch.arange(0, 10)
120
-
121
- edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
138
+ kv = kv_cache.KVCache.from_model_config(config)
139
+
140
+ edge_model = ai_edge_torch.convert(
141
+ pytorch_model,
142
+ sample_kwargs={
143
+ "tokens": tokens,
144
+ "input_pos": input_pos,
145
+ "kv_cache": kv,
146
+ },
147
+ )
122
148
  edge_model.set_interpreter_builder(
123
149
  self._interpreter_builder(edge_model.tflite_model())
124
150
  )
125
151
 
126
152
  self.assertTrue(
127
- model_coverage.compare_tflite_torch(
153
+ test_utils.compare_tflite_torch(
128
154
  edge_model,
129
155
  pytorch_model,
130
- (tokens, input_pos),
131
- num_valid_inputs=1,
156
+ tokens,
157
+ input_pos,
158
+ kv,
159
+ signature_name="serving_default",
132
160
  atol=1e-3,
133
161
  rtol=1e-3,
134
162
  )
@@ -0,0 +1,54 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Common utils for testing."""
17
+
18
+ from ai_edge_torch import model
19
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
20
+ from ai_edge_torch.lowertools import common_utils
21
+ import numpy as np
22
+ import torch
23
+ from torch.utils import _pytree as pytree
24
+
25
+
26
+ def compare_tflite_torch(
27
+ edge_model: model.Model,
28
+ torch_model: torch.nn.Module,
29
+ tokens: torch.Tensor,
30
+ input_pos: torch.Tensor,
31
+ kv_cache: kv_utils.KVCache,
32
+ signature_name: str,
33
+ atol: float = 1e-5,
34
+ rtol: float = 1e-5,
35
+ ):
36
+ """Compares torch models and TFLite models."""
37
+ values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
38
+ flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
39
+ torch_output = torch_model(tokens, input_pos, kv_cache)
40
+
41
+ input_kv_flatten = {k: v.numpy() for k, v in zip(flat_names, values)}
42
+ edge_output = edge_model(
43
+ signature_name=signature_name,
44
+ tokens=tokens.numpy(),
45
+ input_pos=input_pos.numpy(),
46
+ **input_kv_flatten,
47
+ )
48
+
49
+ return np.allclose(
50
+ edge_output["logits"],
51
+ torch_output["logits"].detach().numpy(),
52
+ atol=atol,
53
+ rtol=rtol,
54
+ )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240910"
16
+ __version__ = "0.3.0.dev20240911"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240910
3
+ Version: 0.3.0.dev20240911
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI