ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240911__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +35 -13
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  3. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  4. ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
  5. ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  7. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
  8. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
  9. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  10. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +38 -22
  11. ai_edge_torch/generative/layers/attention.py +60 -63
  12. ai_edge_torch/generative/layers/kv_cache.py +160 -51
  13. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
  14. ai_edge_torch/generative/test/test_model_conversion.py +71 -33
  15. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  16. ai_edge_torch/generative/test/utils.py +54 -0
  17. ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
  18. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  19. ai_edge_torch/version.py +1 -1
  20. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/METADATA +1 -1
  21. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/RECORD +25 -35
  22. ai_edge_torch/generative/examples/experimental/gemma/__init__.py +0 -14
  23. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +0 -88
  24. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  25. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  26. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  27. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  28. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  29. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  30. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  31. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  32. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  33. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  34. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/LICENSE +0 -0
  35. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/WHEEL +0 -0
  36. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@ from ai_edge_torch.testing import model_coverage
25
25
  import numpy as np
26
26
  import tensorflow as tf
27
27
  import torch
28
+ from torch import nn
28
29
  import torchvision
29
30
 
30
31
  from absl.testing import absltest as googletest
@@ -51,7 +52,7 @@ class TestConvert(googletest.TestCase):
51
52
  def test_convert_add(self):
52
53
  """Tests conversion of a simple Add module."""
53
54
 
54
- class Add(torch.nn.Module):
55
+ class Add(nn.Module):
55
56
 
56
57
  def forward(self, a, b):
57
58
  return a + b
@@ -70,7 +71,7 @@ class TestConvert(googletest.TestCase):
70
71
  def test_convert_dot_add(self):
71
72
  """Tests conversion of a matrix multiplication followed by an add."""
72
73
 
73
- class DotAdd(torch.nn.Module):
74
+ class DotAdd(nn.Module):
74
75
 
75
76
  def forward(self, a, b, c):
76
77
  return a @ b + c
@@ -99,7 +100,7 @@ class TestConvert(googletest.TestCase):
99
100
  def test_signature_args_ordering(self):
100
101
  """Tests conversion of a model with more than 10 arguments."""
101
102
 
102
- class AddChainWith11Args(torch.nn.Module):
103
+ class AddChainWith11Args(nn.Module):
103
104
  """A model with 11 arguments."""
104
105
 
105
106
  def forward(
@@ -152,7 +153,7 @@ class TestConvert(googletest.TestCase):
152
153
  def test_multi_output_model(self):
153
154
  """Tests conversion of a model that returns multiple outputs."""
154
155
 
155
- class BasicAddModelWithMultipleOutputs(torch.nn.Module):
156
+ class BasicAddModelWithMultipleOutputs(nn.Module):
156
157
  """A model that returns multiple outputs."""
157
158
 
158
159
  def forward(self, arg0, arg1):
@@ -176,7 +177,7 @@ class TestConvert(googletest.TestCase):
176
177
  def test_12_outputs_model(self):
177
178
  """Tests conversion of a model that returns more than 10 outputs."""
178
179
 
179
- class BasicAddModelWithMultipleOutputs(torch.nn.Module):
180
+ class BasicAddModelWithMultipleOutputs(nn.Module):
180
181
  """A model that returns multiple outputs."""
181
182
 
182
183
  def forward(self, arg0, arg1):
@@ -245,7 +246,7 @@ class TestConvert(googletest.TestCase):
245
246
  def test_convert_add_converter_flags(self):
246
247
  """Tests conversion of an add module setting a tflite converter flag."""
247
248
 
248
- class Add(torch.nn.Module):
249
+ class Add(nn.Module):
249
250
 
250
251
  def forward(self, a, b):
251
252
  return a + b
@@ -267,6 +268,27 @@ class TestConvert(googletest.TestCase):
267
268
  )
268
269
  self.assertTrue(os.path.isdir(ir_dump_path))
269
270
 
