ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +35 -13
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  3. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  4. ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
  5. ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  7. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
  8. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  9. ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
  10. ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
  11. ai_edge_torch/generative/examples/t5/t5.py +35 -22
  12. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  13. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  14. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
  15. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  16. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
  17. ai_edge_torch/generative/layers/attention.py +77 -73
  18. ai_edge_torch/generative/layers/builder.py +5 -3
  19. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  20. ai_edge_torch/generative/layers/model_config.py +38 -19
  21. ai_edge_torch/generative/layers/normalization.py +158 -0
  22. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  23. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  24. ai_edge_torch/generative/test/test_loader.py +1 -1
  25. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  26. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  27. ai_edge_torch/generative/test/utils.py +54 -0
  28. ai_edge_torch/generative/utilities/loader.py +15 -15
  29. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  30. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  31. ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
  32. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
  33. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  34. ai_edge_torch/version.py +1 -1
  35. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
  36. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +41 -47
  37. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  38. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  39. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  40. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  41. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  42. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  43. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  44. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  45. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  46. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  47. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
  49. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
  50. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,122 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a SmalLM model."""
17
+
18
+ import copy
19
+ import os
20
+ import pathlib
21
+
22
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
+ import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import numpy as np
27
+ import torch
28
+ from torch import nn
29
+
30
+ TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
31
+ # SmalLM re-uses the embedding as the head projection layer.
32
+ TENSOR_NAMES.lm_head = None
33
+
34
+
35
+ class SmalLM(tiny_llama.TinyLlama):
36
+ """A SmalLM model built from the Edge Generative API layers.
37
+
38
+ SmalLM shares the same architecture as TinyLlama, but with different model
39
+ sizes.
40
+ """
41
+
42
+ def __init__(self, config: cfg.ModelConfig):
43
+ super().__init__(config)
44
+ # SmalLM re-uses the embedding as the head projection layer.
45
+ self.lm_head.weight.data = self.tok_embedding.weight.data
46
+
47
+
48
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
49
+ """Returns the model config for a SmalLM 135M model.
50
+
51
+ Args:
52
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
53
+ is 1024.
54
+
55
+ Returns:
56
+ The model config for a SmalLM model.
57
+ """
58
+ attn_config = cfg.AttentionConfig(
59
+ num_heads=9,
60
+ head_dim=64,
61
+ num_query_groups=3,
62
+ rotary_percentage=1.0,
63
+ )
64
+ ff_config = cfg.FeedForwardConfig(
65
+ type=cfg.FeedForwardType.GATED,
66
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
67
+ intermediate_size=1536,
68
+ )
69
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
70
+ block_config = cfg.TransformerBlockConfig(
71
+ attn_config=attn_config,
72
+ ff_config=ff_config,
73
+ pre_attention_norm_config=norm_config,
74
+ post_attention_norm_config=norm_config,
75
+ )
76
+ config = cfg.ModelConfig(
77
+ vocab_size=49152,
78
+ num_layers=30,
79
+ max_seq_len=2048,
80
+ embedding_dim=576,
81
+ kv_cache_max_len=kv_cache_max_len,
82
+ block_configs=block_config,
83
+ final_norm_config=norm_config,
84
+ enable_hlfb=True,
85
+ )
86
+ return config
87
+
88
+
89
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
90
+ config = get_model_config(**kwargs)
91
+ model = SmalLM(config)
92
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
93
+ # Since embedding and lm-head use the same weight, we need to set strict
94
+ # to False.
95
+ loader.load(model, strict=False)
96
+ model.eval()
97
+ return model
98
+
99
+
100
+ def define_and_run(checkpoint_path: str) -> None:
101
+ """Instantiates and runs a SmalLM model."""
102
+
103
+ current_dir = pathlib.Path(__file__).parent.resolve()
104
+ smallm_goldens = torch.load(current_dir / "smallm_lm_logits.pt")
105
+ kv_cache_max_len = 1024
106
+ model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
107
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
108
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
109
+ tokens[0, :4] = idx
110
+ input_pos = torch.arange(0, kv_cache_max_len)
111
+ kv = kv_utils.KVCache.from_model_config(model.config)
112
+ output = model.forward(tokens, input_pos, kv)
113
+ assert torch.allclose(
114
+ smallm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
115
+ )
116
+
117
+
118
+ if __name__ == "__main__":
119
+ input_checkpoint_path = os.path.join(
120
+ pathlib.Path.home(), "Downloads/llm_data/smallm"
121
+ )
122
+ define_and_run(input_checkpoint_path)
@@ -61,8 +61,10 @@ class CLIP(nn.Module):
61
61
  )
62
62
 
63
63
  self.config = config
64
+ # CLIP has only one block config.
65
+ block_config = config.block_config(0)
64
66
  self.transformer_blocks = nn.ModuleList(
65
- TransformerBlock(config) for _ in range(config.num_layers)
67
+ TransformerBlock(block_config, config) for _ in range(config.num_layers)
66
68
  )
67
69
  self.final_norm = builder.build_norm(
68
70
  config.embedding_dim, config.final_norm_config
@@ -112,15 +114,19 @@ def get_model_config() -> cfg.ModelConfig:
112
114
 
113
115
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
114
116
 
117
+ block_config = cfg.TransformerBlockConfig(
118
+ attn_config=attn_config,
119
+ ff_config=ff_config,
120
+ pre_attention_norm_config=norm_config,
121
+ post_attention_norm_config=norm_config,
122
+ )
123
+
115
124
  config = cfg.ModelConfig(
116
125
  vocab_size=vocab_size,
117
126
  num_layers=num_layers,
118
127
  max_seq_len=max_seq_len,
119
128
  embedding_dim=embedding_dim,
120
- attn_config=attn_config,
121
- ff_config=ff_config,
122
- pre_attention_norm_config=norm_config,
123
- post_attention_norm_config=norm_config,
129
+ block_configs=block_config,
124
130
  final_norm_config=norm_config,
125
131
  enable_hlfb=True,
126
132
  )
@@ -52,9 +52,15 @@ class T5Stack(nn.Module):
52
52
  self.config = config
53
53
  self.embed_tokens = embed_tokens
54
54
  self.is_decoder = config.is_decoder
55
+ # T5 has only one block config.
56
+ block_config = config.block_config(0)
55
57
  self.transformer_blocks = nn.ModuleList([
56
- EncoderDecoderBlock(config, has_relative_attention_bias=bool(i == 0))
57
- for i in range(config.num_layers)
58
+ EncoderDecoderBlock(
59
+ block_config,
60
+ config,
61
+ has_relative_attention_bias=bool(idx == 0),
62
+ )
63
+ for idx in range(config.num_layers)
58
64
  ])
59
65
  self.final_norm = builder.build_norm(
60
66
  config.embedding_dim, config.final_norm_config
@@ -73,13 +79,11 @@ class T5Stack(nn.Module):
73
79
  torch.Tensor
74
80
  ] = None, # should be for decoder case
75
81
  ):
76
- input_shape = input_ids.size()
77
82
  inputs_embeds = self.embed_tokens(input_ids)
78
- batch_size, seq_length = input_shape
79
83
  hidden_states = inputs_embeds
80
84
  position_bias = None
81
85
  encoder_decoder_position_bias = None
82
- for i, layer_module in enumerate(self.transformer_blocks):
86
+ for _, layer_module in enumerate(self.transformer_blocks):
83
87
  # EncoderDecoderBlock.forward
84
88
  hidden_states, position_bias, encoder_decoder_position_bias = (
85
89
  layer_module(
@@ -111,7 +115,8 @@ class T5(nn.Module):
111
115
 
112
116
  encoder_config = copy.deepcopy(config)
113
117
  encoder_config.is_decoder = False
114
- encoder_config.attn_config.enable_kv_cache = False
118
+ # T5 has only one block config.
119
+ encoder_config.block_config(0).attn_config.enable_kv_cache = False
115
120
  self.encoder = T5Stack(encoder_config, self.tok_embedding)
116
121
 
117
122
  decoder_config = copy.deepcopy(config)
@@ -137,20 +142,22 @@ class T5(nn.Module):
137
142
  device=torch.device("cpu"),
138
143
  )
139
144
 
145
+ # T5 has only one block config.
146
+ attn_config = config.block_config(0).attn_config
140
147
  self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
141
148
  bidirectional=True,
142
149
  query_length=config.kv_cache_max,
143
150
  key_length=config.kv_cache_max,
144
- num_buckets=config.attn_config.relative_attention_num_buckets,
145
- max_distance=config.attn_config.relative_attention_max_distance,
151
+ num_buckets=attn_config.relative_attention_num_buckets,
152
+ max_distance=attn_config.relative_attention_max_distance,
146
153
  )
147
154
 
148
155
  self.dec_rel_pos_mask = attn_utils.build_relative_position_buckets(
149
156
  bidirectional=False,
150
157
  query_length=config.kv_cache_max,
151
158
  key_length=config.kv_cache_max,
152
- num_buckets=config.attn_config.relative_attention_num_buckets,
153
- max_distance=config.attn_config.relative_attention_max_distance,
159
+ num_buckets=attn_config.relative_attention_num_buckets,
160
+ max_distance=attn_config.relative_attention_max_distance,
154
161
  )
155
162
 
156
163
  @torch.inference_mode
@@ -230,7 +237,8 @@ class T5Encoder(nn.Module):
230
237
 
231
238
  encoder_config = copy.deepcopy(config)
232
239
  encoder_config.is_decoder = False
233
- encoder_config.attn_config.enable_kv_cache = False
240
+ # T5 has only one block config.
241
+ encoder_config.block_config(0).attn_config.enable_kv_cache = False
234
242
  self.encoder = T5Stack(encoder_config, self.tok_embedding)
235
243
 
236
244
  self.enc_attn_mask_cache = (
@@ -243,12 +251,14 @@ class T5Encoder(nn.Module):
243
251
  .unsqueeze(0)
244
252
  )
245
253
 
254
+ # T5 has only one block config.
255
+ attn_config = config.block_config(0).attn_config
246
256
  self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
247
257
  bidirectional=True,
248
258
  query_length=config.kv_cache_max,
249
259
  key_length=config.kv_cache_max,
250
- num_buckets=config.attn_config.relative_attention_num_buckets,
251
- max_distance=config.attn_config.relative_attention_max_distance,
260
+ num_buckets=attn_config.relative_attention_num_buckets,
261
+ max_distance=attn_config.relative_attention_max_distance,
252
262
  )
253
263
 
254
264
  @torch.inference_mode
@@ -313,12 +323,14 @@ class T5Decoder(nn.Module):
313
323
  .unsqueeze(0)
314
324
  )
315
325
 
326
+ # T5 has only one block config.
327
+ attn_config = config.block_config(0).attn_config
316
328
  self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
317
329
  bidirectional=True,
318
330
  query_length=config.kv_cache_max,
319
331
  key_length=config.kv_cache_max,
320
- num_buckets=config.attn_config.relative_attention_num_buckets,
321
- max_distance=config.attn_config.relative_attention_max_distance,
332
+ num_buckets=attn_config.relative_attention_num_buckets,
333
+ max_distance=attn_config.relative_attention_max_distance,
322
334
  )
323
335
 
324
336
  self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
@@ -386,19 +398,20 @@ def get_model_config_t5() -> cfg.ModelConfig:
386
398
  type=cfg.NormalizationType.RMS_NORM,
387
399
  epsilon=1e-6,
388
400
  )
389
-
390
- config = cfg.ModelConfig(
391
- vocab_size=32128,
392
- num_layers=12,
393
- max_seq_len=512,
394
- embedding_dim=768,
401
+ block_config = cfg.TransformerBlockConfig(
395
402
  attn_config=attn_config,
396
403
  relative_attention=True,
397
404
  ff_config=ff_config,
398
405
  pre_attention_norm_config=norm_config,
399
406
  post_attention_norm_config=norm_config,
407
+ )
408
+ config = cfg.ModelConfig(
409
+ vocab_size=32128,
410
+ num_layers=12,
411
+ max_seq_len=512,
412
+ embedding_dim=768,
413
+ block_configs=block_config,
400
414
  final_norm_config=norm_config,
401
- parallel_residual=False,
402
415
  lm_head_use_bias=False,
403
416
  enable_hlfb=True,
404
417
  )
@@ -24,7 +24,6 @@ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_
24
24
  from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
25
25
  import torch
26
26
  from torch import nn
27
- import torch.nn.functional as F
28
27
 
29
28
  BATCH_SIZE = 1
30
29
 
@@ -32,13 +31,18 @@ BATCH_SIZE = 1
32
31
  class EncoderDecoderBlock(nn.Module):
33
32
 
34
33
  def __init__(
35
- self, config: cfg.ModelConfig, has_relative_attention_bias: bool = False
34
+ self,
35
+ config: cfg.TransformerBlockConfig,
36
+ model_config: cfg.ModelConfig,
37
+ has_relative_attention_bias: bool = False,
36
38
  ) -> None:
37
39
  """Initialize an instance of the EncoderDecoderBlock.
