ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (48) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  3. ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
  4. ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
  5. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
  7. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  8. ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
  9. ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
  10. ai_edge_torch/generative/examples/t5/t5.py +35 -22
  11. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  12. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  13. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
  14. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  15. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
  16. ai_edge_torch/generative/layers/attention.py +77 -73
  17. ai_edge_torch/generative/layers/builder.py +5 -3
  18. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  19. ai_edge_torch/generative/layers/model_config.py +38 -19
  20. ai_edge_torch/generative/layers/normalization.py +158 -0
  21. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  22. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  23. ai_edge_torch/generative/test/test_loader.py +1 -1
  24. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  25. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  26. ai_edge_torch/generative/test/utils.py +54 -0
  27. ai_edge_torch/generative/utilities/loader.py +15 -15
  28. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  29. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  30. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  31. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  32. ai_edge_torch/version.py +1 -1
  33. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
  34. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +39 -45
  35. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  36. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  37. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  38. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  39. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  40. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  42. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  43. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  44. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  45. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  46. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
  47. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -13,32 +13,35 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting a Gemma2 model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.gemma import gemma2
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
24
27
 
25
- def convert_gemma_to_tflite(
28
+ def convert_gemma2_to_tflite(
26
29
  checkpoint_path: str,
27
30
  prefill_seq_len: int = 512,
28
31
  kv_cache_max_len: int = 1024,
29
32
  quantize: bool = True,
30
33
  ):
31
- """Converting a Gemma 2 2B model to multi-signature
32
- tflite model.
34
+ """Converts a Gemma2 2B model to multi-signature tflite model.
33
35
 
34
36
  Args:
35
- checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
37
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
38
+ holding the checkpoint.
36
39
  prefill_seq_len (int, optional): The maximum size of prefill input tensor.
37
40
  Defaults to 512.
38
41
  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
39
42
  including both prefill and decode. Defaults to 1024.
40
- quantize (bool, optional): Whether the model should be quanized.
41
- Defaults to True.
43
+ quantize (bool, optional): Whether the model should be quanized. Defaults
44
+ to True.
42
45
  """
43
46
  pytorch_model = gemma2.build_2b_model(
44
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/gemma2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma2-2b')
67
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
86
+ convert_gemma2_to_tflite(path)
@@ -13,11 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting a Gemma model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.gemma import gemma
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/gemma_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
67
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
86
+ convert_gemma_to_tflite(path)
@@ -12,13 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building a Gemma model.
15
+
16
+ """Example of building a Gemma model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
 
20
21
  from ai_edge_torch.generative.layers import attention
21
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
25
  import ai_edge_torch.generative.layers.model_config as cfg
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -48,7 +50,6 @@ class Gemma(nn.Module):
48
50
  def __init__(self, config: cfg.ModelConfig):
49
51
  super().__init__()
50
52
 
51
- self.config = config
52
53
  # Construct model layers.
53
54
  self.tok_embedding = nn.Embedding(
54
55
  config.vocab_size, config.embedding_dim, padding_idx=0
@@ -60,18 +61,20 @@ class Gemma(nn.Module):
60
61
  )
61
62
  # Gemma re-uses the embedding as the head projection layer.
62
63
  self.lm_head.weight.data = self.tok_embedding.weight.data
64
+ # Gemma has only one block config.
65
+ block_config = config.block_config(0)
63
66
  self.transformer_blocks = nn.ModuleList(
64
- attention.TransformerBlock(config) for _ in range(config.num_layers)
67
+ attention.TransformerBlock(block_config, config)
68
+ for _ in range(config.num_layers)
65
69
  )
66
70
  self.final_norm = builder.build_norm(
67
71
  config.embedding_dim,
68
72
  config.final_norm_config,
69
73
  )
74
+ attn_config = block_config.attn_config
70
75
  self.rope_cache = attn_utils.build_rope_cache(
71
76
  size=config.kv_cache_max,
72
- dim=int(
73
- config.attn_config.rotary_percentage * config.attn_config.head_dim
74
- ),
77
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
75
78
  base=10_000,
76
79
  condense_ratio=1,
77
80
  dtype=torch.float32,
@@ -84,16 +87,22 @@ class Gemma(nn.Module):
84
87
  )
85
88
  self.config = config
86
89
 
87
- # The model's forward function takes in additional k/v cache tensors
88
- # and returns the updated k/v cache tensors to the caller.
89
- # This can be eliminated if we handle k/v cache updates inside the model itself.
90
90
  @torch.inference_mode
91
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
92
- _, seq_len = idx.size()
91
+ def forward(
92
+ self,
93
+ tokens: torch.Tensor,
94
+ input_pos: torch.Tensor,
95
+ kv_cache: kv_utils.KVCache,
96
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
97
+ _, seq_len = tokens.size()
93
98
  assert self.config.max_seq_len >= seq_len, (
94
99
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
95
100
  f" {self.config.max_seq_len}"
96
101
  )
102
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
103
+ "The number of transformer blocks and the number of KV cache entries"
104
+ " must be the same."
105
+ )
97
106
 
98
107
  cos, sin = self.rope_cache
99
108
  cos = cos.index_select(0, input_pos)
@@ -102,15 +111,20 @@ class Gemma(nn.Module):
102
111
  mask = mask[:, :, :, : self.config.kv_cache_max]
103
112
 
104
113
  # token embeddings of shape (b, t, n_embd)
105
- x = self.tok_embedding(idx)
114
+ x = self.tok_embedding(tokens)
106
115
  x = x * (self.config.embedding_dim**0.5)
107
116
 
108
- for _, block in enumerate(self.transformer_blocks):
109
- x = block(x, (cos, sin), mask, input_pos)
117
+ updated_kv_entires = []
118
+ for i, block in enumerate(self.transformer_blocks):
119
+ kv_entry = kv_cache.caches[i] if kv_cache else None
120
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
121
+ if kv_entry:
122
+ updated_kv_entires.append(kv_entry)
123
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
110
124
 
111
125
  x = self.final_norm(x)
112
- res = self.lm_head(x) # (b, t, vocab_size)
113
- return res
126
+ logits = self.lm_head(x) # (b, t, vocab_size)
127
+ return {"logits": logits, "kv_cache": updated_kv_cache}
114
128
 
115
129
 
116
130
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -139,18 +153,20 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
139
153
  epsilon=1e-6,
140
154
  zero_centered=True,
141
155
  )
