ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +35 -13
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  3. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  4. ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
  5. ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  7. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
  8. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  9. ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
  10. ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
  11. ai_edge_torch/generative/examples/t5/t5.py +35 -22
  12. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  13. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  14. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
  15. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  16. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
  17. ai_edge_torch/generative/layers/attention.py +77 -73
  18. ai_edge_torch/generative/layers/builder.py +5 -3
  19. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  20. ai_edge_torch/generative/layers/model_config.py +38 -19
  21. ai_edge_torch/generative/layers/normalization.py +158 -0
  22. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  23. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  24. ai_edge_torch/generative/test/test_loader.py +1 -1
  25. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  26. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  27. ai_edge_torch/generative/test/utils.py +54 -0
  28. ai_edge_torch/generative/utilities/loader.py +15 -15
  29. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  30. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  31. ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
  32. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
  33. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  34. ai_edge_torch/version.py +1 -1
  35. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
  36. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +41 -47
  37. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  38. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  39. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  40. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  41. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  42. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  43. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  44. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  45. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  46. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  47. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
  49. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
  50. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -12,14 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building the Gemma2 2B model.
15
+
16
+ """Example of building a Gemma2 model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
  from typing import Optional, Tuple
20
21
 
21
22
  from ai_edge_torch.generative.layers import attention
22
23
  from ai_edge_torch.generative.layers import builder
24
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
26
  import ai_edge_torch.generative.layers.model_config as cfg
25
27
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -51,7 +53,8 @@ class Gemma2Block(attention.TransformerBlock):
51
53
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
52
54
  mask: Optional[torch.Tensor] = None,
53
55
  input_pos: Optional[torch.Tensor] = None,
54
- ) -> torch.Tensor:
56
+ kv_cache: kv_utils.KVCacheEntry = None,
57
+ ) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
55
58
  """Forward function of the Gemma2Block.
56
59
 
57
60
  Exactly the same as TransformerBlock but we call the post-attention norm
@@ -62,17 +65,19 @@ class Gemma2Block(attention.TransformerBlock):
62
65
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
63
66
  mask (torch.Tensor): the optional mask tensor.
64
67
  input_pos (torch.Tensor): the optional input position tensor.
68
+ kv_cache (KVCacheEntry): the optional kv cache entry.
65
69
 
66
70
  Returns:
67
- output activation from this transformer block.
71
+ output activation from this transformer block, and updated kv cache (if
72
+ passed in).
68
73
  """
69
74
 
70
75
  x_norm = self.pre_atten_norm(x)
71
- attn_out = self.atten_func(x_norm, rope, mask, input_pos)
76
+ attn_out, kv = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
72
77
  attn_out_norm = self.post_atten_norm(attn_out)
73
78
  x = x + attn_out_norm
74
79
  output = x + self.ff(x)
75
- return output
80
+ return output, kv
76
81
 
77
82
 
78
83
  class Gemma2(nn.Module):
@@ -81,7 +86,6 @@ class Gemma2(nn.Module):
81
86
  def __init__(self, config: cfg.ModelConfig):
82
87
  super().__init__()
83
88
 
84
- self.config = config
85
89
  # Construct model layers.
86
90
  self.tok_embedding = nn.Embedding(
87
91
  config.vocab_size, config.embedding_dim, padding_idx=0
@@ -91,20 +95,22 @@ class Gemma2(nn.Module):
91
95
  config.vocab_size,
92
96
  bias=config.lm_head_use_bias,
93
97
  )
94
- # Gemma re-uses the embedding as the head projection layer.
98
+ # Gemma2 re-uses the embedding as the head projection layer.
95
99
  self.lm_head.weight.data = self.tok_embedding.weight.data
96
100
  self.transformer_blocks = nn.ModuleList(
97
- Gemma2Block(config) for _ in range(config.num_layers)
101
+ Gemma2Block(config.block_config(idx), config)
102
+ for idx in range(config.num_layers)
98
103
  )
99
104
  self.final_norm = builder.build_norm(
100
105
  config.embedding_dim,
101
106
  config.final_norm_config,
102
107
  )
108
+ # Gemma2 has same hyper parameters for each layer except for attention
109
+ # types. Use the first layer.
110
+ attn_config = config.block_config(0).attn_config
103
111
  self.rope_cache = attn_utils.build_rope_cache(
104
112
  size=config.kv_cache_max,
105
- dim=int(
106
- config.attn_config.rotary_percentage * config.attn_config.head_dim
107
- ),
113
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
108
114
  base=10_000,
109
115
  condense_ratio=1,
110
116
  dtype=torch.float32,
@@ -115,47 +121,56 @@ class Gemma2(nn.Module):
115
121
  dtype=torch.float32,
116
122
  device=torch.device("cpu"),
117
123
  )
118
-
119
124
  self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
120
125
  size=config.kv_cache_max,
121
- window_size=self.config.attn_config.sliding_window_size,
126
+ window_size=attn_config.sliding_window_size,
122
127
  dtype=torch.float32,
123
128
  device=torch.device("cpu"),
124
129
  )
125
-
126
130
  self.config = config
127
131
 
128
132
  def get_attention_mask(
129
- self, idx: int, input_pos: torch.Tensor
133
+ self, attn_type: cfg.AttentionType, input_pos: torch.Tensor
130
134
  ) -> torch.Tensor:
131
- if self.config.attn_config.attn_types:
132
- if (
133
- self.config.attn_config.attn_types[idx]
134
- == cfg.AttentionType.LOCAL_SLIDING
135
- ):
136
- return self.sliding_window_mask_cache.index_select(2, input_pos)
137
-
135
+ if attn_type == cfg.AttentionType.LOCAL_SLIDING:
136
+ return self.sliding_window_mask_cache.index_select(2, input_pos)
138
137
  return self.mask_cache.index_select(2, input_pos)
139
138
 
140
139
  @torch.inference_mode
141
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
142
- _, seq_len = idx.size()
140
+ def forward(
141
+ self,
142
+ tokens: torch.Tensor,
143
+ input_pos: torch.Tensor,
144
+ kv_cache: kv_utils.KVCache,
145
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
146
+ _, seq_len = tokens.size()
143
147
  assert self.config.max_seq_len >= seq_len, (
144
148
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
145
149
  f" {self.config.max_seq_len}"
146
150
  )
151
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
152
+ "The number of transformer blocks and the number of KV cache entries"
153
+ " must be the same."
154
+ )
147
155
 
148
156
  cos, sin = self.rope_cache
149
157
  cos = cos.index_select(0, input_pos)
150
158
  sin = sin.index_select(0, input_pos)
151
159
 
152
160
  # token embeddings of shape (b, t, n_embd)
153
- x = self.tok_embedding(idx)
161
+ x = self.tok_embedding(tokens)
154
162
  x = x * (self.config.embedding_dim**0.5)
155
163
 
164
+ updated_kv_entires = []
156
165
  for i, block in enumerate(self.transformer_blocks):
157
- mask = self.get_attention_mask(i, input_pos)
158
- x = block(x, (cos, sin), mask, input_pos)
166
+ mask = self.get_attention_mask(
167
+ block.config.attn_config.attn_type, input_pos
168
+ )
169
+ kv_entry = kv_cache.caches[i] if kv_cache else None
170
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
171
+ if kv_entry:
172
+ updated_kv_entires.append(kv_entry)
173
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
159
174
 
160
175
  x = self.final_norm(x)
161
176
  res = self.lm_head(x) # (b, t, vocab_size)
@@ -163,7 +178,8 @@ class Gemma2(nn.Module):
163
178
  res = res / self.config.final_logit_softcap
164
179
  res = torch.tanh(res)
165
180
  res = res * self.config.final_logit_softcap
166
- return res
181
+
182
+ return {"logits": res, "kv_cache": updated_kv_cache}
167
183
 
168
184
 
169
185
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -176,18 +192,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
176
192
  Returns:
177
193
  The model config for a Gemma 2B model.
178
194
  """