38
40
 
39
41
  Args:
40
- config (cfg.ModelConfig): the configuration object for this transformer
41
- block.
42
+ config (cfg.TransformerBlockConfig): the configuration object for this
43
+ transformer block.
44
+ model_config (cfg.ModelConfig): the configuration object for the model
45
+ this transformer block belongs to.
42
46
  has_relative_attention_bias (bool): whether the self attention block has
43
47
  relative bias.
44
48
  """
@@ -46,22 +50,22 @@ class EncoderDecoderBlock(nn.Module):
46
50
  super().__init__()
47
51
  self.atten_func = T5Attention(
48
52
  BATCH_SIZE,
49
- config.embedding_dim,
53
+ model_config.embedding_dim,
50
54
  config.attn_config,
51
55
  config.pre_attention_norm_config,
52
- config.kv_cache_max,
53
- config.enable_hlfb,
56
+ model_config.kv_cache_max,
57
+ model_config.enable_hlfb,
54
58
  has_relative_attention_bias=has_relative_attention_bias,
55
59
  )
56
60
  # For a decoder, we add a cross attention.
57
- if config.is_decoder:
61
+ if model_config.is_decoder:
58
62
  self.cross_atten_func = T5Attention(
59
63
  BATCH_SIZE,
60
- config.embedding_dim,
64
+ model_config.embedding_dim,
61
65
  config.attn_config,
62
66
  config.pre_attention_norm_config,
63
- config.kv_cache_max,
64
- config.enable_hlfb,
67
+ model_config.kv_cache_max,
68
+ model_config.enable_hlfb,
65
69
  # Cross Attention does not have relative attention bias.
66
70
  has_relative_attention_bias=False,
67
71
  )
@@ -69,9 +73,10 @@ class EncoderDecoderBlock(nn.Module):
69
73
  self.cross_atten_func = None
70
74
 
71
75
  self.post_atten_norm = builder.build_norm(
72
- config.embedding_dim, config.post_attention_norm_config
76
+ model_config.embedding_dim,
77
+ config.post_attention_norm_config,
73
78
  )
74
- self.ff = builder.build_ff(config.embedding_dim, config.ff_config)
79
+ self.ff = builder.build_ff(model_config.embedding_dim, config.ff_config)
75
80
  self.config = config
76
81
 
77
82
  def forward(
@@ -20,7 +20,6 @@ from ai_edge_torch.generative.layers.attention import TransformerBlock
20
20
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
21
21
  import ai_edge_torch.generative.layers.builder as builder
22
22
  import ai_edge_torch.generative.layers.model_config as cfg
23
- import numpy as np
24
23
  import torch
25
24
  import torch.nn as nn
26
25
 
@@ -36,16 +35,16 @@ class ToySingleLayerModel(torch.nn.Module):
36
35
  config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
37
36
  )
38
37
  self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
39
- self.transformer_block = TransformerBlock(config)
38
+ self.transformer_block = TransformerBlock(config.block_config(0), config)
40
39
  self.final_norm = builder.build_norm(
41
40
  config.embedding_dim,
42
41
  config.final_norm_config,
43
42
  )
43
+ # Toy model has only one block config.
44
+ attn_config = config.block_config(0).attn_config
44
45
  self.rope_cache = attn_utils.build_rope_cache(
45
46
  size=config.max_seq_len,
46
- dim=int(
47
- config.attn_config.rotary_percentage * config.attn_config.head_dim
48
- ),
47
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
49
48
  base=10_000,
50
49
  condense_ratio=1,
51
50
  dtype=torch.float32,
@@ -85,16 +84,16 @@ class ToySingleLayerModelWeightSharing(torch.nn.Module):
85
84
  bias=config.lm_head_use_bias,
86
85
  )
87
86
  self.lm_head.weight.data = self.tok_embedding.weight.data
88
- self.transformer_block = TransformerBlock(config)
87
+ self.transformer_block = TransformerBlock(config.block_config(0), config)
89
88
  self.final_norm = builder.build_norm(
90
89
  config.embedding_dim,
91
90
  config.final_norm_config,
92
91
  )
92
+ # Toy model has only one block config.
93
+ attn_config = config.block_config(0).attn_config
93
94
  self.rope_cache = attn_utils.build_rope_cache(
94
95
  size=config.max_seq_len,
95
- dim=int(
96
- config.attn_config.rotary_percentage * config.attn_config.head_dim
97
- ),
96
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
98
97
  base=10_000,
99
98
  condense_ratio=1,
100
99
  dtype=torch.float32,
@@ -135,15 +134,18 @@ def get_model_config() -> cfg.ModelConfig:
135
134
  intermediate_size=256,
136
135
  )
137
136
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
137
+ block_config = cfg.TransformerBlockConfig(
138
+ attn_config=attn_config,
139
+ ff_config=ff_config,
140
+ pre_attention_norm_config=norm_config,
141
+ post_attention_norm_config=norm_config,
142
+ )
138
143
  config = cfg.ModelConfig(
139
144
  vocab_size=400,
140
145
  num_layers=1,
141
146
  max_seq_len=KV_CACHE_MAX_LEN,
142
147
  embedding_dim=128,
143
- attn_config=attn_config,
144
- ff_config=ff_config,
145
- pre_attention_norm_config=norm_config,
146
- post_attention_norm_config=norm_config,
148
+ block_configs=block_config,
147
149
  final_norm_config=norm_config,
148
150
  )
149
151
  return config
@@ -12,14 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # A toy example which has basic transformer block (w/ KV-Cache).
15
+
16
+ """A toy example which has basic transformer block (w/ externalized KV-Cache)."""
17
+
16
18
  from typing import Tuple
17
19
 
18
20
  import ai_edge_torch
19
21
  from ai_edge_torch import lowertools
20
- from ai_edge_torch.generative.layers.attention import TransformerBlock
22
+ from ai_edge_torch.generative.layers import attention
23
+ from ai_edge_torch.generative.layers import builder
24
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
- import ai_edge_torch.generative.layers.builder as builder
23
26
  import ai_edge_torch.generative.layers.model_config as cfg
24
27
  import torch
25
28
  import torch.nn as nn
@@ -27,7 +30,7 @@ import torch.nn as nn
27
30
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
28
31
 
29
32
 
30
- class ToyModelWithKV(torch.nn.Module):
33
+ class ToyModelWithKVCache(torch.nn.Module):
31
34
 
32
35
  def __init__(self, config: cfg.ModelConfig) -> None:
33
36
  super().__init__()
@@ -35,18 +38,20 @@ class ToyModelWithKV(torch.nn.Module):
35
38
  config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
36
39
  )
37
40
  self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
41
+ # Toy model has only one block config.
42
+ block_config = config.block_config(0)
38
43
  self.transformer_blocks = nn.ModuleList(
39
- TransformerBlock(config) for _ in range(config.num_layers)
44
+ attention.TransformerBlock(block_config, config)
45
+ for _ in range(config.num_layers)
40
46
  )
41
47
  self.final_norm = builder.build_norm(
42
48
  config.embedding_dim,
43
49
  config.final_norm_config,
44
50
  )
51
+ attn_config = block_config.attn_config
45
52
  self.rope_cache = attn_utils.build_rope_cache(
46
53
  size=config.max_seq_len,
47
- dim=int(
48
- config.attn_config.rotary_percentage * config.attn_config.head_dim
49
- ),
54
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
50
55
  base=10_000,
51
56
  condense_ratio=1,
52
57
  dtype=torch.float32,
@@ -57,18 +62,29 @@ class ToyModelWithKV(torch.nn.Module):
57
62
  )
58
63
  self.config = config
59
64
 
60
- @torch.inference_mode
61
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
62
- x = self.tok_embedding(idx)
65
+ def forward(
66
+ self,
67
+ tokens: torch.Tensor,
68
+ input_pos: torch.Tensor,
69
+ kv_cache: kv_utils.KVCache,
70
+ ) -> Tuple[torch.Tensor, kv_utils.KVCache]:
71
+ x = self.tok_embedding(tokens)
63
72
  cos, sin = self.rope_cache
64
73
  cos = cos.index_select(0, input_pos)
65
74
  sin = sin.index_select(0, input_pos)
66
75
  mask = self.mask_cache.index_select(2, input_pos)
67
76
  mask = mask[:, :, :, : self.config.max_seq_len]
77
+
78
+ updated_kv_entires = []
68
79
  for i, block in enumerate(self.transformer_blocks):
69
- x = block(x, (cos, sin), mask, input_pos)
80
+ kv_entry = kv_cache.caches[i] if kv_cache else None
81
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
82
+ if kv_entry:
83
+ updated_kv_entires.append(kv_entry)
84
+
70
85
  x = self.final_norm(x)
71
- return self.lm_head(x)
86
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
87
+ return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
72
88
 
73
89
 
74
90
  def _export_stablehlo_mlir(model, args):
@@ -78,7 +94,10 @@ def _export_stablehlo_mlir(model, args):
78
94
 
79
95
  def get_model_config() -> cfg.ModelConfig:
80
96
  attn_config = cfg.AttentionConfig(
81
- num_heads=32, head_dim=4, num_query_groups=4, rotary_percentage=1.0
97
+ num_heads=32,
98
+ head_dim=4,
99
+ num_query_groups=4,
100
+ rotary_percentage=1.0,
82
101
  )
83
102
  ff_config = cfg.FeedForwardConfig(
84
103
  type=cfg.FeedForwardType.GATED,
@@ -86,15 +105,18 @@ def get_model_config() -> cfg.ModelConfig:
86
105
  intermediate_size=256,
87
106
  )
88
107
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
89
- config = cfg.ModelConfig(
90
- vocab_size=150,
91
- num_layers=2,
92
- max_seq_len=500,
93
- embedding_dim=128,
108
+ block_config = cfg.TransformerBlockConfig(
94
109
  attn_config=attn_config,
95
110
  ff_config=ff_config,
96
111
  pre_attention_norm_config=norm_config,
97
112
  post_attention_norm_config=norm_config,
113
+ )
114
+ config = cfg.ModelConfig(
115
+ vocab_size=150,
116
+ num_layers=2,
117
+ max_seq_len=100,
118
+ embedding_dim=128,
119
+ block_configs=block_config,
98
120
  final_norm_config=norm_config,
99
121
  enable_hlfb=True,
100
122
  )
@@ -102,40 +124,59 @@ def get_model_config() -> cfg.ModelConfig:
102
124
 
103
125
 
104
126
  def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
105
- idx = torch.unsqueeze(torch.arange(0, 100), 0)
127
+ tokens = torch.unsqueeze(torch.arange(0, 100), 0)
106
128
  input_pos = torch.arange(0, 100)
107
- return idx, input_pos
129
+ return tokens, input_pos
108
130
 
109
131
 
110
132
  def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
111
- idx = torch.tensor([[1]], dtype=torch.long)
112
- input_pos = torch.tensor([10], dtype=torch.int64)
113
- return idx, input_pos
133
+ tokens = torch.tensor([[1]], dtype=torch.long)
134
+ input_pos = torch.tensor([10])
135
+ return tokens, input_pos
114
136
 
115
137
 
116
138
  def define_and_run() -> None:
117
139
  dump_mlir = False
118
140
 
119
141
  config = get_model_config()
120
- model = ToyModelWithKV(config)
142
+ model = ToyModelWithExternalKV(config)
143
+ model.eval()
121
144
  print('running an inference')
122
- idx, input_pos = get_sample_prefill_inputs()
123
- decode_idx, decode_input_pos = get_sample_decode_inputs()
124
- print(model.forward(idx, input_pos))
145
+ kv = kv_utils.KVCache.from_model_config(config)
146
+
147
+ tokens, input_pos = get_sample_prefill_inputs()
148
+ decode_token, decode_input_pos = get_sample_decode_inputs()
149
+ print(model.forward(tokens, input_pos, kv))
125
150
 
126
151
  if dump_mlir:
127
- mlir_text = _export_stablehlo_mlir(model, (idx, input_pos))
128
- with open('/tmp/toy_model_with_kv.stablehlo.mlir', 'w') as f:
152
+ mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
153
+ with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
129
154
  f.write(mlir_text)
130
155
 
131
156
  # Convert model to tflite with 2 signatures (prefill + decode).
132
157
  print('converting toy model to tflite with 2 signatures (prefill + decode)')
133
158
  edge_model = (
134
- ai_edge_torch.signature('prefill', model, (idx, input_pos))
135
- .signature('decode', model, (decode_idx, decode_input_pos))
159
+ ai_edge_torch.signature(
160
+ 'prefill',
161
+ model,
162
+ sample_kwargs={
163
+ 'tokens': tokens,
164
+ 'input_pos': input_pos,
165
+ 'kv_cache': kv,
166
+ },
167
+ )
168
+ .signature(
169
+ 'decode',
170
+ model,
171
+ sample_kwargs={
172
+ 'tokens': decode_token,
173
+ 'input_pos': decode_input_pos,
174
+ 'kv_cache': kv,
175
+ },
176
+ )
136
177
  .convert()
137
178
  )
138
- edge_model.export('/tmp/toy_kv_cache.tflite')
179
+ edge_model.export('/tmp/toy_external_kv_cache.tflite')
139
180
 
140
181
 
141
182
  if __name__ == '__main__':