156
+ block_config = cfg.TransformerBlockConfig(
157
+ attn_config=attn_config,
158
+ ff_config=ff_config,
159
+ pre_attention_norm_config=norm_config,
160
+ post_attention_norm_config=norm_config,
161
+ )
142
162
  config = cfg.ModelConfig(
143
163
  vocab_size=256000,
144
164
  num_layers=18,
145
165
  max_seq_len=8192,
146
166
  embedding_dim=2048,
147
167
  kv_cache_max_len=kv_cache_max_len,
148
- attn_config=attn_config,
149
- ff_config=ff_config,
150
- pre_attention_norm_config=norm_config,
151
- post_attention_norm_config=norm_config,
168
+ block_configs=block_config,
152
169
  final_norm_config=norm_config,
153
- parallel_residual=False,
154
170
  lm_head_use_bias=False,
155
171
  enable_hlfb=True,
156
172
  )
@@ -159,7 +175,8 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
159
175
 
160
176
  def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
161
177
  config = get_model_config_2b(kv_cache_max_len)
162
- config.ff_config.intermediate_size = 128
178
+ # Gemma has only one block config.
179
+ config.block_config(0).ff_config.intermediate_size = 128
163
180
  config.vocab_size = 128
164
181
  config.num_layers = 2
165
182
  config.max_seq_len = 2 * kv_cache_max_len