179
- attn_config = cfg.AttentionConfig(
180
- num_heads=8,
181
- head_dim=256,
182
- num_query_groups=4,
183
- rotary_percentage=1.0,
184
- qkv_transpose_before_split=True,
185
- logit_softcap=50.0,
186
- sliding_window_size=4096,
187
- attn_types=[cfg.AttentionType.GLOBAL, cfg.AttentionType.LOCAL_SLIDING]
188
- * 13,
189
- )
190
-
191
195
  norm_config = cfg.NormalizationConfig(
192
196
  type=cfg.NormalizationType.RMS_NORM,
193
197
  epsilon=1e-6,
@@ -200,18 +204,38 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
200
204
  pre_ff_norm_config=norm_config,
201
205
  post_ff_norm_config=norm_config,
202
206
  )
207
+
208
+ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
209
+ attn_config = cfg.AttentionConfig(
210
+ num_heads=8,
211
+ head_dim=256,
212
+ num_query_groups=4,
213
+ rotary_percentage=1.0,
214
+ qkv_transpose_before_split=True,
215
+ logit_softcap=50.0,
216
+ sliding_window_size=4096,
217
+ attn_type=(
218
+ cfg.AttentionType.GLOBAL
219
+ if idx % 2 == 0
220
+ else cfg.AttentionType.LOCAL_SLIDING
221
+ ),
222
+ )
223
+ return cfg.TransformerBlockConfig(
224
+ attn_config=attn_config,
225
+ ff_config=ff_config,
226
+ pre_attention_norm_config=norm_config,
227
+ post_attention_norm_config=norm_config,
228
+ )
229
+
230
+ num_layers = 26
203
231
  config = cfg.ModelConfig(
204
232
  vocab_size=256000,
205
- num_layers=26,
233
+ num_layers=num_layers,
206
234
  max_seq_len=8192,
207
235
  embedding_dim=2304,
208
236
  kv_cache_max_len=kv_cache_max_len,
209
- attn_config=attn_config,
210
- ff_config=ff_config,
211
- pre_attention_norm_config=norm_config,
212
- post_attention_norm_config=norm_config,
237
+ block_configs=[get_block_config(i) for i in range(num_layers)],
213
238
  final_norm_config=norm_config,
214
- parallel_residual=False,
215
239
  lm_head_use_bias=False,
216
240
  enable_hlfb=True,
217
241
  final_logit_softcap=30.0,
@@ -221,14 +245,16 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
221
245
 
222
246
  def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
223
247
  config = get_model_config_2b(kv_cache_max_len)
224
- config.attn_config.num_heads = 4
225
- config.attn_config.head_dim = 64
226
- config.attn_config.sliding_window_size = 64
227
- config.ff_config.intermediate_size = 128
228
248
  config.vocab_size = 128
229
249
  config.num_layers = 2
230
250
  config.max_seq_len = 2 * kv_cache_max_len
231
251
  config.embedding_dim = 128
252
+ config.block_configs = config.block_configs[: config.num_layers]
253
+ for block_config in config.block_configs:
254
+ block_config.attn_config.num_heads = 4
255
+ block_config.attn_config.head_dim = 64
256
+ block_config.attn_config.sliding_window_size = 64
257
+ block_config.ff_config.intermediate_size = 128
232
258
  return config
233
259
 
234
260
 
@@ -236,21 +262,20 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
236
262
  config = get_model_config_2b(**kwargs)
237
263
  model = Gemma2(config)
238
264
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
239
- # since embedding and lm-head use the same weight, we need to set strict
265
+ # Since embedding and lm-head use the same weight, we need to set strict
240
266
  # to False.
241
267
  loader.load(model, strict=False)
242
268
  model.eval()
243
269
  return model
244
270
 
245
271
 
246
- def define_and_run_2b() -> None:
272
+ def define_and_run_2b(checkpoint_path: str) -> None:
247
273
  """Instantiates and runs a Gemma2 2B model."""
248
274
 
249
- current_dir = Path(__file__).parent.resolve()
275
+ current_dir = pathlib.Path(__file__).parent.resolve()
250
276
  gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
251
277
  print("Running GEMMA 2")
252
278
  kv_cache_max_len = 1024
253
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma2-2b")
254
279
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
255
280
  toks = torch.from_numpy(
256
281
  np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
@@ -258,11 +283,13 @@ def define_and_run_2b() -> None:
258
283
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
259
284
  tokens[0, :9] = toks
260
285
  input_pos = torch.arange(0, kv_cache_max_len)
261
- out = model.forward(tokens, input_pos)
262
- out_final = out[0, 8, :]
286
+ kv = kv_utils.KVCache.from_model_config(model.config)
287
+ out = model.forward(tokens, input_pos, kv)
288
+ out_final = out["logits"][0, 8, :]
263
289
  assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
264
290
 
265
291
 
266
292
  if __name__ == "__main__":
267
293
  torch.set_printoptions(sci_mode=True)
268
- define_and_run_2b()
294
+ path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
295
+ define_and_run_2b(path)
@@ -12,16 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- #
16
- # Note: This is an experimental version of phi2 with external KV cache.
17
- # Please use with caution.
15
+
16
+ """Example of converting a Phi-2 model to multi-signature tflite model."""
18
17
 
19
18
  import os
20
- from pathlib import Path
19
+ import pathlib
21
20
 
22
21
  import ai_edge_torch
23
- from ai_edge_torch.generative.examples.experimental.phi import phi2
24
- from ai_edge_torch.generative.layers.experimental import ekv_cache
22
+ from ai_edge_torch.generative.examples.phi import phi2
23
+ from ai_edge_torch.generative.layers import kv_cache
25
24
  from ai_edge_torch.generative.quantize import quant_recipes
26
25
  import torch
27
26
 
@@ -32,9 +31,8 @@ def convert_phi2_to_tflite(
32
31
  kv_cache_max_len: int = 1024,
33
32
  quantize: bool = True,
34
33
  ):
35
- """An example method for converting a Phi-2 model to multi-signature
34
+ """Converts a Phi-2 model to multi-signature tflite model.
36
35
 
37
- tflite model.
38
36
  Args:
39
37
  checkpoint_path (str): The filepath to the model checkpoint, or directory
40
38
  holding the checkpoint.
@@ -53,7 +51,7 @@ def convert_phi2_to_tflite(
53
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
54
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
55
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
56
- kv = ekv_cache.EKVCache.from_model_config(pytorch_model.config)
54
+ kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
57
55
 
58
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
59
57
  edge_model = (
@@ -77,11 +75,12 @@ def convert_phi2_to_tflite(
77
75
  )
78
76
  .convert(quant_config=quant_config)
79
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
80
79
  edge_model.export(
81
- f'/tmp/phi2_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
80
+ f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
82
81
  )
83
82
 
84
83
 
85
84
  if __name__ == '__main__':
86
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/phi2')
87
- convert_phi2_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2')
86
+ convert_phi2_to_tflite(path)
@@ -12,26 +12,22 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building phi-2 model from the Edge Generative API layers.
16
- #
17
- # Note: This is an experimental version of phi2 with external KV cache.
18
- # Please use with caution.
15
+
16
+ """Example of building a Phi-2 model."""
19
17
 
20
18
  import os
21
- from pathlib import Path
22
- from typing import Tuple
19
+ import pathlib
23
20
 
21
+ from ai_edge_torch.generative.layers import attention
24
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
- from ai_edge_torch.generative.layers.experimental import attention
27
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
25
  import ai_edge_torch.generative.layers.model_config as cfg
29
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
30
27
  import numpy as np
31
28
  import torch
32
29
  from torch import nn
33
30
 
34
-
35
31
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
36
32
  ff_up_proj="model.layers.{}.mlp.fc1",
37
33
  ff_down_proj="model.layers.{}.mlp.fc2",
@@ -52,7 +48,6 @@ class Phi2(nn.Module):
52
48
  def __init__(self, config: cfg.ModelConfig):
53
49
  super().__init__()
54
50
 
55
- self.config = config
56
51
  # Construct model layers.
57
52
  self.lm_head = nn.Linear(
58
53
  config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
@@ -60,18 +55,20 @@ class Phi2(nn.Module):
60
55
  self.tok_embedding = nn.Embedding(
61
56
  config.vocab_size, config.embedding_dim, padding_idx=0
62
57
  )
58
+ # Phi-2 has only one block config.
59
+ block_config = config.block_config(0)
63
60
  self.transformer_blocks = nn.ModuleList(
64
- attention.TransformerBlock(config) for _ in range(config.num_layers)
61
+ attention.TransformerBlock(block_config, config)
62
+ for _ in range(config.num_layers)
65
63
  )
66
64
  self.final_norm = builder.build_norm(
67
65
  config.embedding_dim,
68
66
  config.final_norm_config,
69
67
  )
68
+ attn_config = block_config.attn_config
70
69
  self.rope_cache = attn_utils.build_rope_cache(
71
70
  size=config.kv_cache_max,
72
- dim=int(
73
- config.attn_config.rotary_percentage * config.attn_config.head_dim
74
- ),
71
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
75
72
  base=10_000,
76
73
  condense_ratio=1,
77
74
  dtype=torch.float32,
@@ -89,13 +86,17 @@ class Phi2(nn.Module):
89
86
  self,
90
87
  tokens: torch.Tensor,
91
88
  input_pos: torch.Tensor,
92
- kv_cache: kv_utils.EKVCache,
93
- ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
89
+ kv_cache: kv_utils.KVCache,
90
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
94
91
  _, seq_len = tokens.size()
95
92
  assert self.config.max_seq_len >= seq_len, (
96
93
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
97
94
  f" {self.config.max_seq_len}"
98
95
  )
96
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
97
+ "The number of transformer blocks and the number of KV cache entries"
98
+ " must be the same."
99
+ )
99
100
 
100
101
  cos, sin = self.rope_cache
101
102
  cos = cos.index_select(0, input_pos)
@@ -111,11 +112,11 @@ class Phi2(nn.Module):
111
112
  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
112
113
  if kv_entry:
113
114
  updated_kv_entires.append(kv_entry)
114
- updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
115
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
115
116
 
116
117
  x = self.final_norm(x)
117
- res = self.lm_head(x) # (b, t, vocab_size)
118
- return res, updated_kv_cache
118
+ logits = self.lm_head(x) # (b, t, vocab_size)
119
+ return {"logits": logits, "kv_cache": updated_kv_cache}
119
120
 
120
121
 
121
122
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -143,17 +144,20 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
143
144
  use_bias=True,
144
145
  )
145
146
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
147
+ block_config = cfg.TransformerBlockConfig(
148
+ attn_config=attn_config,
149
+ ff_config=ff_config,
150
+ pre_attention_norm_config=norm_config,
151
+ parallel_residual=True,
152
+ )
146
153
  config = cfg.ModelConfig(
147
154
  vocab_size=51200,
148
155
  num_layers=32,
149
156
  max_seq_len=2048,
150
157
  kv_cache_max_len=kv_cache_max_len,
151
158
  embedding_dim=2560,
152
- attn_config=attn_config,
153
- ff_config=ff_config,
154
- pre_attention_norm_config=norm_config,
159
+ block_configs=block_config,
155
160
  final_norm_config=norm_config,
156
- parallel_residual=True,
157
161
  lm_head_use_bias=True,
158
162
  enable_hlfb=True,
159
163
  )
@@ -165,43 +169,42 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
165
169
  config.vocab_size = 128
166
170
  config.num_layers = 2
167
171
  config.max_seq_len = 2 * kv_cache_max_len
168
- config.ff_config.intermediate_size = 128
172
+ # Phi-2 has only one block config.
173
+ config.block_config(0).ff_config.intermediate_size = 128
169
174
  return config
170
175
 
171
176
 
172
- def build_model(
173
- checkpoint_path: str, test_model: bool = False, **kwargs
174
- ) -> nn.Module:
177
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
175
178
  """Instantiates the model instance and load checkpoint if provided."""
176
- config = (
177
- get_fake_model_config(**kwargs)
178
- if test_model
179
- else get_model_config(**kwargs)
180
- )
179
+ config = get_model_config(**kwargs)
181
180
  model = Phi2(config)
182
- if checkpoint_path is not None:
183
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
184
- loader.load(model)
181
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
182
+ loader.load(model)
185
183
  model.eval()
186
184
  return model
187
185
 
188
186
 
189
- def define_and_run(checkpoint_path: str, test_model: bool = False) -> None:
187
+ def define_and_run(checkpoint_path: str) -> None:
190
188
  """Instantiates and runs a Phi-2 model."""
191
189
 
190
+ current_dir = pathlib.Path(__file__).parent.resolve()
191
+ phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
192
192
  kv_cache_max_len = 1024
193
- model = build_model(
194
- checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
195
- )
193
+ model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
196
194
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
197
195
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
198
196
  tokens[0, :4] = idx
199
197
  input_pos = torch.arange(0, kv_cache_max_len)
200
- kv = kv_utils.EKVCache.from_model_config(model.config)
201
- print("running an inference")
202
- print(model.forward(tokens, input_pos, kv))
198
+ kv = kv_utils.KVCache.from_model_config(model.config)
199
+ output = model.forward(tokens, input_pos, kv)
200
+ print("comparing with goldens..")
201
+ assert torch.allclose(
202
+ phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
203
+ )
203
204
 
204
205
 
205
206
  if __name__ == "__main__":
206
- input_checkpoint_path = os.path.join(Path.home(), "Downloads/phi2")
207
+ input_checkpoint_path = os.path.join(
208
+ pathlib.Path.home(), "Downloads/llm_data/phi2"
209
+ )
207
210
  define_and_run(input_checkpoint_path)
@@ -12,30 +12,27 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- #
16
- # Note: This is an experimental version of Gemma with external KV cache.
17
- # Please use with caution.
18
15
 
16
+ """Example of converting SmalLM model to multi-signature tflite model."""
19
17
 
20
18
  import os
21
- from pathlib import Path
19
+ import pathlib
22
20
 
23
21
  import ai_edge_torch
24
- from ai_edge_torch.generative.examples.experimental.gemma import gemma
25
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
22
+ from ai_edge_torch.generative.examples.smallm import smallm
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
26
24
  from ai_edge_torch.generative.quantize import quant_recipes
27
25
  import torch
28
26
 
29
27
 
30
- def convert_gemma_to_tflite(
28
+ def convert_smallm_to_tflite(
31
29
  checkpoint_path: str,
32
30
  prefill_seq_len: int = 512,
33
31
  kv_cache_max_len: int = 1024,
34
32
  quantize: bool = True,
35
33
  ):
36
- """An example method for converting a Gemma 2B model to multi-signature
34
+ """Converts SmalLM model to multi-signature tflite model.
37
35
 
38
- tflite model.
39
36
  Args:
40
37
  checkpoint_path (str): The filepath to the model checkpoint, or directory
41
38
  holding the checkpoint.
@@ -46,7 +43,7 @@ def convert_gemma_to_tflite(
46
43
  quantize (bool, optional): Whether the model should be quanized. Defaults
47
44
  to True.
48
45
  """
49
- pytorch_model = gemma.build_2b_model(
46
+ pytorch_model = smallm.build_model(
50
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
51
48
  )
52
49
  # Tensors used to trace the model graph during conversion.
@@ -54,7 +51,7 @@ def convert_gemma_to_tflite(
54
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
55
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
56
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
57
- kv = kv_utils.EKVCache.from_model_config(pytorch_model.config)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
58
55
 
59
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
60
57
  edge_model = (
@@ -78,11 +75,12 @@ def convert_gemma_to_tflite(
78
75
  )
79
76
  .convert(quant_config=quant_config)
80
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
81
79
  edge_model.export(
82
- f'/tmp/gemma_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
80
+ f'/tmp/smallm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
83
81
  )
84
82
 
85
83
 
86
84
  if __name__ == '__main__':
87
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
88
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smallm')
86
+ convert_smallm_to_tflite(path)