ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  3. ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
  4. ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
  5. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
  7. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  8. ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
  9. ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
  10. ai_edge_torch/generative/examples/t5/t5.py +35 -22
  11. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  12. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  13. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
  14. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  15. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
  16. ai_edge_torch/generative/layers/attention.py +77 -73
  17. ai_edge_torch/generative/layers/builder.py +5 -3
  18. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  19. ai_edge_torch/generative/layers/model_config.py +38 -19
  20. ai_edge_torch/generative/layers/normalization.py +158 -0
  21. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  22. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  23. ai_edge_torch/generative/test/test_loader.py +1 -1
  24. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  25. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  26. ai_edge_torch/generative/test/utils.py +54 -0
  27. ai_edge_torch/generative/utilities/loader.py +15 -15
  28. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  29. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  30. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  31. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  32. ai_edge_torch/version.py +1 -1
  33. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
  34. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +39 -45
  35. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  36. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  37. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  38. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  39. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  40. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  42. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  43. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  44. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  45. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  46. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
  47. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -13,32 +13,35 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting a Gemma2 model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.gemma import gemma2
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
24
27
 
25
- def convert_gemma_to_tflite(
28
+ def convert_gemma2_to_tflite(
26
29
  checkpoint_path: str,
27
30
  prefill_seq_len: int = 512,
28
31
  kv_cache_max_len: int = 1024,
29
32
  quantize: bool = True,
30
33
  ):
31
- """Converting a Gemma 2 2B model to multi-signature
32
- tflite model.
34
+ """Converts a Gemma2 2B model to multi-signature tflite model.
33
35
 
34
36
  Args:
35
- checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
37
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
38
+ holding the checkpoint.
36
39
  prefill_seq_len (int, optional): The maximum size of prefill input tensor.
37
40
  Defaults to 512.
38
41
  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
39
42
  including both prefill and decode. Defaults to 1024.
40
- quantize (bool, optional): Whether the model should be quanized.
41
- Defaults to True.
43
+ quantize (bool, optional): Whether the model should be quanized. Defaults
44
+ to True.
42
45
  """
43
46
  pytorch_model = gemma2.build_2b_model(
44
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/gemma2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma2-2b')
67
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
86
+ convert_gemma2_to_tflite(path)
@@ -13,11 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting a Gemma model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.gemma import gemma
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/gemma_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
67
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
86
+ convert_gemma_to_tflite(path)
@@ -12,13 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building a Gemma model.
15
+
16
+ """Example of building a Gemma model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
 
20
21
  from ai_edge_torch.generative.layers import attention
21
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
25
  import ai_edge_torch.generative.layers.model_config as cfg
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -48,7 +50,6 @@ class Gemma(nn.Module):
48
50
  def __init__(self, config: cfg.ModelConfig):
49
51
  super().__init__()
50
52
 
51
- self.config = config
52
53
  # Construct model layers.
53
54
  self.tok_embedding = nn.Embedding(
54
55
  config.vocab_size, config.embedding_dim, padding_idx=0
@@ -60,18 +61,20 @@ class Gemma(nn.Module):
60
61
  )
61
62
  # Gemma re-uses the embedding as the head projection layer.
62
63
  self.lm_head.weight.data = self.tok_embedding.weight.data
64
+ # Gemma has only one block config.
65
+ block_config = config.block_config(0)
63
66
  self.transformer_blocks = nn.ModuleList(
64
- attention.TransformerBlock(config) for _ in range(config.num_layers)
67
+ attention.TransformerBlock(block_config, config)
68
+ for _ in range(config.num_layers)
65
69
  )
66
70
  self.final_norm = builder.build_norm(
67
71
  config.embedding_dim,
68
72
  config.final_norm_config,
69
73
  )
74
+ attn_config = block_config.attn_config
70
75
  self.rope_cache = attn_utils.build_rope_cache(
71
76
  size=config.kv_cache_max,
72
- dim=int(
73
- config.attn_config.rotary_percentage * config.attn_config.head_dim
74
- ),
77
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
75
78
  base=10_000,
76
79
  condense_ratio=1,
77
80
  dtype=torch.float32,
@@ -84,16 +87,22 @@ class Gemma(nn.Module):
84
87
  )
85
88
  self.config = config
86
89
 
87
- # The model's forward function takes in additional k/v cache tensors
88
- # and returns the updated k/v cache tensors to the caller.
89
- # This can be eliminated if we handle k/v cache updates inside the model itself.
90
90
  @torch.inference_mode
91
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
92
- _, seq_len = idx.size()
91
+ def forward(
92
+ self,
93
+ tokens: torch.Tensor,
94
+ input_pos: torch.Tensor,
95
+ kv_cache: kv_utils.KVCache,
96
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
97
+ _, seq_len = tokens.size()
93
98
  assert self.config.max_seq_len >= seq_len, (
94
99
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
95
100
  f" {self.config.max_seq_len}"
96
101
  )
102
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
103
+ "The number of transformer blocks and the number of KV cache entries"
104
+ " must be the same."
105
+ )
97
106
 
98
107
  cos, sin = self.rope_cache
99
108
  cos = cos.index_select(0, input_pos)
@@ -102,15 +111,20 @@ class Gemma(nn.Module):
102
111
  mask = mask[:, :, :, : self.config.kv_cache_max]
103
112
 
104
113
  # token embeddings of shape (b, t, n_embd)
105
- x = self.tok_embedding(idx)
114
+ x = self.tok_embedding(tokens)
106
115
  x = x * (self.config.embedding_dim**0.5)
107
116
 
108
- for _, block in enumerate(self.transformer_blocks):
109
- x = block(x, (cos, sin), mask, input_pos)
117
+ updated_kv_entires = []
118
+ for i, block in enumerate(self.transformer_blocks):
119
+ kv_entry = kv_cache.caches[i] if kv_cache else None
120
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
121
+ if kv_entry:
122
+ updated_kv_entires.append(kv_entry)
123
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
110
124
 
111
125
  x = self.final_norm(x)
112
- res = self.lm_head(x) # (b, t, vocab_size)
113
- return res
126
+ logits = self.lm_head(x) # (b, t, vocab_size)
127
+ return {"logits": logits, "kv_cache": updated_kv_cache}
114
128
 
115
129
 
116
130
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -139,18 +153,20 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
139
153
  epsilon=1e-6,
140
154
  zero_centered=True,
141
155
  )