@@ -170,32 +187,35 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
170
187
  config = get_model_config_2b(**kwargs)
171
188
  model = Gemma(config)
172
189
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
173
- # since embedding and lm-head use the same weight, we need to set strict
190
+ # Since embedding and lm-head use the same weight, we need to set strict
174
191
  # to False.
175
192
  loader.load(model, strict=False)
176
193
  model.eval()
177
194
  return model
178
195
 
179
196
 
180
- def define_and_run_2b() -> None:
197
+ def define_and_run_2b(checkpoint_path: str) -> None:
181
198
  """Instantiates and runs a Gemma 2B model."""
182
199
 
183
- current_dir = Path(__file__).parent.resolve()
200
+ current_dir = pathlib.Path(__file__).parent.resolve()
184
201
  gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
185
202
 
186
203
  kv_cache_max_len = 1024
187
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
188
204
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
189
205
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
190
206
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
191
207
  tokens[0, :4] = idx
192
208
  input_pos = torch.arange(0, kv_cache_max_len)
193
- lm_logits = model.forward(tokens, input_pos)
209
+ kv = kv_utils.KVCache.from_model_config(model.config)
210
+ output = model.forward(tokens, input_pos, kv)
194
211
  print("comparing with goldens..")
195
212
  assert torch.allclose(
196
- gemma_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
213
+ gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
197
214
  )
198
215
 
199
216
 
200
217
  if __name__ == "__main__":
201
- define_and_run_2b()
218
+ input_checkpoint_path = os.path.join(
219
+ pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
220
+ )
221
+ define_and_run_2b(input_checkpoint_path)
@@ -12,14 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building the Gemma2 2B model.
15
+
16
+ """Example of building a Gemma2 model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
  from typing import Optional, Tuple
20
21
 
21
22
  from ai_edge_torch.generative.layers import attention
22
23
  from ai_edge_torch.generative.layers import builder
24
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
26
  import ai_edge_torch.generative.layers.model_config as cfg
25
27
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -51,7 +53,8 @@ class Gemma2Block(attention.TransformerBlock):
51
53
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
52
54
  mask: Optional[torch.Tensor] = None,
53
55
  input_pos: Optional[torch.Tensor] = None,
54
- ) -> torch.Tensor:
56
+ kv_cache: kv_utils.KVCacheEntry = None,
57
+ ) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
55
58
  """Forward function of the Gemma2Block.
56
59
 
57
60
  Exactly the same as TransformerBlock but we call the post-attention norm
@@ -62,17 +65,19 @@ class Gemma2Block(attention.TransformerBlock):
62
65
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
63
66
  mask (torch.Tensor): the optional mask tensor.
64
67
  input_pos (torch.Tensor): the optional input position tensor.
68
+ kv_cache (KVCacheEntry): the optional kv cache entry.
65
69
 
66
70
  Returns:
67
- output activation from this transformer block.
71
+ output activation from this transformer block, and updated kv cache (if
72
+ passed in).
68
73
  """
69
74
 
70
75
  x_norm = self.pre_atten_norm(x)
71
- attn_out = self.atten_func(x_norm, rope, mask, input_pos)
76
+ attn_out, kv = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
72
77
  attn_out_norm = self.post_atten_norm(attn_out)
73
78
  x = x + attn_out_norm
74
79
  output = x + self.ff(x)
75
- return output
80
+ return output, kv
76
81
 
77
82
 
78
83
  class Gemma2(nn.Module):
@@ -81,7 +86,6 @@ class Gemma2(nn.Module):
81
86
  def __init__(self, config: cfg.ModelConfig):
82
87
  super().__init__()
83
88
 
84
- self.config = config
85
89
  # Construct model layers.
86
90
  self.tok_embedding = nn.Embedding(
87
91
  config.vocab_size, config.embedding_dim, padding_idx=0
@@ -91,20 +95,22 @@ class Gemma2(nn.Module):
91
95
  config.vocab_size,
92
96
  bias=config.lm_head_use_bias,
93
97
  )
94
- # Gemma re-uses the embedding as the head projection layer.
98
+ # Gemma2 re-uses the embedding as the head projection layer.
95
99
  self.lm_head.weight.data = self.tok_embedding.weight.data
96
100
  self.transformer_blocks = nn.ModuleList(
97
- Gemma2Block(config) for _ in range(config.num_layers)
101
+ Gemma2Block(config.block_config(idx), config)
102
+ for idx in range(config.num_layers)
98
103
  )
99
104
  self.final_norm = builder.build_norm(
100
105
  config.embedding_dim,
101
106
  config.final_norm_config,
102
107
  )
108
+ # Gemma2 has same hyper parameters for each layer except for attention
109
+ # types. Use the first layer.
110
+ attn_config = config.block_config(0).attn_config
103
111
  self.rope_cache = attn_utils.build_rope_cache(
104
112
  size=config.kv_cache_max,
105
- dim=int(
106
- config.attn_config.rotary_percentage * config.attn_config.head_dim
107
- ),
113
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
108
114
  base=10_000,
109
115
  condense_ratio=1,
110
116
  dtype=torch.float32,
@@ -115,47 +121,56 @@ class Gemma2(nn.Module):
115
121
  dtype=torch.float32,
116
122
  device=torch.device("cpu"),
117
123
  )
118
-
119
124
  self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
120
125
  size=config.kv_cache_max,
121
- window_size=self.config.attn_config.sliding_window_size,
126
+ window_size=attn_config.sliding_window_size,
122
127
  dtype=torch.float32,
123
128
  device=torch.device("cpu"),
124
129
  )
125
-
126
130
  self.config = config
127
131
 
128
132
  def get_attention_mask(
129
- self, idx: int, input_pos: torch.Tensor
133
+ self, attn_type: cfg.AttentionType, input_pos: torch.Tensor
130
134
  ) -> torch.Tensor:
131
- if self.config.attn_config.attn_types:
132
- if (
133
- self.config.attn_config.attn_types[idx]
134
- == cfg.AttentionType.LOCAL_SLIDING
135
- ):
136
- return self.sliding_window_mask_cache.index_select(2, input_pos)
137
-
135
+ if attn_type == cfg.AttentionType.LOCAL_SLIDING:
136
+ return self.sliding_window_mask_cache.index_select(2, input_pos)
138
137
  return self.mask_cache.index_select(2, input_pos)
139
138
 
140
139
  @torch.inference_mode
141
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
142
- _, seq_len = idx.size()
140
+ def forward(
141
+ self,
142
+ tokens: torch.Tensor,
143
+ input_pos: torch.Tensor,
144
+ kv_cache: kv_utils.KVCache,
145
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
146
+ _, seq_len = tokens.size()
143
147
  assert self.config.max_seq_len >= seq_len, (
144
148
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
145
149
  f" {self.config.max_seq_len}"
146
150
  )
151
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
152
+ "The number of transformer blocks and the number of KV cache entries"
153
+ " must be the same."
154
+ )
147
155
 
148
156
  cos, sin = self.rope_cache
149
157
  cos = cos.index_select(0, input_pos)
150
158
  sin = sin.index_select(0, input_pos)
151
159
 
152
160
  # token embeddings of shape (b, t, n_embd)
153
- x = self.tok_embedding(idx)
161
+ x = self.tok_embedding(tokens)
154
162
  x = x * (self.config.embedding_dim**0.5)
155
163
 
164
+ updated_kv_entires = []
156
165
  for i, block in enumerate(self.transformer_blocks):
157
- mask = self.get_attention_mask(i, input_pos)
158
- x = block(x, (cos, sin), mask, input_pos)
166
+ mask = self.get_attention_mask(
167
+ block.config.attn_config.attn_type, input_pos
168
+ )
169
+ kv_entry = kv_cache.caches[i] if kv_cache else None
170
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
171
+ if kv_entry:
172
+ updated_kv_entires.append(kv_entry)
173
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
159
174
 
160
175
  x = self.final_norm(x)
161
176
  res = self.lm_head(x) # (b, t, vocab_size)
@@ -163,7 +178,8 @@ class Gemma2(nn.Module):
163
178
  res = res / self.config.final_logit_softcap
164
179
  res = torch.tanh(res)
165
180
  res = res * self.config.final_logit_softcap
166
- return res
181
+
182
+ return {"logits": res, "kv_cache": updated_kv_cache}
167
183
 
168
184
 
169
185
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -176,18 +192,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
176
192
  Returns:
177
193
  The model config for a Gemma 2B model.
178
194
  """