271
+ def test_convert_conv_transpose_batch_norm(self):
272
+ """Tests conversion of a model with ConvTranspose2d and BatchNorm2d."""
273
+
274
+ channels = 2
275
+ size = 2
276
+ torch_model = nn.Sequential(
277
+ nn.ConvTranspose2d(
278
+ channels, channels, 1, stride=2, dilation=1, bias=False
279
+ ),
280
+ nn.BatchNorm2d(channels),
281
+ )
282
+
283
+ torch_model.eval()
284
+ sample_input = (torch.rand(1, channels, size, size),)
285
+ edge_model = ai_edge_torch.convert(torch_model, sample_input)
286
+
287
+ result = model_coverage.compare_tflite_torch(
288
+ edge_model, torch_model, sample_input
289
+ )
290
+ self.assertTrue(result)
291
+
270
292
  @googletest.skipIf(
271
293
  not config.Config.use_torch_xla,
272
294
  reason="Shape polymorphism is not yet support with odml_torch.",
@@ -274,7 +296,7 @@ class TestConvert(googletest.TestCase):
274
296
  def test_convert_model_with_dynamic_batch(self):
275
297
  """Test converting a simple model with dynamic batch size."""
276
298
 
277
- class SampleModel(torch.nn.Module):
299
+ class SampleModel(nn.Module):
278
300
 
279
301
  def __init__(self):
280
302
  super().__init__()
@@ -304,7 +326,7 @@ class TestConvert(googletest.TestCase):
304
326
  def test_convert_model_with_kwargs(self):
305
327
  """Test converting a simple model with sample_kwargs."""
306
328
 
307
- class SampleModel(torch.nn.Module):
329
+ class SampleModel(nn.Module):
308
330
 
309
331
  def forward(self, x, y):
310
332
  return x + y
@@ -323,7 +345,7 @@ class TestConvert(googletest.TestCase):
323
345
  def test_convert_model_with_args_kwargs(self):
324
346
  """Test converting a simple model with both sample_args and sample_kwargs."""
325
347
 
326
- class SampleModel(torch.nn.Module):
348
+ class SampleModel(nn.Module):
327
349
 
328
350
  def forward(self, x, y):
329
351
  return x + y
@@ -343,7 +365,7 @@ class TestConvert(googletest.TestCase):
343
365
  def test_convert_model_with_args_nested_kwargs_1(self):
344
366
  """Test converting a simple model with both sample_args and nested sample_kwargs."""
345
367
 
346
- class SampleModel(torch.nn.Module):
368
+ class SampleModel(nn.Module):
347
369
 
348
370
  def forward(self, x: torch.Tensor, y: torch.Tensor, z: TestContainer1):
349
371
  return x + y + z.data_1 + z.data_2[0] + z.data_2[1]
@@ -370,7 +392,7 @@ class TestConvert(googletest.TestCase):
370
392
  def test_convert_model_with_args_nested_kwargs_2(self):
371
393
  """Test converting a simple model with both sample_args and nested sample_kwargs."""
372
394
 
373
- class SampleModel(torch.nn.Module):
395
+ class SampleModel(nn.Module):
374
396
 
375
397
  def forward(self, x, y, z):
376
398
  return x + y + z.data_1 + z.data_2[0][0] + z.data_2[1]
@@ -397,7 +419,7 @@ class TestConvert(googletest.TestCase):
397
419
  def test_convert_model_with_args_nested_kwargs_3(self):
398
420
  """Test converting a simple model with both sample_args and nested sample_kwargs."""
399
421
 
400
- class SampleModel(torch.nn.Module):
422
+ class SampleModel(nn.Module):
401
423
 
402
424
  def forward(self, x, y, z):
403
425
  return x + y + z.data_1 + z.data_2[0]["foo"] + z.data_2[1]
@@ -424,7 +446,7 @@ class TestConvert(googletest.TestCase):
424
446
  def test_convert_model_non_flat_output_dict(self):
425
447
  """Test converting a model with non-flat output structure."""
426
448
 
427
- class SampleModel(torch.nn.Module):
449
+ class SampleModel(nn.Module):
428
450
 
429
451
  def forward(self, x, y, z):
430
452
  return {"x": x, "y": TestContainer1(data_1=y, data_2=[y, z])}
@@ -13,32 +13,35 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting a Gemma2 model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.gemma import gemma2
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
24
27
 
25
- def convert_gemma_to_tflite(
28
+ def convert_gemma2_to_tflite(
26
29
  checkpoint_path: str,
27
30
  prefill_seq_len: int = 512,
28
31
  kv_cache_max_len: int = 1024,
29
32
  quantize: bool = True,
30
33
  ):
31
- """Converting a Gemma 2 2B model to multi-signature
32
- tflite model.
34
+ """Converts a Gemma2 2B model to multi-signature tflite model.
33
35
 
34
36
  Args:
35
- checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
37
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
38
+ holding the checkpoint.
36
39
  prefill_seq_len (int, optional): The maximum size of prefill input tensor.
37
40
  Defaults to 512.
38
41
  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
39
42
  including both prefill and decode. Defaults to 1024.
40
- quantize (bool, optional): Whether the model should be quanized.
41
- Defaults to True.
43
+ quantize (bool, optional): Whether the model should be quanized. Defaults
44
+ to True.
42
45
  """
43
46
  pytorch_model = gemma2.build_2b_model(
44
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/gemma2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma2-2b')
67
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
86
+ convert_gemma2_to_tflite(path)
@@ -13,11 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting a Gemma model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.gemma import gemma
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/gemma_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
67
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
86
+ convert_gemma_to_tflite(path)
@@ -12,13 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building a Gemma model.
15
+
16
+ """Example of building a Gemma model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
 
20
21
  from ai_edge_torch.generative.layers import attention
21
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
25
  import ai_edge_torch.generative.layers.model_config as cfg
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -84,16 +86,22 @@ class Gemma(nn.Module):
84
86
  )
85
87
  self.config = config
86
88
 
87
- # The model's forward function takes in additional k/v cache tensors
88
- # and returns the updated k/v cache tensors to the caller.
89
- # This can be eliminated if we handle k/v cache updates inside the model itself.
90
89
  @torch.inference_mode
91
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
92
- _, seq_len = idx.size()
90
+ def forward(
91
+ self,
92
+ tokens: torch.Tensor,
93
+ input_pos: torch.Tensor,
94
+ kv_cache: kv_utils.KVCache,
95
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
96
+ _, seq_len = tokens.size()
93
97
  assert self.config.max_seq_len >= seq_len, (
94
98
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
95
99
  f" {self.config.max_seq_len}"
96
100
  )
101
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
102
+ "The number of transformer blocks and the number of KV cache entries"
103
+ " must be the same."
104
+ )
97
105
 
98
106
  cos, sin = self.rope_cache
99
107
  cos = cos.index_select(0, input_pos)
@@ -102,15 +110,20 @@ class Gemma(nn.Module):
102
110
  mask = mask[:, :, :, : self.config.kv_cache_max]
103
111
 
104
112
  # token embeddings of shape (b, t, n_embd)
105
- x = self.tok_embedding(idx)
113
+ x = self.tok_embedding(tokens)
106
114
  x = x * (self.config.embedding_dim**0.5)
107
115
 
108
- for _, block in enumerate(self.transformer_blocks):
109
- x = block(x, (cos, sin), mask, input_pos)
116
+ updated_kv_entires = []
117
+ for i, block in enumerate(self.transformer_blocks):
118
+ kv_entry = kv_cache.caches[i] if kv_cache else None
119
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
120
+ if kv_entry:
121
+ updated_kv_entires.append(kv_entry)
122
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
110
123
 
111
124
  x = self.final_norm(x)
112
- res = self.lm_head(x) # (b, t, vocab_size)
113
- return res
125
+ logits = self.lm_head(x) # (b, t, vocab_size)
126
+ return {"logits": logits, "kv_cache": updated_kv_cache}
114
127
 
115
128
 
116
129
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -177,25 +190,28 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
177
190
  return model
178
191
 
179
192
 
180
- def define_and_run_2b() -> None:
193
+ def define_and_run_2b(checkpoint_path: str) -> None:
181
194
  """Instantiates and runs a Gemma 2B model."""
182
195
 
183
- current_dir = Path(__file__).parent.resolve()
196
+ current_dir = pathlib.Path(__file__).parent.resolve()
184
197
  gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
185
198
 
186
199
  kv_cache_max_len = 1024
187
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
188
200
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
189
201
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
190
202
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
191
203
  tokens[0, :4] = idx
192
204
  input_pos = torch.arange(0, kv_cache_max_len)
193
- lm_logits = model.forward(tokens, input_pos)
205
+ kv = kv_utils.KVCache.from_model_config(model.config)
206
+ output = model.forward(tokens, input_pos, kv)
194
207
  print("comparing with goldens..")
195
208
  assert torch.allclose(
196
- gemma_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
209
+ gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
197
210
  )
198
211
 
199
212
 
200
213
  if __name__ == "__main__":
201
- define_and_run_2b()
214
+ input_checkpoint_path = os.path.join(
215
+ pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
216
+ )
217
+ define_and_run_2b(input_checkpoint_path)
@@ -12,14 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building the Gemma2 2B model.
15
+
16
+ """Example of building a Gemma2 model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
  from typing import Optional, Tuple
20
21
 
21
22
  from ai_edge_torch.generative.layers import attention
22
23
  from ai_edge_torch.generative.layers import builder
24
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
26
  import ai_edge_torch.generative.layers.model_config as cfg
25
27
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -51,7 +53,8 @@ class Gemma2Block(attention.TransformerBlock):
51
53
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
52
54
  mask: Optional[torch.Tensor] = None,
53
55
  input_pos: Optional[torch.Tensor] = None,
54
- ) -> torch.Tensor:
56
+ kv_cache: kv_utils.KVCacheEntry = None,
57
+ ) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
55
58
  """Forward function of the Gemma2Block.
56
59
 
57
60
  Exactly the same as TransformerBlock but we call the post-attention norm
@@ -62,17 +65,19 @@ class Gemma2Block(attention.TransformerBlock):
62
65
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
63
66
  mask (torch.Tensor): the optional mask tensor.
64
67
  input_pos (torch.Tensor): the optional input position tensor.
68
+ kv_cache (KVCacheEntry): the optional kv cache entry.
65
69
 
66
70
  Returns:
67
- output activation from this transformer block.
71
+ output activation from this transformer block, and updated kv cache (if
72
+ passed in).
68
73
  """
69
74
 
70
75
  x_norm = self.pre_atten_norm(x)
71
- attn_out = self.atten_func(x_norm, rope, mask, input_pos)
76
+ attn_out, kv = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
72
77
  attn_out_norm = self.post_atten_norm(attn_out)
73
78
  x = x + attn_out_norm
74
79
  output = x + self.ff(x)
75
- return output
80
+ return output, kv
76
81
 
77
82
 
78
83
  class Gemma2(nn.Module):
@@ -138,24 +143,38 @@ class Gemma2(nn.Module):
138
143
  return self.mask_cache.index_select(2, input_pos)
139
144
 
140
145
  @torch.inference_mode
141
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
142
- _, seq_len = idx.size()
146
+ def forward(
147
+ self,
148
+ tokens: torch.Tensor,
149
+ input_pos: torch.Tensor,
150
+ kv_cache: kv_utils.KVCache,
151
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
152
+ _, seq_len = tokens.size()
143
153
  assert self.config.max_seq_len >= seq_len, (
144
154
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
145
155
  f" {self.config.max_seq_len}"
146
156
  )
157
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
158
+ "The number of transformer blocks and the number of KV cache entries"
159
+ " must be the same."
160
+ )
147
161
 
148
162
  cos, sin = self.rope_cache
149
163
  cos = cos.index_select(0, input_pos)
150
164
  sin = sin.index_select(0, input_pos)
151
165
 
152
166
  # token embeddings of shape (b, t, n_embd)
153
- x = self.tok_embedding(idx)
167
+ x = self.tok_embedding(tokens)
154
168
  x = x * (self.config.embedding_dim**0.5)
155
169
 
170
+ updated_kv_entires = []
156
171
  for i, block in enumerate(self.transformer_blocks):
157
172
  mask = self.get_attention_mask(i, input_pos)
158
- x = block(x, (cos, sin), mask, input_pos)
173
+ kv_entry = kv_cache.caches[i] if kv_cache else None
174
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
175
+ if kv_entry:
176
+ updated_kv_entires.append(kv_entry)
177
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
159
178
 
160
179
  x = self.final_norm(x)
161
180
  res = self.lm_head(x) # (b, t, vocab_size)
@@ -163,7 +182,8 @@ class Gemma2(nn.Module):
163
182
  res = res / self.config.final_logit_softcap
164
183
  res = torch.tanh(res)
165
184
  res = res * self.config.final_logit_softcap
166
- return res
185
+
186
+ return {"logits": res, "kv_cache": updated_kv_cache}
167
187
 
168
188
 
169
189
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -243,14 +263,13 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
243
263
  return model
244
264
 
245
265
 
246
- def define_and_run_2b() -> None:
266
+ def define_and_run_2b(checkpoint_path: str) -> None:
247
267
  """Instantiates and runs a Gemma2 2B model."""
248
268
 
249
- current_dir = Path(__file__).parent.resolve()
269
+ current_dir = pathlib.Path(__file__).parent.resolve()
250
270
  gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
251
271
  print("Running GEMMA 2")
252
272
  kv_cache_max_len = 1024
253
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma2-2b")
254
273
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
255
274
  toks = torch.from_numpy(
256
275
  np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
@@ -258,11 +277,13 @@ def define_and_run_2b() -> None:
258
277
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
259
278
  tokens[0, :9] = toks
260
279
  input_pos = torch.arange(0, kv_cache_max_len)
261
- out = model.forward(tokens, input_pos)
262
- out_final = out[0, 8, :]
280
+ kv = kv_utils.KVCache.from_model_config(model.config)
281
+ out = model.forward(tokens, input_pos, kv)
282
+ out_final = out["logits"][0, 8, :]
263
283
  assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
264
284
 
265
285
 
266
286
  if __name__ == "__main__":
267
287
  torch.set_printoptions(sci_mode=True)
268
- define_and_run_2b()
288
+ path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
289
+ define_and_run_2b(path)
@@ -12,16 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- #
16
- # Note: This is an experimental version of phi2 with external KV cache.
17
- # Please use with caution.
15
+
16
+ """Example of converting a Phi-2 model to multi-signature tflite model."""
18
17
 
19
18
  import os
20
- from pathlib import Path
19
+ import pathlib
21
20
 
22
21
  import ai_edge_torch
23
- from ai_edge_torch.generative.examples.experimental.phi import phi2
24
- from ai_edge_torch.generative.layers.experimental import ekv_cache
22
+ from ai_edge_torch.generative.examples.phi import phi2
23
+ from ai_edge_torch.generative.layers import kv_cache
25
24
  from ai_edge_torch.generative.quantize import quant_recipes
26
25
  import torch
27
26
 
@@ -32,9 +31,8 @@ def convert_phi2_to_tflite(
32
31
  kv_cache_max_len: int = 1024,
33
32
  quantize: bool = True,
34
33
  ):
35
- """An example method for converting a Phi-2 model to multi-signature
34
+ """Converts a Phi-2 model to multi-signature tflite model.
36
35
 
37
- tflite model.
38
36
  Args:
39
37
  checkpoint_path (str): The filepath to the model checkpoint, or directory
40
38
  holding the checkpoint.
@@ -53,7 +51,7 @@ def convert_phi2_to_tflite(
53
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
54
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
55
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
56
- kv = ekv_cache.EKVCache.from_model_config(pytorch_model.config)
54
+ kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
57
55
 
58
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
59
57
  edge_model = (
@@ -77,11 +75,12 @@ def convert_phi2_to_tflite(
77
75
  )
78
76
  .convert(quant_config=quant_config)
79
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
80
79
  edge_model.export(
81
- f'/tmp/phi2_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
80
+ f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
82
81
  )
83
82
 
84
83
 
85
84
  if __name__ == '__main__':
86
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/phi2')
87
- convert_phi2_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2')
86
+ convert_phi2_to_tflite(path)