156
+ block_config = cfg.TransformerBlockConfig(
157
+ attn_config=attn_config,
158
+ ff_config=ff_config,
159
+ pre_attention_norm_config=norm_config,
160
+ post_attention_norm_config=norm_config,
161
+ )
142
162
  config = cfg.ModelConfig(
143
163
  vocab_size=256000,
144
164
  num_layers=18,
145
165
  max_seq_len=8192,
146
166
  embedding_dim=2048,
147
167
  kv_cache_max_len=kv_cache_max_len,
148
- attn_config=attn_config,
149
- ff_config=ff_config,
150
- pre_attention_norm_config=norm_config,
151
- post_attention_norm_config=norm_config,
168
+ block_configs=block_config,
152
169
  final_norm_config=norm_config,
153
- parallel_residual=False,
154
170
  lm_head_use_bias=False,
155
171
  enable_hlfb=True,
156
172
  )
@@ -159,7 +175,8 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
159
175
 
160
176
  def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
161
177
  config = get_model_config_2b(kv_cache_max_len)
162
- config.ff_config.intermediate_size = 128
178
+ # Gemma has only one block config.
179
+ config.block_config(0).ff_config.intermediate_size = 128
163
180
  config.vocab_size = 128
164
181
  config.num_layers = 2
165
182
  config.max_seq_len = 2 * kv_cache_max_len