179
- attn_config = cfg.AttentionConfig(
180
- num_heads=8,
181
- head_dim=256,
182
- num_query_groups=4,
183
- rotary_percentage=1.0,
184
- qkv_transpose_before_split=True,
185
- logit_softcap=50.0,
186
- sliding_window_size=4096,
187
- attn_types=[cfg.AttentionType.GLOBAL, cfg.AttentionType.LOCAL_SLIDING]
188
- * 13,
189
- )
190
-
191
195
  norm_config = cfg.NormalizationConfig(
192
196
  type=cfg.NormalizationType.RMS_NORM,
193
197
  epsilon=1e-6,
@@ -200,18 +204,38 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
200
204
  pre_ff_norm_config=norm_config,
201
205
  post_ff_norm_config=norm_config,
202
206
  )
207
+
208
+ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
209
+ attn_config = cfg.AttentionConfig(
210
+ num_heads=8,
211
+ head_dim=256,
212
+ num_query_groups=4,
213
+ rotary_percentage=1.0,
214
+ qkv_transpose_before_split=True,
215
+ logit_softcap=50.0,
216
+ sliding_window_size=4096,
217
+ attn_type=(
218
+ cfg.AttentionType.GLOBAL
219
+ if idx % 2 == 0
220
+ else cfg.AttentionType.LOCAL_SLIDING
221
+ ),
222
+ )
223
+ return cfg.TransformerBlockConfig(
224
+ attn_config=attn_config,
225
+ ff_config=ff_config,
226
+ pre_attention_norm_config=norm_config,
227
+ post_attention_norm_config=norm_config,
228
+ )
229
+
230
+ num_layers = 26
203
231
  config = cfg.ModelConfig(
204
232
  vocab_size=256000,
205
- num_layers=26,
233
+ num_layers=num_layers,
206
234
  max_seq_len=8192,
207
235
  embedding_dim=2304,
208
236
  kv_cache_max_len=kv_cache_max_len,
209
- attn_config=attn_config,
210
- ff_config=ff_config,
211
- pre_attention_norm_config=norm_config,
212
- post_attention_norm_config=norm_config,
237
+ block_configs=[get_block_config(i) for i in range(num_layers)],
213
238
  final_norm_config=norm_config,
214
- parallel_residual=False,
215
239
  lm_head_use_bias=False,
216
240
  enable_hlfb=True,
217
241
  final_logit_softcap=30.0,
@@ -221,14 +245,16 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
221
245
 
222
246
  def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
223
247
  config = get_model_config_2b(kv_cache_max_len)
224
- config.attn_config.num_heads = 4
225
- config.attn_config.head_dim = 64
226
- config.attn_config.sliding_window_size = 64
227
- config.ff_config.intermediate_size = 128
228
248
  config.vocab_size = 128
229
249
  config.num_layers = 2
230
250
  config.max_seq_len = 2 * kv_cache_max_len
231
251
  config.embedding_dim = 128
252
+ config.block_configs = config.block_configs[: config.num_layers]
253
+ for block_config in config.block_configs:
254
+ block_config.attn_config.num_heads = 4
255
+ block_config.attn_config.head_dim = 64
256
+ block_config.attn_config.sliding_window_size = 64
257
+ block_config.ff_config.intermediate_size = 128
232
258
  return config
233
259
 
234
260
 
@@ -236,21 +262,20 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
236
262
  config = get_model_config_2b(**kwargs)
237
263
  model = Gemma2(config)
238
264
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
239
- # since embedding and lm-head use the same weight, we need to set strict
265
+ # Since embedding and lm-head use the same weight, we need to set strict
240
266
  # to False.
241
267
  loader.load(model, strict=False)
242
268
  model.eval()
243
269
  return model
244
270
 
245
271
 
246
- def define_and_run_2b() -> None:
272
+ def define_and_run_2b(checkpoint_path: str) -> None:
247
273
  """Instantiates and runs a Gemma2 2B model."""
248
274
 
249
- current_dir = Path(__file__).parent.resolve()
275
+ current_dir = pathlib.Path(__file__).parent.resolve()
250
276
  gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
251
277
  print("Running GEMMA 2")
252
278
  kv_cache_max_len = 1024
253
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma2-2b")
254
279
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
255
280
  toks = torch.from_numpy(
256
281
  np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
@@ -258,11 +283,13 @@ def define_and_run_2b() -> None:
258
283
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
259
284
  tokens[0, :9] = toks
260
285
  input_pos = torch.arange(0, kv_cache_max_len)
261
- out = model.forward(tokens, input_pos)
262
- out_final = out[0, 8, :]
286
+ kv = kv_utils.KVCache.from_model_config(model.config)
287
+ out = model.forward(tokens, input_pos, kv)
288
+ out_final = out["logits"][0, 8, :]
263
289
  assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
264
290
 
265
291
 
266
292
  if __name__ == "__main__":
267
293
  torch.set_printoptions(sci_mode=True)
268
- define_and_run_2b()
294
+ path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
295
+ define_and_run_2b(path)
@@ -12,16 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- #
16
- # Note: This is an experimental version of phi2 with external KV cache.
17
- # Please use with caution.
15
+
16
+ """Example of converting a Phi-2 model to multi-signature tflite model."""
18
17
 
19
18
  import os
20
- from pathlib import Path
19
+ import pathlib
21
20
 
22
21
  import ai_edge_torch
23
- from ai_edge_torch.generative.examples.experimental.phi import phi2
24
- from ai_edge_torch.generative.layers.experimental import ekv_cache
22
+ from ai_edge_torch.generative.examples.phi import phi2
23
+ from ai_edge_torch.generative.layers import kv_cache
25
24
  from ai_edge_torch.generative.quantize import quant_recipes
26
25
  import torch
27
26
 
@@ -32,9 +31,8 @@ def convert_phi2_to_tflite(
32
31
  kv_cache_max_len: int = 1024,
33
32
  quantize: bool = True,
34
33
  ):
35
- """An example method for converting a Phi-2 model to multi-signature
34
+ """Converts a Phi-2 model to multi-signature tflite model.
36
35
 
37
- tflite model.
38
36
  Args:
39
37
  checkpoint_path (str): The filepath to the model checkpoint, or directory
40
38
  holding the checkpoint.
@@ -53,7 +51,7 @@ def convert_phi2_to_tflite(
53
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
54
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
55
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
56
- kv = ekv_cache.EKVCache.from_model_config(pytorch_model.config)
54
+ kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
57
55
 
58
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
59
57
  edge_model = (
@@ -77,11 +75,12 @@ def convert_phi2_to_tflite(
77
75
  )
78
76
  .convert(quant_config=quant_config)
79
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
80
79
  edge_model.export(
81
- f'/tmp/phi2_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
80
+ f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
82
81
  )
83
82
 
84
83
 
85
84
  if __name__ == '__main__':
86
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/phi2')
87
- convert_phi2_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2')
86
+ convert_phi2_to_tflite(path)