ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240911__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (36) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +35 -13
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  3. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  4. ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
  5. ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  7. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
  8. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
  9. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  10. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +38 -22
  11. ai_edge_torch/generative/layers/attention.py +60 -63
  12. ai_edge_torch/generative/layers/kv_cache.py +160 -51
  13. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
  14. ai_edge_torch/generative/test/test_model_conversion.py +71 -33
  15. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  16. ai_edge_torch/generative/test/utils.py +54 -0
  17. ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
  18. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  19. ai_edge_torch/version.py +1 -1
  20. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/METADATA +1 -1
  21. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/RECORD +25 -35
  22. ai_edge_torch/generative/examples/experimental/gemma/__init__.py +0 -14
  23. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +0 -88
  24. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  25. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  26. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  27. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  28. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  29. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  30. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  31. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  32. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  33. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  34. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/LICENSE +0 -0
  35. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/WHEEL +0 -0
  36. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@ from ai_edge_torch.testing import model_coverage
25
25
  import numpy as np
26
26
  import tensorflow as tf
27
27
  import torch
28
+ from torch import nn
28
29
  import torchvision
29
30
 
30
31
  from absl.testing import absltest as googletest
@@ -51,7 +52,7 @@ class TestConvert(googletest.TestCase):
51
52
  def test_convert_add(self):
52
53
  """Tests conversion of a simple Add module."""
53
54
 
54
- class Add(torch.nn.Module):
55
+ class Add(nn.Module):
55
56
 
56
57
  def forward(self, a, b):
57
58
  return a + b
@@ -70,7 +71,7 @@ class TestConvert(googletest.TestCase):
70
71
  def test_convert_dot_add(self):
71
72
  """Tests conversion of a matrix multiplication followed by an add."""
72
73
 
73
- class DotAdd(torch.nn.Module):
74
+ class DotAdd(nn.Module):
74
75
 
75
76
  def forward(self, a, b, c):
76
77
  return a @ b + c
@@ -99,7 +100,7 @@ class TestConvert(googletest.TestCase):
99
100
  def test_signature_args_ordering(self):
100
101
  """Tests conversion of a model with more than 10 arguments."""
101
102
 
102
- class AddChainWith11Args(torch.nn.Module):
103
+ class AddChainWith11Args(nn.Module):
103
104
  """A model with 11 arguments."""
104
105
 
105
106
  def forward(
@@ -152,7 +153,7 @@ class TestConvert(googletest.TestCase):
152
153
  def test_multi_output_model(self):
153
154
  """Tests conversion of a model that returns multiple outputs."""
154
155
 
155
- class BasicAddModelWithMultipleOutputs(torch.nn.Module):
156
+ class BasicAddModelWithMultipleOutputs(nn.Module):
156
157
  """A model that returns multiple outputs."""
157
158
 
158
159
  def forward(self, arg0, arg1):
@@ -176,7 +177,7 @@ class TestConvert(googletest.TestCase):
176
177
  def test_12_outputs_model(self):
177
178
  """Tests conversion of a model that returns more than 10 outputs."""
178
179
 
179
- class BasicAddModelWithMultipleOutputs(torch.nn.Module):
180
+ class BasicAddModelWithMultipleOutputs(nn.Module):
180
181
  """A model that returns multiple outputs."""
181
182
 
182
183
  def forward(self, arg0, arg1):
@@ -245,7 +246,7 @@ class TestConvert(googletest.TestCase):
245
246
  def test_convert_add_converter_flags(self):
246
247
  """Tests conversion of an add module setting a tflite converter flag."""
247
248
 
248
- class Add(torch.nn.Module):
249
+ class Add(nn.Module):
249
250
 
250
251
  def forward(self, a, b):
251
252
  return a + b
@@ -267,6 +268,27 @@ class TestConvert(googletest.TestCase):
267
268
  )
268
269
  self.assertTrue(os.path.isdir(ir_dump_path))
269
270
 