@@ -170,32 +187,35 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
170
187
  config = get_model_config_2b(**kwargs)
171
188
  model = Gemma(config)
172
189
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
173
- # since embedding and lm-head use the same weight, we need to set strict
190
+ # Since embedding and lm-head use the same weight, we need to set strict
174
191
  # to False.
175
192
  loader.load(model, strict=False)
176
193
  model.eval()
177
194
  return model
178
195
 
179
196
 
180
- def define_and_run_2b() -> None:
197
+ def define_and_run_2b(checkpoint_path: str) -> None:
181
198
  """Instantiates and runs a Gemma 2B model."""
182
199
 
183
- current_dir = Path(__file__).parent.resolve()
200
+ current_dir = pathlib.Path(__file__).parent.resolve()
184
201
  gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
185
202
 
186
203
  kv_cache_max_len = 1024
187
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
188
204
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
189
205
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
190
206
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
191
207
  tokens[0, :4] = idx
192
208
  input_pos = torch.arange(0, kv_cache_max_len)
193
- lm_logits = model.forward(tokens, input_pos)
209
+ kv = kv_utils.KVCache.from_model_config(model.config)
210
+ output = model.forward(tokens, input_pos, kv)
194
211
  print("comparing with goldens..")
195
212
  assert torch.allclose(
196
- gemma_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
213
+ gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
197
214
  )
198
215
 
199
216
 
200
217
  if __name__ == "__main__":
201
- define_and_run_2b()
218
+ input_checkpoint_path = os.path.join(
219
+ pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
220
+ )
221
+ define_and_run_2b(input_checkpoint_path)
@@ -12,14 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building the Gemma2 2B model.
15
+
16
+ """Example of building a Gemma2 model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
  from typing import Optional, Tuple
20
21
 
21
22
  from ai_edge_torch.generative.layers import attention
22
23
  from ai_edge_torch.generative.layers import builder
24
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
26
  import ai_edge_torch.generative.layers.model_config as cfg
25
27
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -51,7 +53,8 @@ class Gemma2Block(attention.TransformerBlock):
51
53
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
52
54
  mask: Optional[torch.Tensor] = None,
53
55
  input_pos: Optional[torch.Tensor] = None,
54
- ) -> torch.Tensor:
56
+ kv_cache: kv_utils.KVCacheEntry = None,
57
+ ) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
55
58
  """Forward function of the Gemma2Block.
56
59
 
57
60
  Exactly the same as TransformerBlock but we call the post-attention norm
@@ -62,17 +65,19 @@ class Gemma2Block(attention.TransformerBlock):
62
65
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
63
66
  mask (torch.Tensor): the optional mask tensor.
64
67
  input_pos (torch.Tensor): the optional input position tensor.
68
+ kv_cache (KVCacheEntry): the optional kv cache entry.
65
69
 
66
70
  Returns:
67
- output activation from this transformer block.
71
+ output activation from this transformer block, and updated kv cache (if
72
+ passed in).
68
73
  """
69
74
 
70
75
  x_norm = self.pre_atten_norm(x)
71
- attn_out = self.atten_func(x_norm, rope, mask, input_pos)
76
+ attn_out, kv = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
72
77
  attn_out_norm = self.post_atten_norm(attn_out)
73
78
  x = x + attn_out_norm
74
79
  output = x + self.ff(x)
75
- return output
80
+ return output, kv
76
81
 
77
82
 
78
83
  class Gemma2(nn.Module):
@@ -81,7 +86,6 @@ class Gemma2(nn.Module):
81
86
  def __init__(self, config: cfg.ModelConfig):
82
87
  super().__init__()
83
88
 
84
- self.config = config
85
89
  # Construct model layers.
86
90
  self.tok_embedding = nn.Embedding(
87
91
  config.vocab_size, config.embedding_dim, padding_idx=0
@@ -91,20 +95,22 @@ class Gemma2(nn.Module):
91
95
  config.vocab_size,
92
96
  bias=config.lm_head_use_bias,
93
97
  )
94
- # Gemma re-uses the embedding as the head projection layer.
98
+ # Gemma2 re-uses the embedding as the head projection layer.
95
99
  self.lm_head.weight.data = self.tok_embedding.weight.data
96
100
  self.transformer_blocks = nn.ModuleList(
97
- Gemma2Block(config) for _ in range(config.num_layers)
101
+ Gemma2Block(config.block_config(idx), config)
102
+ for idx in range(config.num_layers)
98
103
  )
99
104
  self.final_norm = builder.build_norm(
100
105
  config.embedding_dim,
101
106
  config.final_norm_config,
102
107
  )
108
+ # Gemma2 has same hyper parameters for each layer except for attention
109
+ # types. Use the first layer.
110
+ attn_config = config.block_config(0).attn_config
103
111
  self.rope_cache = attn_utils.build_rope_cache(
104
112
  size=config.kv_cache_max,
105
- dim=int(
106
- config.attn_config.rotary_percentage * config.attn_config.head_dim
107
- ),
113
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
108
114
  base=10_000,
109
115
  condense_ratio=1,
110
116
  dtype=torch.float32,
@@ -115,47 +121,56 @@ class Gemma2(nn.Module):
115
121
  dtype=torch.float32,
116
122
  device=torch.device("cpu"),
117
123
  )
118
-
119
124
  self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
120
125
  size=config.kv_cache_max,
121
- window_size=self.config.attn_config.sliding_window_size,
126
+ window_size=attn_config.sliding_window_size,
122
127
  dtype=torch.float32,
123
128
  device=torch.device("cpu"),
124
129
  )
125
-
126
130
  self.config = config
127
131
 
128
132
  def get_attention_mask(
129
- self, idx: int, input_pos: torch.Tensor
133
+ self, attn_type: cfg.AttentionType, input_pos: torch.Tensor
130
134
  ) -> torch.Tensor:
131
- if self.config.attn_config.attn_types:
132
- if (
133
- self.config.attn_config.attn_types[idx]
134
- == cfg.AttentionType.LOCAL_SLIDING
135
- ):
136
- return self.sliding_window_mask_cache.index_select(2, input_pos)
137
-
135
+ if attn_type == cfg.AttentionType.LOCAL_SLIDING:
136
+ return self.sliding_window_mask_cache.index_select(2, input_pos)
138
137
  return self.mask_cache.index_select(2, input_pos)
139
138
 
140
139
  @torch.inference_mode
141
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
142
- _, seq_len = idx.size()
140
+ def forward(
141
+ self,
142
+ tokens: torch.Tensor,
143
+ input_pos: torch.Tensor,
144
+ kv_cache: kv_utils.KVCache,
145
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
146
+ _, seq_len = tokens.size()
143
147
  assert self.config.max_seq_len >= seq_len, (
144
148
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
145
149
  f" {self.config.max_seq_len}"
146
150
  )
151
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
152
+ "The number of transformer blocks and the number of KV cache entries"
153
+ " must be the same."
154
+ )
147
155
 
148
156
  cos, sin = self.rope_cache
149
157
  cos = cos.index_select(0, input_pos)
150
158
  sin = sin.index_select(0, input_pos)
151
159
 
152
160
  # token embeddings of shape (b, t, n_embd)
153
- x = self.tok_embedding(idx)
161
+ x = self.tok_embedding(tokens)
154
162
  x = x * (self.config.embedding_dim**0.5)
155
163
 
164
+ updated_kv_entires = []
156
165
  for i, block in enumerate(self.transformer_blocks):
157
- mask = self.get_attention_mask(i, input_pos)
158
- x = block(x, (cos, sin), mask, input_pos)
166
+ mask = self.get_attention_mask(
167
+ block.config.attn_config.attn_type, input_pos
168
+ )
169
+ kv_entry = kv_cache.caches[i] if kv_cache else None
170
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
171
+ if kv_entry:
172
+ updated_kv_entires.append(kv_entry)
173
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
159
174
 
160
175
  x = self.final_norm(x)
161
176
  res = self.lm_head(x) # (b, t, vocab_size)
@@ -163,7 +178,8 @@ class Gemma2(nn.Module):
163
178
  res = res / self.config.final_logit_softcap
164
179
  res = torch.tanh(res)
165
180
  res = res * self.config.final_logit_softcap
166
- return res
181
+
182
+ return {"logits": res, "kv_cache": updated_kv_cache}
167
183
 
168
184
 
169
185
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -176,18 +192,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
176
192
  Returns:
177
193
  The model config for a Gemma 2B model.
178
194
  """
