ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240911__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (33) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  3. ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
  4. ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
  5. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
  7. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
  8. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  9. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +38 -22
  10. ai_edge_torch/generative/layers/attention.py +60 -63
  11. ai_edge_torch/generative/layers/kv_cache.py +160 -51
  12. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
  13. ai_edge_torch/generative/test/test_model_conversion.py +71 -33
  14. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  15. ai_edge_torch/generative/test/utils.py +54 -0
  16. ai_edge_torch/version.py +1 -1
  17. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/METADATA +1 -1
  18. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/RECORD +22 -32
  19. ai_edge_torch/generative/examples/experimental/gemma/__init__.py +0 -14
  20. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +0 -88
  21. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  22. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  23. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  24. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  25. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  26. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  27. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  28. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  29. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  30. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  31. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/LICENSE +0 -0
  32. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/WHEEL +0 -0
  33. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/top_level.txt +0 -0
@@ -13,32 +13,35 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting a Gemma2 model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.gemma import gemma2
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
24
27
 
25
- def convert_gemma_to_tflite(
28
+ def convert_gemma2_to_tflite(
26
29
  checkpoint_path: str,
27
30
  prefill_seq_len: int = 512,
28
31
  kv_cache_max_len: int = 1024,
29
32
  quantize: bool = True,
30
33
  ):
31
- """Converting a Gemma 2 2B model to multi-signature
32
- tflite model.
34
+ """Converts a Gemma2 2B model to multi-signature tflite model.
33
35
 
34
36
  Args:
35
- checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
37
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
38
+ holding the checkpoint.
36
39
  prefill_seq_len (int, optional): The maximum size of prefill input tensor.
37
40
  Defaults to 512.
38
41
  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
39
42
  including both prefill and decode. Defaults to 1024.
40
- quantize (bool, optional): Whether the model should be quanized.
41
- Defaults to True.
43
+ quantize (bool, optional): Whether the model should be quanized. Defaults
44
+ to True.
42
45
  """
43
46
  pytorch_model = gemma2.build_2b_model(
44
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/gemma2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma2-2b')
67
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
86
+ convert_gemma2_to_tflite(path)
@@ -13,11 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting a Gemma model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.gemma import gemma
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/gemma_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
67
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
86
+ convert_gemma_to_tflite(path)
@@ -12,13 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building a Gemma model.
15
+
16
+ """Example of building a Gemma model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
 
20
21
  from ai_edge_torch.generative.layers import attention
21
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
25
  import ai_edge_torch.generative.layers.model_config as cfg
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -84,16 +86,22 @@ class Gemma(nn.Module):
84
86
  )
85
87
  self.config = config
86
88
 
87
- # The model's forward function takes in additional k/v cache tensors
88
- # and returns the updated k/v cache tensors to the caller.
89
- # This can be eliminated if we handle k/v cache updates inside the model itself.
90
89
  @torch.inference_mode
91
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
92
- _, seq_len = idx.size()
90
+ def forward(
91
+ self,
92
+ tokens: torch.Tensor,
93
+ input_pos: torch.Tensor,
94
+ kv_cache: kv_utils.KVCache,
95
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
96
+ _, seq_len = tokens.size()
93
97
  assert self.config.max_seq_len >= seq_len, (
94
98
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
95
99
  f" {self.config.max_seq_len}"
96
100
  )
101
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
102
+ "The number of transformer blocks and the number of KV cache entries"
103
+ " must be the same."
104
+ )
97
105
 
98
106
  cos, sin = self.rope_cache
99
107
  cos = cos.index_select(0, input_pos)
@@ -102,15 +110,20 @@ class Gemma(nn.Module):
102
110
  mask = mask[:, :, :, : self.config.kv_cache_max]
103
111
 
104
112
  # token embeddings of shape (b, t, n_embd)
105
- x = self.tok_embedding(idx)
113
+ x = self.tok_embedding(tokens)
106
114
  x = x * (self.config.embedding_dim**0.5)
107
115
 
108
- for _, block in enumerate(self.transformer_blocks):
109
- x = block(x, (cos, sin), mask, input_pos)
116
+ updated_kv_entires = []
117
+ for i, block in enumerate(self.transformer_blocks):
118
+ kv_entry = kv_cache.caches[i] if kv_cache else None
119
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
120
+ if kv_entry:
121
+ updated_kv_entires.append(kv_entry)
122
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
110
123
 
