ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240912__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  3. ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
  4. ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
  5. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
  7. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  8. ai_edge_torch/generative/examples/smallm/smallm.py +119 -0
  9. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
  10. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  11. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +40 -24
  12. ai_edge_torch/generative/layers/attention.py +60 -63
  13. ai_edge_torch/generative/layers/builder.py +4 -2
  14. ai_edge_torch/generative/layers/kv_cache.py +160 -51
  15. ai_edge_torch/generative/layers/model_config.py +1 -0
  16. ai_edge_torch/generative/layers/normalization.py +158 -0
  17. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  18. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
  19. ai_edge_torch/generative/test/test_loader.py +1 -1
  20. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  21. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  22. ai_edge_torch/generative/test/utils.py +54 -0
  23. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  24. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  25. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  26. ai_edge_torch/version.py +1 -1
  27. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/METADATA +1 -1
  28. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/RECORD +33 -39
  29. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  30. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  31. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  32. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  33. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  34. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  35. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  36. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  37. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  38. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  39. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  40. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/LICENSE +0 -0
  41. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/WHEEL +0 -0
  42. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/top_level.txt +0 -0
@@ -13,32 +13,35 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting a Gemma2 model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.gemma import gemma2
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
24
27
 
25
- def convert_gemma_to_tflite(
28
+ def convert_gemma2_to_tflite(
26
29
  checkpoint_path: str,
27
30
  prefill_seq_len: int = 512,
28
31
  kv_cache_max_len: int = 1024,
29
32
  quantize: bool = True,
30
33
  ):
31
- """Converting a Gemma 2 2B model to multi-signature
32
- tflite model.
34
+ """Converts a Gemma2 2B model to multi-signature tflite model.
33
35
 
34
36
  Args:
35
- checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
37
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
38
+ holding the checkpoint.
36
39
  prefill_seq_len (int, optional): The maximum size of prefill input tensor.
37
40
  Defaults to 512.
38
41
  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
39
42
  including both prefill and decode. Defaults to 1024.
40
- quantize (bool, optional): Whether the model should be quanized.
41
- Defaults to True.
43
+ quantize (bool, optional): Whether the model should be quanized. Defaults
44
+ to True.
42
45
  """
43
46
  pytorch_model = gemma2.build_2b_model(
44
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/gemma2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma2-2b')
67
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
86
+ convert_gemma2_to_tflite(path)
@@ -13,11 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting a Gemma model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.gemma import gemma
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/gemma_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
67
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
86
+ convert_gemma_to_tflite(path)
@@ -12,13 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building a Gemma model.
15
+
16
+ """Example of building a Gemma model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
 
20
21
  from ai_edge_torch.generative.layers import attention
21
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
25
  import ai_edge_torch.generative.layers.model_config as cfg
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -84,16 +86,22 @@ class Gemma(nn.Module):
84
86
  )
85
87
  self.config = config
86
88
 
87
- # The model's forward function takes in additional k/v cache tensors
88
- # and returns the updated k/v cache tensors to the caller.
89
- # This can be eliminated if we handle k/v cache updates inside the model itself.
90
89
  @torch.inference_mode
91
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
92
- _, seq_len = idx.size()
90
+ def forward(
91
+ self,
92
+ tokens: torch.Tensor,
93
+ input_pos: torch.Tensor,
94
+ kv_cache: kv_utils.KVCache,
95
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
96
+ _, seq_len = tokens.size()
93
97
  assert self.config.max_seq_len >= seq_len, (
94
98
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
95
99
  f" {self.config.max_seq_len}"
96
100
  )
101
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
102
+ "The number of transformer blocks and the number of KV cache entries"
103
+ " must be the same."
104
+ )
97
105
 
98
106
  cos, sin = self.rope_cache
99
107
  cos = cos.index_select(0, input_pos)
@@ -102,15 +110,20 @@ class Gemma(nn.Module):
102
110
  mask = mask[:, :, :, : self.config.kv_cache_max]
103
111
 
104
112
  # token embeddings of shape (b, t, n_embd)