271
+ def test_convert_conv_transpose_batch_norm(self):
272
+ """Tests conversion of a model with ConvTranspose2d and BatchNorm2d."""
273
+
274
+ channels = 2
275
+ size = 2
276
+ torch_model = nn.Sequential(
277
+ nn.ConvTranspose2d(
278
+ channels, channels, 1, stride=2, dilation=1, bias=False
279
+ ),
280
+ nn.BatchNorm2d(channels),
281
+ )
282
+
283
+ torch_model.eval()
284
+ sample_input = (torch.rand(1, channels, size, size),)
285
+ edge_model = ai_edge_torch.convert(torch_model, sample_input)
286
+
287
+ result = model_coverage.compare_tflite_torch(
288
+ edge_model, torch_model, sample_input
289
+ )
290
+ self.assertTrue(result)
291
+
270
292
  @googletest.skipIf(
271
293
  not config.Config.use_torch_xla,
272
294
  reason="Shape polymorphism is not yet support with odml_torch.",
@@ -274,7 +296,7 @@ class TestConvert(googletest.TestCase):
274
296
  def test_convert_model_with_dynamic_batch(self):
275
297
  """Test converting a simple model with dynamic batch size."""
276
298
 
277
- class SampleModel(torch.nn.Module):
299
+ class SampleModel(nn.Module):
278
300
 
279
301
  def __init__(self):
280
302
  super().__init__()
@@ -304,7 +326,7 @@ class TestConvert(googletest.TestCase):
304
326
  def test_convert_model_with_kwargs(self):
305
327
  """Test converting a simple model with sample_kwargs."""
306
328
 
307
- class SampleModel(torch.nn.Module):
329
+ class SampleModel(nn.Module):
308
330
 
309
331
  def forward(self, x, y):
310
332
  return x + y
@@ -323,7 +345,7 @@ class TestConvert(googletest.TestCase):
323
345
  def test_convert_model_with_args_kwargs(self):
324
346
  """Test converting a simple model with both sample_args and sample_kwargs."""
325
347
 
326
- class SampleModel(torch.nn.Module):
348
+ class SampleModel(nn.Module):
327
349
 
328
350
  def forward(self, x, y):
329
351
  return x + y
@@ -343,7 +365,7 @@ class TestConvert(googletest.TestCase):
343
365
  def test_convert_model_with_args_nested_kwargs_1(self):
344
366
  """Test converting a simple model with both sample_args and nested sample_kwargs."""
345
367
 
346
- class SampleModel(torch.nn.Module):
368
+ class SampleModel(nn.Module):
347
369
 
348
370
  def forward(self, x: torch.Tensor, y: torch.Tensor, z: TestContainer1):
349
371
  return x + y + z.data_1 + z.data_2[0] + z.data_2[1]
@@ -370,7 +392,7 @@ class TestConvert(googletest.TestCase):
370
392
  def test_convert_model_with_args_nested_kwargs_2(self):
371
393
  """Test converting a simple model with both sample_args and nested sample_kwargs."""
372
394
 
373
- class SampleModel(torch.nn.Module):
395
+ class SampleModel(nn.Module):
374
396
 
375
397
  def forward(self, x, y, z):
376
398
  return x + y + z.data_1 + z.data_2[0][0] + z.data_2[1]
@@ -397,7 +419,7 @@ class TestConvert(googletest.TestCase):
397
419
  def test_convert_model_with_args_nested_kwargs_3(self):
398
420
  """Test converting a simple model with both sample_args and nested sample_kwargs."""
399
421
 
400
- class SampleModel(torch.nn.Module):
422
+ class SampleModel(nn.Module):
401
423
 
402
424
  def forward(self, x, y, z):
403
425
  return x + y + z.data_1 + z.data_2[0]["foo"] + z.data_2[1]
@@ -424,7 +446,7 @@ class TestConvert(googletest.TestCase):
424
446
  def test_convert_model_non_flat_output_dict(self):
425
447
  """Test converting a model with non-flat output structure."""
426
448
 
427
- class SampleModel(torch.nn.Module):
449
+ class SampleModel(nn.Module):
428
450
 
429
451
  def forward(self, x, y, z):
430
452
  return {"x": x, "y": TestContainer1(data_1=y, data_2=[y, z])}
@@ -13,32 +13,35 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting a Gemma2 model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.gemma import gemma2
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
24
27
 
25
- def convert_gemma_to_tflite(
28
+ def convert_gemma2_to_tflite(
26
29
  checkpoint_path: str,
27
30
  prefill_seq_len: int = 512,
28
31
  kv_cache_max_len: int = 1024,
29
32
  quantize: bool = True,
30
33
  ):
31
- """Converting a Gemma 2 2B model to multi-signature
32
- tflite model.
34
+ """Converts a Gemma2 2B model to multi-signature tflite model.
33
35
 
34
36
  Args:
35
- checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
37
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
38
+ holding the checkpoint.
36
39
  prefill_seq_len (int, optional): The maximum size of prefill input tensor.
37
40
  Defaults to 512.
38
41
  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
39
42
  including both prefill and decode. Defaults to 1024.
40
- quantize (bool, optional): Whether the model should be quanized.
41
- Defaults to True.
43
+ quantize (bool, optional): Whether the model should be quanized. Defaults
44
+ to True.
42
45
  """
43
46
  pytorch_model = gemma2.build_2b_model(
44
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/gemma2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma2-2b')
67
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
86
+ convert_gemma2_to_tflite(path)
@@ -13,11 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting a Gemma model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.gemma import gemma
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/gemma_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
67
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
86
+ convert_gemma_to_tflite(path)
@@ -12,13 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building a Gemma model.
15
+
16
+ """Example of building a Gemma model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
 
20
21
  from ai_edge_torch.generative.layers import attention
21
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
25
  import ai_edge_torch.generative.layers.model_config as cfg
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -84,16 +86,22 @@ class Gemma(nn.Module):
84
86
  )
85
87
  self.config = config
86
88
 
87
- # The model's forward function takes in additional k/v cache tensors
88
- # and returns the updated k/v cache tensors to the caller.
89
- # This can be eliminated if we handle k/v cache updates inside the model itself.
90
89
  @torch.inference_mode
91
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
92
- _, seq_len = idx.size()
90
+ def forward(
91
+ self,
92
+ tokens: torch.Tensor,
93
+ input_pos: torch.Tensor,
94
+ kv_cache: kv_utils.KVCache,
95
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
96
+ _, seq_len = tokens.size()
93
97
  assert self.config.max_seq_len >= seq_len, (
94
98
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
95
99
  f" {self.config.max_seq_len}"
96
100
  )
101
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
102
+ "The number of transformer blocks and the number of KV cache entries"
103
+ " must be the same."
104
+ )
97
105
 
98
106
  cos, sin = self.rope_cache
99
107
  cos = cos.index_select(0, input_pos)
@@ -102,15 +110,20 @@ class Gemma(nn.Module):
102
110
  mask = mask[:, :, :, : self.config.kv_cache_max]
103
111
 
104
112
  # token embeddings of shape (b, t, n_embd)
105
- x = self.tok_embedding(idx)
113
+ x = self.tok_embedding(tokens)
106
114
  x = x * (self.config.embedding_dim**0.5)
107
115
 
108
- for _, block in enumerate(self.transformer_blocks):
109
- x = block(x, (cos, sin), mask, input_pos)
116
+ updated_kv_entires = []
117
+ for i, block in enumerate(self.transformer_blocks):
118
+ kv_entry = kv_cache.caches[i] if kv_cache else None
119
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
120
+ if kv_entry:
121
+ updated_kv_entires.append(kv_entry)
122
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
110
123
 
111
124
  x = self.final_norm(x)
112
- res = self.lm_head(x) # (b, t, vocab_size)
113
- return res
125
+ logits = self.lm_head(x) # (b, t, vocab_size)
126
+ return {"logits": logits, "kv_cache": updated_kv_cache}
114
127
 
115
128
 
116
129
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -177,25 +190,28 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
177
190
  return model
178
191
 
179
192
 
180
- def define_and_run_2b() -> None:
193
+ def define_and_run_2b(checkpoint_path: str) -> None:
181
194
  """Instantiates and runs a Gemma 2B model."""
182
195
 
183
- current_dir = Path(__file__).parent.resolve()
196
+ current_dir = pathlib.Path(__file__).parent.resolve()
184
197
  gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
185
198
 
186
199
  kv_cache_max_len = 1024
187
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
188
200
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
189
201
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
190
202
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
191
203
  tokens[0, :4] = idx
192
204
  input_pos = torch.arange(0, kv_cache_max_len)
193
- lm_logits = model.forward(tokens, input_pos)
205
+ kv = kv_utils.KVCache.from_model_config(model.config)
206
+ output = model.forward(tokens, input_pos, kv)
194
207
  print("comparing with goldens..")
195
208
  assert torch.allclose(
196
- gemma_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
209
+ gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
197
210
  )
198
211
 
199
212
 
200
213
  if __name__ == "__main__":
201
- define_and_run_2b()
214
+ input_checkpoint_path = os.path.join(
215
+ pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
216
+ )
217
+ define_and_run_2b(input_checkpoint_path)
@@ -12,14 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building the Gemma2 2B model.
15
+
16
+ """Example of building a Gemma2 model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
  from typing import Optional, Tuple
20
21
 
21
22
  from ai_edge_torch.generative.layers import attention
22
23
  from ai_edge_torch.generative.layers import builder
24
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
26
  import ai_edge_torch.generative.layers.model_config as cfg
25
27
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -51,7 +53,8 @@ class Gemma2Block(attention.TransformerBlock):
51
53
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
52
54
  mask: Optional[torch.Tensor] = None,
53
55
  input_pos: Optional[torch.Tensor] = None,
54
- ) -> torch.Tensor:
56
+ kv_cache: kv_utils.KVCacheEntry = None,
57
+ ) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
55
58
  """Forward function of the Gemma2Block.
56
59
 
57
60
  Exactly the same as TransformerBlock but we call the post-attention norm
@@ -62,17 +65,19 @@ class Gemma2Block(attention.TransformerBlock):
62
65
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
63
66
  mask (torch.Tensor): the optional mask tensor.
64
67
  input_pos (torch.Tensor): the optional input position tensor.
68
+ kv_cache (KVCacheEntry): the optional kv cache entry.
65
69
 
66
70
  Returns:
67
- output activation from this transformer block.
71
+ output activation from this transformer block, and updated kv cache (if
72
+ passed in).
68
73
  """
69
74
 
70
75
  x_norm = self.pre_atten_norm(x)
71
- attn_out = self.atten_func(x_norm, rope, mask, input_pos)
76
+ attn_out, kv = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
72
77
  attn_out_norm = self.post_atten_norm(attn_out)
73
78
  x = x + attn_out_norm
74
79
  output = x + self.ff(x)
75
- return output
80
+ return output, kv
76
81
 
77
82
 
78
83
  class Gemma2(nn.Module):
@@ -138,24 +143,38 @@ class Gemma2(nn.Module):
138
143
  return self.mask_cache.index_select(2, input_pos)
139
144
 
140
145
  @torch.inference_mode
141
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
142
- _, seq_len = idx.size()
146
+ def forward(
147
+ self,
148
+ tokens: torch.Tensor,
149
+ input_pos: torch.Tensor,
150
+ kv_cache: kv_utils.KVCache,
151
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
152
+ _, seq_len = tokens.size()
143
153
  assert self.config.max_seq_len >= seq_len, (
144
154
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
145
155
  f" {self.config.max_seq_len}"
146
156
  )
157
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
158
+ "The number of transformer blocks and the number of KV cache entries"
159
+ " must be the same."
160
+ )
147
161
 
148
162
  cos, sin = self.rope_cache
149
163
  cos = cos.index_select(0, input_pos)
150
164
  sin = sin.index_select(0, input_pos)
151
165
 
152
166
  # token embeddings of shape (b, t, n_embd)
153
- x = self.tok_embedding(idx)
167
+ x = self.tok_embedding(tokens)
154
168
  x = x * (self.config.embedding_dim**0.5)
155
169
 
170
+ updated_kv_entires = []
156
171
  for i, block in enumerate(self.transformer_blocks):
157
172
  mask = self.get_attention_mask(i, input_pos)
158
- x = block(x, (cos, sin), mask, input_pos)
173
+ kv_entry = kv_cache.caches[i] if kv_cache else None
174
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
175
+ if kv_entry:
176
+ updated_kv_entires.append(kv_entry)
177
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
159
178
 
160
179
  x = self.final_norm(x)
161
180
  res = self.lm_head(x) # (b, t, vocab_size)
@@ -163,7 +182,8 @@ class Gemma2(nn.Module):
163
182
  res = res / self.config.final_logit_softcap
164
183
  res = torch.tanh(res)
165
184
  res = res * self.config.final_logit_softcap
166
- return res
185
+
186
+ return {"logits": res, "kv_cache": updated_kv_cache}
167
187
 
168
188
 
169
189
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -243,14 +263,13 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
243
263
  return model
244
264
 
245
265
 
246
- def define_and_run_2b() -> None:
266
+ def define_and_run_2b(checkpoint_path: str) -> None:
247
267
  """Instantiates and runs a Gemma2 2B model."""
248
268
 
249
- current_dir = Path(__file__).parent.resolve()
269
+ current_dir = pathlib.Path(__file__).parent.resolve()
250
270
  gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
251
271
  print("Running GEMMA 2")
252
272
  kv_cache_max_len = 1024
253
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma2-2b")
254
273
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
255
274
  toks = torch.from_numpy(
256
275
  np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
@@ -258,11 +277,13 @@ def define_and_run_2b() -> None:
258
277
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
259
278
  tokens[0, :9] = toks
260
279
  input_pos = torch.arange(0, kv_cache_max_len)
261
- out = model.forward(tokens, input_pos)
262
- out_final = out[0, 8, :]
280
+ kv = kv_utils.KVCache.from_model_config(model.config)
281
+ out = model.forward(tokens, input_pos, kv)
282
+ out_final = out["logits"][0, 8, :]
263
283
  assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
264
284
 
265
285
 
266
286
  if __name__ == "__main__":
267
287
  torch.set_printoptions(sci_mode=True)
268
- define_and_run_2b()
288
+ path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
289
+ define_and_run_2b(path)
@@ -12,16 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- #
16
- # Note: This is an experimental version of phi2 with external KV cache.
17
- # Please use with caution.
15
+
16
+ """Example of converting a Phi-2 model to multi-signature tflite model."""
18
17
 
19
18
  import os
20
- from pathlib import Path
19
+ import pathlib
21
20
 
22
21
  import ai_edge_torch
23
- from ai_edge_torch.generative.examples.experimental.phi import phi2
24
- from ai_edge_torch.generative.layers.experimental import ekv_cache
22
+ from ai_edge_torch.generative.examples.phi import phi2
23
+ from ai_edge_torch.generative.layers import kv_cache
25
24
  from ai_edge_torch.generative.quantize import quant_recipes
26
25
  import torch
27
26
 
@@ -32,9 +31,8 @@ def convert_phi2_to_tflite(
32
31
  kv_cache_max_len: int = 1024,
33
32
  quantize: bool = True,
34
33
  ):
35
- """An example method for converting a Phi-2 model to multi-signature
34
+ """Converts a Phi-2 model to multi-signature tflite model.
36
35
 
37
- tflite model.
38
36
  Args:
39
37
  checkpoint_path (str): The filepath to the model checkpoint, or directory
40
38
  holding the checkpoint.
@@ -53,7 +51,7 @@ def convert_phi2_to_tflite(
53
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
54
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
55
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
56
- kv = ekv_cache.EKVCache.from_model_config(pytorch_model.config)
54
+ kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
57
55
 
58
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
59
57
  edge_model = (
@@ -77,11 +75,12 @@ def convert_phi2_to_tflite(
77
75
  )
78
76
  .convert(quant_config=quant_config)
79
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
80
79
  edge_model.export(
81
- f'/tmp/phi2_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
80
+ f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
82
81
  )
83
82
 
84
83
 
85
84
  if __name__ == '__main__':
86
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/phi2')
87
- convert_phi2_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2')
86
+ convert_phi2_to_tflite(path)