111
124
  x = self.final_norm(x)
112
- res = self.lm_head(x) # (b, t, vocab_size)
113
- return res
125
+ logits = self.lm_head(x) # (b, t, vocab_size)
126
+ return {"logits": logits, "kv_cache": updated_kv_cache}
114
127
 
115
128
 
116
129
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -177,25 +190,28 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
177
190
  return model
178
191
 
179
192
 
180
- def define_and_run_2b() -> None:
193
+ def define_and_run_2b(checkpoint_path: str) -> None:
181
194
  """Instantiates and runs a Gemma 2B model."""
182
195
 
183
- current_dir = Path(__file__).parent.resolve()
196
+ current_dir = pathlib.Path(__file__).parent.resolve()
184
197
  gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
185
198
 
186
199
  kv_cache_max_len = 1024
187
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
188
200
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
189
201
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
190
202
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
191
203
  tokens[0, :4] = idx
192
204
  input_pos = torch.arange(0, kv_cache_max_len)
193
- lm_logits = model.forward(tokens, input_pos)
205
+ kv = kv_utils.KVCache.from_model_config(model.config)
206
+ output = model.forward(tokens, input_pos, kv)
194
207
  print("comparing with goldens..")
195
208
  assert torch.allclose(
196
- gemma_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
209
+ gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
197
210
  )
198
211
 
199
212
 
200
213
  if __name__ == "__main__":
201
- define_and_run_2b()
214
+ input_checkpoint_path = os.path.join(
215
+ pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
216
+ )
217
+ define_and_run_2b(input_checkpoint_path)
@@ -12,14 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building the Gemma2 2B model.
15
+
16
+ """Example of building a Gemma2 model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
  from typing import Optional, Tuple
20
21
 
21
22
  from ai_edge_torch.generative.layers import attention
22
23
  from ai_edge_torch.generative.layers import builder
24
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
26
  import ai_edge_torch.generative.layers.model_config as cfg
25
27
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -51,7 +53,8 @@ class Gemma2Block(attention.TransformerBlock):
51
53
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
52
54
  mask: Optional[torch.Tensor] = None,
53
55
  input_pos: Optional[torch.Tensor] = None,
54
- ) -> torch.Tensor:
56
+ kv_cache: kv_utils.KVCacheEntry = None,
57
+ ) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
55
58
  """Forward function of the Gemma2Block.
56
59
 
57
60
  Exactly the same as TransformerBlock but we call the post-attention norm
@@ -62,17 +65,19 @@ class Gemma2Block(attention.TransformerBlock):
62
65
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
63
66
  mask (torch.Tensor): the optional mask tensor.
64
67
  input_pos (torch.Tensor): the optional input position tensor.
68
+ kv_cache (KVCacheEntry): the optional kv cache entry.
65
69
 
66
70
  Returns:
67
- output activation from this transformer block.
71
+ output activation from this transformer block, and updated kv cache (if
72
+ passed in).
68
73
  """
69
74
 
70
75
  x_norm = self.pre_atten_norm(x)
71
- attn_out = self.atten_func(x_norm, rope, mask, input_pos)
76
+ attn_out, kv = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
72
77
  attn_out_norm = self.post_atten_norm(attn_out)
73
78
  x = x + attn_out_norm
74
79
  output = x + self.ff(x)
75
- return output
80
+ return output, kv
76
81
 
77
82
 
78
83
  class Gemma2(nn.Module):
@@ -138,24 +143,38 @@ class Gemma2(nn.Module):
138
143
  return self.mask_cache.index_select(2, input_pos)
139
144
 
140
145
  @torch.inference_mode
141
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
142
- _, seq_len = idx.size()
146
+ def forward(
147
+ self,
148
+ tokens: torch.Tensor,
149
+ input_pos: torch.Tensor,
150
+ kv_cache: kv_utils.KVCache,
151
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
152
+ _, seq_len = tokens.size()
143
153
  assert self.config.max_seq_len >= seq_len, (
144
154
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
145
155
  f" {self.config.max_seq_len}"
146
156
  )
157
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
158
+ "The number of transformer blocks and the number of KV cache entries"
159
+ " must be the same."
160
+ )
147
161
 
148
162
  cos, sin = self.rope_cache
149
163
  cos = cos.index_select(0, input_pos)
150
164
  sin = sin.index_select(0, input_pos)
151
165
 
152
166
  # token embeddings of shape (b, t, n_embd)
153
- x = self.tok_embedding(idx)
167
+ x = self.tok_embedding(tokens)
154
168
  x = x * (self.config.embedding_dim**0.5)
155
169
 
170
+ updated_kv_entires = []
156
171
  for i, block in enumerate(self.transformer_blocks):
157
172
  mask = self.get_attention_mask(i, input_pos)
158
- x = block(x, (cos, sin), mask, input_pos)
173
+ kv_entry = kv_cache.caches[i] if kv_cache else None
174
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
175
+ if kv_entry:
176
+ updated_kv_entires.append(kv_entry)
177
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
159
178
 
160
179
  x = self.final_norm(x)
161
180
  res = self.lm_head(x) # (b, t, vocab_size)
@@ -163,7 +182,8 @@ class Gemma2(nn.Module):
163
182
  res = res / self.config.final_logit_softcap
164
183
  res = torch.tanh(res)
165
184
  res = res * self.config.final_logit_softcap
166
- return res
185
+
186
+ return {"logits": res, "kv_cache": updated_kv_cache}
167
187
 
168
188
 
169
189
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -243,14 +263,13 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
243
263
  return model
244
264
 
245
265
 
246
- def define_and_run_2b() -> None:
266
+ def define_and_run_2b(checkpoint_path: str) -> None:
247
267
  """Instantiates and runs a Gemma2 2B model."""
248
268
 
249
- current_dir = Path(__file__).parent.resolve()
269
+ current_dir = pathlib.Path(__file__).parent.resolve()
250
270
  gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
251
271
  print("Running GEMMA 2")
252
272
  kv_cache_max_len = 1024
253
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma2-2b")
254
273
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
255
274
  toks = torch.from_numpy(
256
275
  np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
@@ -258,11 +277,13 @@ def define_and_run_2b() -> None:
258
277
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
259
278
  tokens[0, :9] = toks
260
279
  input_pos = torch.arange(0, kv_cache_max_len)
261
- out = model.forward(tokens, input_pos)
262
- out_final = out[0, 8, :]
280
+ kv = kv_utils.KVCache.from_model_config(model.config)
281
+ out = model.forward(tokens, input_pos, kv)
282
+ out_final = out["logits"][0, 8, :]
263
283
  assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
264
284
 
265
285
 
266
286
  if __name__ == "__main__":
267
287
  torch.set_printoptions(sci_mode=True)
268
- define_and_run_2b()
288
+ path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
289
+ define_and_run_2b(path)
@@ -12,16 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- #
16
- # Note: This is an experimental version of phi2 with external KV cache.
17
- # Please use with caution.
15
+
16
+ """Example of converting a Phi-2 model to multi-signature tflite model."""
18
17
 
19
18
  import os
20
- from pathlib import Path
19
+ import pathlib
21
20
 
22
21
  import ai_edge_torch
23
- from ai_edge_torch.generative.examples.experimental.phi import phi2
24
- from ai_edge_torch.generative.layers.experimental import ekv_cache
22
+ from ai_edge_torch.generative.examples.phi import phi2
23
+ from ai_edge_torch.generative.layers import kv_cache
25
24
  from ai_edge_torch.generative.quantize import quant_recipes
26
25
  import torch
27
26
 
@@ -32,9 +31,8 @@ def convert_phi2_to_tflite(
32
31
  kv_cache_max_len: int = 1024,
33
32
  quantize: bool = True,
34
33
  ):
35
- """An example method for converting a Phi-2 model to multi-signature
34
+ """Converts a Phi-2 model to multi-signature tflite model.
36
35
 
37
- tflite model.
38
36
  Args:
39
37
  checkpoint_path (str): The filepath to the model checkpoint, or directory
40
38
  holding the checkpoint.
@@ -53,7 +51,7 @@ def convert_phi2_to_tflite(
53
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
54
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
55
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
56
- kv = ekv_cache.EKVCache.from_model_config(pytorch_model.config)
54
+ kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
57
55
 
58
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
59
57
  edge_model = (
@@ -77,11 +75,12 @@ def convert_phi2_to_tflite(
77
75
  )
78
76
  .convert(quant_config=quant_config)
79
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
80
79
  edge_model.export(
81
- f'/tmp/phi2_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
80
+ f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
82
81
  )
83
82
 
84
83
 
85
84
  if __name__ == '__main__':
86
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/phi2')
87
- convert_phi2_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2')
86
+ convert_phi2_to_tflite(path)
@@ -12,26 +12,22 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building phi-2 model from the Edge Generative API layers.
16
- #
17
- # Note: This is an experimental version of phi2 with external KV cache.
18
- # Please use with caution.
15
+
16
+ """Example of building a Phi-2 model."""
19
17
 
20
18
  import os
21
- from pathlib import Path
22
- from typing import Tuple
19
+ import pathlib
23
20
 
21
+ from ai_edge_torch.generative.layers import attention
24
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
- from ai_edge_torch.generative.layers.experimental import attention
27
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
25
  import ai_edge_torch.generative.layers.model_config as cfg
29
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
30
27
  import numpy as np
31
28
  import torch
32
29
  from torch import nn
33
30
 
34
-
35
31
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
36
32
  ff_up_proj="model.layers.{}.mlp.fc1",
37
33
  ff_down_proj="model.layers.{}.mlp.fc2",
@@ -89,13 +85,17 @@ class Phi2(nn.Module):
89
85
  self,
90
86
  tokens: torch.Tensor,
91
87
  input_pos: torch.Tensor,
92
- kv_cache: kv_utils.EKVCache,
93
- ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
88
+ kv_cache: kv_utils.KVCache,
89
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
94
90
  _, seq_len = tokens.size()
95
91
  assert self.config.max_seq_len >= seq_len, (
96
92
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
97
93
  f" {self.config.max_seq_len}"
98
94
  )
95
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
96
+ "The number of transformer blocks and the number of KV cache entries"
97
+ " must be the same."
98
+ )
99
99
 
100
100
  cos, sin = self.rope_cache
101
101
  cos = cos.index_select(0, input_pos)
@@ -111,11 +111,11 @@ class Phi2(nn.Module):
111
111
  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
112
112
  if kv_entry:
113
113
  updated_kv_entires.append(kv_entry)
114
- updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
114
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
115
115
 
116
116
  x = self.final_norm(x)
117
- res = self.lm_head(x) # (b, t, vocab_size)
118
- return res, updated_kv_cache
117
+ logits = self.lm_head(x) # (b, t, vocab_size)
118
+ return {"logits": logits, "kv_cache": updated_kv_cache}
119
119
 
120
120
 
121
121
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -169,39 +169,37 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
169
169
  return config
170
170
 
171
171
 
172
- def build_model(
173
- checkpoint_path: str, test_model: bool = False, **kwargs
174
- ) -> nn.Module:
172
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
175
173
  """Instantiates the model instance and load checkpoint if provided."""
176
- config = (
177
- get_fake_model_config(**kwargs)
178
- if test_model
179
- else get_model_config(**kwargs)
180
- )
174
+ config = get_model_config(**kwargs)
181
175
  model = Phi2(config)
182
- if checkpoint_path is not None:
183
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
184
- loader.load(model)
176
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
177
+ loader.load(model)
185
178
  model.eval()
186
179
  return model
187
180
 
188
181
 
189
- def define_and_run(checkpoint_path: str, test_model: bool = False) -> None:
182
+ def define_and_run(checkpoint_path: str) -> None:
190
183
  """Instantiates and runs a Phi-2 model."""
191
184
 
185
+ current_dir = pathlib.Path(__file__).parent.resolve()
186
+ phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
192
187
  kv_cache_max_len = 1024
193
- model = build_model(
194
- checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
195
- )
188
+ model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
196
189
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
197
190
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
198
191
  tokens[0, :4] = idx
199
192
  input_pos = torch.arange(0, kv_cache_max_len)
200
- kv = kv_utils.EKVCache.from_model_config(model.config)
201
- print("running an inference")
202
- print(model.forward(tokens, input_pos, kv))
193
+ kv = kv_utils.KVCache.from_model_config(model.config)
194
+ output = model.forward(tokens, input_pos, kv)
195
+ print("comparing with goldens..")
196
+ assert torch.allclose(
197
+ phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
198
+ )
203
199
 
204
200
 
205
201
  if __name__ == "__main__":
206
- input_checkpoint_path = os.path.join(Path.home(), "Downloads/phi2")
202
+ input_checkpoint_path = os.path.join(
203
+ pathlib.Path.home(), "Downloads/llm_data/phi2"
204
+ )
207
205
  define_and_run(input_checkpoint_path)