105
- x = self.tok_embedding(idx)
113
+ x = self.tok_embedding(tokens)
106
114
  x = x * (self.config.embedding_dim**0.5)
107
115
 
108
- for _, block in enumerate(self.transformer_blocks):
109
- x = block(x, (cos, sin), mask, input_pos)
116
+ updated_kv_entires = []
117
+ for i, block in enumerate(self.transformer_blocks):
118
+ kv_entry = kv_cache.caches[i] if kv_cache else None
119
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
120
+ if kv_entry:
121
+ updated_kv_entires.append(kv_entry)
122
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
110
123
 
111
124
  x = self.final_norm(x)
112
- res = self.lm_head(x) # (b, t, vocab_size)
113
- return res
125
+ logits = self.lm_head(x) # (b, t, vocab_size)
126
+ return {"logits": logits, "kv_cache": updated_kv_cache}
114
127
 
115
128
 
116
129
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -177,25 +190,28 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
177
190
  return model
178
191
 
179
192
 
180
- def define_and_run_2b() -> None:
193
+ def define_and_run_2b(checkpoint_path: str) -> None:
181
194
  """Instantiates and runs a Gemma 2B model."""
182
195
 
183
- current_dir = Path(__file__).parent.resolve()
196
+ current_dir = pathlib.Path(__file__).parent.resolve()
184
197
  gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
185
198
 
186
199
  kv_cache_max_len = 1024
187
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
188
200
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
189
201
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
190
202
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
191
203
  tokens[0, :4] = idx
192
204
  input_pos = torch.arange(0, kv_cache_max_len)
193
- lm_logits = model.forward(tokens, input_pos)
205
+ kv = kv_utils.KVCache.from_model_config(model.config)
206
+ output = model.forward(tokens, input_pos, kv)
194
207
  print("comparing with goldens..")
195
208
  assert torch.allclose(
196
- gemma_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
209
+ gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
197
210
  )
198
211
 
199
212
 
200
213
  if __name__ == "__main__":
201
- define_and_run_2b()
214
+ input_checkpoint_path = os.path.join(
215
+ pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
216
+ )
217
+ define_and_run_2b(input_checkpoint_path)
@@ -12,14 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building the Gemma2 2B model.
15
+
16
+ """Example of building a Gemma2 model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
  from typing import Optional, Tuple
20
21
 
21
22
  from ai_edge_torch.generative.layers import attention
22
23
  from ai_edge_torch.generative.layers import builder
24
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
26
  import ai_edge_torch.generative.layers.model_config as cfg
25
27
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -51,7 +53,8 @@ class Gemma2Block(attention.TransformerBlock):
51
53
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
52
54
  mask: Optional[torch.Tensor] = None,
53
55
  input_pos: Optional[torch.Tensor] = None,
54
- ) -> torch.Tensor:
56
+ kv_cache: kv_utils.KVCacheEntry = None,
57
+ ) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
55
58
  """Forward function of the Gemma2Block.
56
59
 
57
60
  Exactly the same as TransformerBlock but we call the post-attention norm
@@ -62,17 +65,19 @@ class Gemma2Block(attention.TransformerBlock):
62
65
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
63
66
  mask (torch.Tensor): the optional mask tensor.
64
67
  input_pos (torch.Tensor): the optional input position tensor.
68
+ kv_cache (KVCacheEntry): the optional kv cache entry.
65
69
 
66
70
  Returns:
67
- output activation from this transformer block.
71
+ output activation from this transformer block, and updated kv cache (if
72
+ passed in).
68
73
  """
69
74
 
70
75
  x_norm = self.pre_atten_norm(x)
71
- attn_out = self.atten_func(x_norm, rope, mask, input_pos)
76
+ attn_out, kv = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
72
77
  attn_out_norm = self.post_atten_norm(attn_out)
73
78
  x = x + attn_out_norm
74
79
  output = x + self.ff(x)
75
- return output
80
+ return output, kv
76
81
 
77
82
 
78
83
  class Gemma2(nn.Module):
@@ -138,24 +143,38 @@ class Gemma2(nn.Module):
138
143
  return self.mask_cache.index_select(2, input_pos)
139
144
 
140
145
  @torch.inference_mode
141
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
142
- _, seq_len = idx.size()
146
+ def forward(
147
+ self,
148
+ tokens: torch.Tensor,
149
+ input_pos: torch.Tensor,
150
+ kv_cache: kv_utils.KVCache,
151
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
152
+ _, seq_len = tokens.size()
143
153
  assert self.config.max_seq_len >= seq_len, (
144
154
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
145
155
  f" {self.config.max_seq_len}"
146
156
  )
157
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
158
+ "The number of transformer blocks and the number of KV cache entries"
159
+ " must be the same."
160
+ )
147
161
 
148
162
  cos, sin = self.rope_cache
149
163
  cos = cos.index_select(0, input_pos)
150
164
  sin = sin.index_select(0, input_pos)
151
165
 
152
166
  # token embeddings of shape (b, t, n_embd)
153
- x = self.tok_embedding(idx)
167
+ x = self.tok_embedding(tokens)
154
168
  x = x * (self.config.embedding_dim**0.5)
155
169
 
170
+ updated_kv_entires = []
156
171
  for i, block in enumerate(self.transformer_blocks):
157
172
  mask = self.get_attention_mask(i, input_pos)
158
- x = block(x, (cos, sin), mask, input_pos)
173
+ kv_entry = kv_cache.caches[i] if kv_cache else None
174
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
175
+ if kv_entry:
176
+ updated_kv_entires.append(kv_entry)
177
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
159
178
 
160
179
  x = self.final_norm(x)
161
180
  res = self.lm_head(x) # (b, t, vocab_size)
@@ -163,7 +182,8 @@ class Gemma2(nn.Module):
163
182
  res = res / self.config.final_logit_softcap
164
183
  res = torch.tanh(res)
165
184
  res = res * self.config.final_logit_softcap
166
- return res
185
+
186
+ return {"logits": res, "kv_cache": updated_kv_cache}
167
187
 
168
188
 
169
189
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -243,14 +263,13 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
243
263
  return model
244
264
 
245
265
 
246
- def define_and_run_2b() -> None:
266
+ def define_and_run_2b(checkpoint_path: str) -> None:
247
267
  """Instantiates and runs a Gemma2 2B model."""
248
268
 
249
- current_dir = Path(__file__).parent.resolve()
269
+ current_dir = pathlib.Path(__file__).parent.resolve()
250
270
  gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
251
271
  print("Running GEMMA 2")
252
272
  kv_cache_max_len = 1024
253
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma2-2b")
254
273
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
255
274
  toks = torch.from_numpy(
256
275
  np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
@@ -258,11 +277,13 @@ def define_and_run_2b() -> None:
258
277
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
259
278
  tokens[0, :9] = toks
260
279
  input_pos = torch.arange(0, kv_cache_max_len)
261
- out = model.forward(tokens, input_pos)
262
- out_final = out[0, 8, :]
280
+ kv = kv_utils.KVCache.from_model_config(model.config)
281
+ out = model.forward(tokens, input_pos, kv)
282
+ out_final = out["logits"][0, 8, :]
263
283
  assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
264
284
 
265
285
 
266
286
  if __name__ == "__main__":
267
287
  torch.set_printoptions(sci_mode=True)
268
- define_and_run_2b()
288
+ path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
289
+ define_and_run_2b(path)
@@ -12,16 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- #
16
- # Note: This is an experimental version of phi2 with external KV cache.
17
- # Please use with caution.
15
+
16
+ """Example of converting a Phi-2 model to multi-signature tflite model."""
18
17
 
19
18
  import os
20
- from pathlib import Path
19
+ import pathlib
21
20
 
22
21
  import ai_edge_torch
23
- from ai_edge_torch.generative.examples.experimental.phi import phi2
24
- from ai_edge_torch.generative.layers.experimental import ekv_cache
22
+ from ai_edge_torch.generative.examples.phi import phi2
23
+ from ai_edge_torch.generative.layers import kv_cache
25
24
  from ai_edge_torch.generative.quantize import quant_recipes
26
25
  import torch
27
26
 
@@ -32,9 +31,8 @@ def convert_phi2_to_tflite(
32
31
  kv_cache_max_len: int = 1024,
33
32
  quantize: bool = True,
34
33
  ):
35
- """An example method for converting a Phi-2 model to multi-signature
34
+ """Converts a Phi-2 model to multi-signature tflite model.
36
35
 
37
- tflite model.
38
36
  Args:
39
37
  checkpoint_path (str): The filepath to the model checkpoint, or directory
40
38
  holding the checkpoint.
@@ -53,7 +51,7 @@ def convert_phi2_to_tflite(
53
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
54
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
55
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
56
- kv = ekv_cache.EKVCache.from_model_config(pytorch_model.config)
54
+ kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
57
55
 
58
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
59
57
  edge_model = (
@@ -77,11 +75,12 @@ def convert_phi2_to_tflite(
77
75
  )
78
76
  .convert(quant_config=quant_config)
79
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
80
79
  edge_model.export(
81
- f'/tmp/phi2_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
80
+ f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
82
81
  )
83
82
 
84
83
 
85
84
  if __name__ == '__main__':
86
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/phi2')
87
- convert_phi2_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2')
86
+ convert_phi2_to_tflite(path)
@@ -12,26 +12,22 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building phi-2 model from the Edge Generative API layers.
16
- #
17
- # Note: This is an experimental version of phi2 with external KV cache.
18
- # Please use with caution.
15
+
16
+ """Example of building a Phi-2 model."""
19
17
 
20
18
  import os
21
- from pathlib import Path
22
- from typing import Tuple
19
+ import pathlib
23
20
 
21
+ from ai_edge_torch.generative.layers import attention
24
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
- from ai_edge_torch.generative.layers.experimental import attention
27
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
25
  import ai_edge_torch.generative.layers.model_config as cfg
29
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
30
27
  import numpy as np
31
28
  import torch
32
29
  from torch import nn
33
30
 
34
-
35
31
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
36
32
  ff_up_proj="model.layers.{}.mlp.fc1",
37
33
  ff_down_proj="model.layers.{}.mlp.fc2",
@@ -89,13 +85,17 @@ class Phi2(nn.Module):
89
85
  self,
90
86
  tokens: torch.Tensor,
91
87
  input_pos: torch.Tensor,
92
- kv_cache: kv_utils.EKVCache,
93
- ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
88
+ kv_cache: kv_utils.KVCache,
89
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
94
90
  _, seq_len = tokens.size()
95
91
  assert self.config.max_seq_len >= seq_len, (
96
92
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
97
93
  f" {self.config.max_seq_len}"
98
94
  )
95
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
96
+ "The number of transformer blocks and the number of KV cache entries"
97
+ " must be the same."
98
+ )
99
99
 
100
100
  cos, sin = self.rope_cache
101
101
  cos = cos.index_select(0, input_pos)
@@ -111,11 +111,11 @@ class Phi2(nn.Module):
111
111
  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
112
112
  if kv_entry:
113
113
  updated_kv_entires.append(kv_entry)
114
- updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
114
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
115
115
 
116
116
  x = self.final_norm(x)
117
- res = self.lm_head(x) # (b, t, vocab_size)
118
- return res, updated_kv_cache
117
+ logits = self.lm_head(x) # (b, t, vocab_size)
118
+ return {"logits": logits, "kv_cache": updated_kv_cache}
119
119
 
120
120
 
121
121
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -169,39 +169,37 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
169
169
  return config
170
170
 
171
171
 
172
- def build_model(
173
- checkpoint_path: str, test_model: bool = False, **kwargs
174
- ) -> nn.Module:
172
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
175
173
  """Instantiates the model instance and load checkpoint if provided."""
176
- config = (
177
- get_fake_model_config(**kwargs)
178
- if test_model
179
- else get_model_config(**kwargs)
180
- )
174
+ config = get_model_config(**kwargs)
181
175
  model = Phi2(config)
182
- if checkpoint_path is not None:
183
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
184
- loader.load(model)
176
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
177
+ loader.load(model)
185
178
  model.eval()
186
179
  return model
187
180
 
188
181
 
189
- def define_and_run(checkpoint_path: str, test_model: bool = False) -> None:
182
+ def define_and_run(checkpoint_path: str) -> None:
190
183
  """Instantiates and runs a Phi-2 model."""
191
184
 
185
+ current_dir = pathlib.Path(__file__).parent.resolve()
186
+ phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
192
187
  kv_cache_max_len = 1024
193
- model = build_model(
194
- checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
195
- )
188
+ model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
196
189
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
197
190
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
198
191
  tokens[0, :4] = idx
199
192
  input_pos = torch.arange(0, kv_cache_max_len)
200
- kv = kv_utils.EKVCache.from_model_config(model.config)
201
- print("running an inference")
202
- print(model.forward(tokens, input_pos, kv))
193
+ kv = kv_utils.KVCache.from_model_config(model.config)
194
+ output = model.forward(tokens, input_pos, kv)
195
+ print("comparing with goldens..")
196
+ assert torch.allclose(
197
+ phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
198
+ )
203
199
 
204
200
 
205
201
  if __name__ == "__main__":
206
- input_checkpoint_path = os.path.join(Path.home(), "Downloads/phi2")
202
+ input_checkpoint_path = os.path.join(
203
+ pathlib.Path.home(), "Downloads/llm_data/phi2"
204
+ )
207
205
  define_and_run(input_checkpoint_path)
@@ -12,30 +12,27 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- #
16
- # Note: This is an experimental version of Gemma with external KV cache.
17
- # Please use with caution.
18
15
 
16
+ """Example of converting SmalLM model to multi-signature tflite model."""
19
17
 
20
18
  import os
21
- from pathlib import Path
19
+ import pathlib
22
20
 
23
21
  import ai_edge_torch
24
- from ai_edge_torch.generative.examples.experimental.gemma import gemma
25
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
22
+ from ai_edge_torch.generative.examples.smallm import smallm
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
26
24
  from ai_edge_torch.generative.quantize import quant_recipes
27
25
  import torch
28
26
 
29
27
 
30
- def convert_gemma_to_tflite(
28
+ def convert_smallm_to_tflite(
31
29
  checkpoint_path: str,
32
30
  prefill_seq_len: int = 512,
33
31
  kv_cache_max_len: int = 1024,
34
32
  quantize: bool = True,
35
33
  ):
36
- """An example method for converting a Gemma 2B model to multi-signature
34
+ """Converts SmalLM model to multi-signature tflite model.
37
35
 
38
- tflite model.
39
36
  Args:
40
37
  checkpoint_path (str): The filepath to the model checkpoint, or directory
41
38
  holding the checkpoint.
@@ -46,7 +43,7 @@ def convert_gemma_to_tflite(
46
43
  quantize (bool, optional): Whether the model should be quanized. Defaults
47
44
  to True.
48
45
  """
49
- pytorch_model = gemma.build_2b_model(
46
+ pytorch_model = smallm.build_model(
50
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
51
48
  )
52
49
  # Tensors used to trace the model graph during conversion.
@@ -54,7 +51,7 @@ def convert_gemma_to_tflite(
54
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
55
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
56
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
57
- kv = kv_utils.EKVCache.from_model_config(pytorch_model.config)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
58
55
 
59
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
60
57
  edge_model = (
@@ -78,11 +75,12 @@ def convert_gemma_to_tflite(
78
75
  )
79
76
  .convert(quant_config=quant_config)
80
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
81
79
  edge_model.export(
82
- f'/tmp/gemma_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
80
+ f'/tmp/smallm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
83
81
  )
84
82
 
85
83
 
86
84
  if __name__ == '__main__':
87
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
88
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smallm')
86
+ convert_smallm_to_tflite(path)