179
- attn_config = cfg.AttentionConfig(
180
- num_heads=8,
181
- head_dim=256,
182
- num_query_groups=4,
183
- rotary_percentage=1.0,
184
- qkv_transpose_before_split=True,
185
- logit_softcap=50.0,
186
- sliding_window_size=4096,
187
- attn_types=[cfg.AttentionType.GLOBAL, cfg.AttentionType.LOCAL_SLIDING]
188
- * 13,
189
- )
190
-
191
195
  norm_config = cfg.NormalizationConfig(
192
196
  type=cfg.NormalizationType.RMS_NORM,
193
197
  epsilon=1e-6,
@@ -200,18 +204,38 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
200
204
  pre_ff_norm_config=norm_config,
201
205
  post_ff_norm_config=norm_config,
202
206
  )
207
+
208
+ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
209
+ attn_config = cfg.AttentionConfig(
210
+ num_heads=8,
211
+ head_dim=256,
212
+ num_query_groups=4,
213
+ rotary_percentage=1.0,
214
+ qkv_transpose_before_split=True,
215
+ logit_softcap=50.0,
216
+ sliding_window_size=4096,
217
+ attn_type=(
218
+ cfg.AttentionType.GLOBAL
219
+ if idx % 2 == 0
220
+ else cfg.AttentionType.LOCAL_SLIDING
221
+ ),
222
+ )
223
+ return cfg.TransformerBlockConfig(
224
+ attn_config=attn_config,
225
+ ff_config=ff_config,
226
+ pre_attention_norm_config=norm_config,
227
+ post_attention_norm_config=norm_config,
228
+ )
229
+
230
+ num_layers = 26
203
231
  config = cfg.ModelConfig(
204
232
  vocab_size=256000,
205
- num_layers=26,
233
+ num_layers=num_layers,
206
234
  max_seq_len=8192,
207
235
  embedding_dim=2304,
208
236
  kv_cache_max_len=kv_cache_max_len,
209
- attn_config=attn_config,
210
- ff_config=ff_config,
211
- pre_attention_norm_config=norm_config,
212
- post_attention_norm_config=norm_config,
237
+ block_configs=[get_block_config(i) for i in range(num_layers)],
213
238
  final_norm_config=norm_config,
214
- parallel_residual=False,
215
239
  lm_head_use_bias=False,
216
240
  enable_hlfb=True,
217
241
  final_logit_softcap=30.0,
@@ -221,14 +245,16 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
221
245
 
222
246
  def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
223
247
  config = get_model_config_2b(kv_cache_max_len)
224
- config.attn_config.num_heads = 4
225
- config.attn_config.head_dim = 64
226
- config.attn_config.sliding_window_size = 64
227
- config.ff_config.intermediate_size = 128
228
248
  config.vocab_size = 128
229
249
  config.num_layers = 2
230
250
  config.max_seq_len = 2 * kv_cache_max_len
231
251
  config.embedding_dim = 128
252
+ config.block_configs = config.block_configs[: config.num_layers]
253
+ for block_config in config.block_configs:
254
+ block_config.attn_config.num_heads = 4
255
+ block_config.attn_config.head_dim = 64
256
+ block_config.attn_config.sliding_window_size = 64
257
+ block_config.ff_config.intermediate_size = 128
232
258
  return config
233
259
 
234
260
 
@@ -236,21 +262,20 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
236
262
  config = get_model_config_2b(**kwargs)
237
263
  model = Gemma2(config)
238
264
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
239
- # since embedding and lm-head use the same weight, we need to set strict
265
+ # Since embedding and lm-head use the same weight, we need to set strict
240
266
  # to False.
241
267
  loader.load(model, strict=False)
242
268
  model.eval()
243
269
  return model
244
270
 
245
271
 
246
- def define_and_run_2b() -> None:
272
+ def define_and_run_2b(checkpoint_path: str) -> None:
247
273
  """Instantiates and runs a Gemma2 2B model."""
248
274
 
249
- current_dir = Path(__file__).parent.resolve()
275
+ current_dir = pathlib.Path(__file__).parent.resolve()
250
276
  gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
251
277
  print("Running GEMMA 2")
252
278
  kv_cache_max_len = 1024
253
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma2-2b")
254
279
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
255
280
  toks = torch.from_numpy(
256
281
  np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
@@ -258,11 +283,13 @@ def define_and_run_2b() -> None:
258
283
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
259
284
  tokens[0, :9] = toks
260
285
  input_pos = torch.arange(0, kv_cache_max_len)
261
- out = model.forward(tokens, input_pos)
262
- out_final = out[0, 8, :]
286
+ kv = kv_utils.KVCache.from_model_config(model.config)
287
+ out = model.forward(tokens, input_pos, kv)
288
+ out_final = out["logits"][0, 8, :]
263
289
  assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
264
290
 
265
291
 
266
292
  if __name__ == "__main__":
267
293
  torch.set_printoptions(sci_mode=True)
268
- define_and_run_2b()
294
+ path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
295
+ define_and_run_2b(path)
@@ -12,16 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- #
16
- # Note: This is an experimental version of phi2 with external KV cache.
17
- # Please use with caution.
15
+
16
+ """Example of converting a Phi-2 model to multi-signature tflite model."""
18
17
 
19
18
  import os
20
- from pathlib import Path
19
+ import pathlib
21
20
 
22
21
  import ai_edge_torch
23
- from ai_edge_torch.generative.examples.experimental.phi import phi2
24
- from ai_edge_torch.generative.layers.experimental import ekv_cache
22
+ from ai_edge_torch.generative.examples.phi import phi2
23
+ from ai_edge_torch.generative.layers import kv_cache
25
24
  from ai_edge_torch.generative.quantize import quant_recipes
26
25
  import torch
27
26
 
@@ -32,9 +31,8 @@ def convert_phi2_to_tflite(
32
31
  kv_cache_max_len: int = 1024,
33
32
  quantize: bool = True,
34
33
  ):
35
- """An example method for converting a Phi-2 model to multi-signature
34
+ """Converts a Phi-2 model to multi-signature tflite model.
36
35
 
37
- tflite model.
38
36
  Args:
39
37
  checkpoint_path (str): The filepath to the model checkpoint, or directory
40
38
  holding the checkpoint.
@@ -53,7 +51,7 @@ def convert_phi2_to_tflite(
53
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
54
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
55
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
56
- kv = ekv_cache.EKVCache.from_model_config(pytorch_model.config)
54
+ kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
57
55
 
58
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
59
57
  edge_model = (
@@ -77,11 +75,12 @@ def convert_phi2_to_tflite(
77
75
  )
78
76
  .convert(quant_config=quant_config)
79
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
80
79
  edge_model.export(
81
- f'/tmp/phi2_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
80
+ f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
82
81
  )
83
82
 
84
83
 
85
84
  if __name__ == '__main__':
86
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/phi2')
87
- convert_phi2_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2')
86
+ convert_phi2_to_tflite(path)