ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
  4. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
  5. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
  6. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
  7. ai_edge_torch/config.py +4 -1
  8. ai_edge_torch/fx_pass_base.py +101 -0
  9. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +35 -16
  10. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +29 -10
  11. ai_edge_torch/generative/examples/gemma/gemma.py +52 -32
  12. ai_edge_torch/generative/examples/gemma/gemma2.py +87 -60
  13. ai_edge_torch/generative/examples/{experimental/gemma → openelm}/convert_to_tflite.py +16 -18
  14. ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
  15. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +15 -16
  16. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +48 -45
  17. ai_edge_torch/generative/examples/{experimental/tiny_llama → smollm}/convert_to_tflite.py +16 -17
  18. ai_edge_torch/generative/examples/smollm/smollm.py +131 -0
  19. ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -6
  20. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
  21. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
  22. ai_edge_torch/generative/examples/t5/t5.py +43 -30
  23. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  24. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  25. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +75 -34
  26. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +29 -10
  27. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +57 -36
  28. ai_edge_torch/generative/fx_passes/__init__.py +4 -4
  29. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
  30. ai_edge_torch/generative/layers/attention.py +84 -73
  31. ai_edge_torch/generative/layers/builder.py +38 -14
  32. ai_edge_torch/generative/layers/feed_forward.py +26 -8
  33. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  34. ai_edge_torch/generative/layers/model_config.py +61 -33
  35. ai_edge_torch/generative/layers/normalization.py +158 -0
  36. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  37. ai_edge_torch/generative/quantize/example.py +2 -2
  38. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  39. ai_edge_torch/generative/test/test_loader.py +1 -1
  40. ai_edge_torch/generative/test/test_model_conversion.py +77 -62
  41. ai_edge_torch/generative/test/test_model_conversion_large.py +61 -68
  42. ai_edge_torch/generative/test/test_quantize.py +5 -5
  43. ai_edge_torch/generative/test/utils.py +54 -0
  44. ai_edge_torch/generative/utilities/loader.py +28 -15
  45. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  46. ai_edge_torch/odml_torch/export.py +40 -0
  47. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  48. ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
  49. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
  50. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  51. ai_edge_torch/version.py +1 -1
  52. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
  53. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +59 -63
  54. ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
  55. ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
  56. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  57. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  58. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  59. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  60. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  61. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  62. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  63. /ai_edge_torch/generative/examples/{experimental → openelm}/__init__.py +0 -0
  64. /ai_edge_torch/generative/examples/{experimental/gemma → phi}/__init__.py +0 -0
  65. /ai_edge_torch/generative/examples/{experimental/phi → smollm}/__init__.py +0 -0
  66. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
  67. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
  68. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -52,9 +52,15 @@ class T5Stack(nn.Module):
52
52
  self.config = config
53
53
  self.embed_tokens = embed_tokens
54
54
  self.is_decoder = config.is_decoder
55
+ # T5 has only one block config.
56
+ block_config = config.block_config(0)
55
57
  self.transformer_blocks = nn.ModuleList([
56
- EncoderDecoderBlock(config, has_relative_attention_bias=bool(i == 0))
57
- for i in range(config.num_layers)
58
+ EncoderDecoderBlock(
59
+ block_config,
60
+ config,
61
+ has_relative_attention_bias=bool(idx == 0),
62
+ )
63
+ for idx in range(config.num_layers)
58
64
  ])
59
65
  self.final_norm = builder.build_norm(
60
66
  config.embedding_dim, config.final_norm_config
@@ -73,13 +79,11 @@ class T5Stack(nn.Module):
73
79
  torch.Tensor
74
80
  ] = None, # should be for decoder case
75
81
  ):
76
- input_shape = input_ids.size()
77
82
  inputs_embeds = self.embed_tokens(input_ids)
78
- batch_size, seq_length = input_shape
79
83
  hidden_states = inputs_embeds
80
84
  position_bias = None
81
85
  encoder_decoder_position_bias = None
82
- for i, layer_module in enumerate(self.transformer_blocks):
86
+ for _, layer_module in enumerate(self.transformer_blocks):
83
87
  # EncoderDecoderBlock.forward
84
88
  hidden_states, position_bias, encoder_decoder_position_bias = (
85
89
  layer_module(
@@ -111,7 +115,8 @@ class T5(nn.Module):
111
115
 
112
116
  encoder_config = copy.deepcopy(config)
113
117
  encoder_config.is_decoder = False
114
- encoder_config.attn_config.enable_kv_cache = False
118
+ # T5 has only one block config.
119
+ encoder_config.block_config(0).attn_config.enable_kv_cache = False
115
120
  self.encoder = T5Stack(encoder_config, self.tok_embedding)
116
121
 
117
122
  decoder_config = copy.deepcopy(config)
@@ -137,20 +142,22 @@ class T5(nn.Module):
137
142
  device=torch.device("cpu"),
138
143
  )
139
144
 
145
+ # T5 has only one block config.
146
+ attn_config = config.block_config(0).attn_config
140
147
  self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
141
148
  bidirectional=True,
142
149
  query_length=config.kv_cache_max,
143
150
  key_length=config.kv_cache_max,
144
- num_buckets=config.attn_config.relative_attention_num_buckets,
145
- max_distance=config.attn_config.relative_attention_max_distance,
151
+ num_buckets=attn_config.relative_attention_num_buckets,
152
+ max_distance=attn_config.relative_attention_max_distance,
146
153
  )
147
154
 
148
155
  self.dec_rel_pos_mask = attn_utils.build_relative_position_buckets(
149
156
  bidirectional=False,
150
157
  query_length=config.kv_cache_max,
151
158
  key_length=config.kv_cache_max,
152
- num_buckets=config.attn_config.relative_attention_num_buckets,
153
- max_distance=config.attn_config.relative_attention_max_distance,
159
+ num_buckets=attn_config.relative_attention_num_buckets,
160
+ max_distance=attn_config.relative_attention_max_distance,
154
161
  )
155
162
 
156
163
  @torch.inference_mode
@@ -230,7 +237,8 @@ class T5Encoder(nn.Module):
230
237
 
231
238
  encoder_config = copy.deepcopy(config)
232
239
  encoder_config.is_decoder = False
233
- encoder_config.attn_config.enable_kv_cache = False
240
+ # T5 has only one block config.
241
+ encoder_config.block_config(0).attn_config.enable_kv_cache = False
234
242
  self.encoder = T5Stack(encoder_config, self.tok_embedding)
235
243
 
236
244
  self.enc_attn_mask_cache = (
@@ -243,12 +251,14 @@ class T5Encoder(nn.Module):
243
251
  .unsqueeze(0)
244
252
  )
245
253
 
254
+ # T5 has only one block config.
255
+ attn_config = config.block_config(0).attn_config
246
256
  self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
247
257
  bidirectional=True,
248
258
  query_length=config.kv_cache_max,
249
259
  key_length=config.kv_cache_max,
250
- num_buckets=config.attn_config.relative_attention_num_buckets,
251
- max_distance=config.attn_config.relative_attention_max_distance,
260
+ num_buckets=attn_config.relative_attention_num_buckets,
261
+ max_distance=attn_config.relative_attention_max_distance,
252
262
  )
253
263
 
254
264
  @torch.inference_mode
@@ -313,12 +323,14 @@ class T5Decoder(nn.Module):
313
323
  .unsqueeze(0)
314
324
  )
315
325
 
326
+ # T5 has only one block config.
327
+ attn_config = config.block_config(0).attn_config
316
328
  self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
317
329
  bidirectional=True,
318
330
  query_length=config.kv_cache_max,
319
331
  key_length=config.kv_cache_max,
320
- num_buckets=config.attn_config.relative_attention_num_buckets,
321
- max_distance=config.attn_config.relative_attention_max_distance,
332
+ num_buckets=attn_config.relative_attention_num_buckets,
333
+ max_distance=attn_config.relative_attention_max_distance,
322
334
  )
323
335
 
324
336
  self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
@@ -386,19 +398,20 @@ def get_model_config_t5() -> cfg.ModelConfig:
386
398
  type=cfg.NormalizationType.RMS_NORM,
387
399
  epsilon=1e-6,
388
400
  )
389
-
390
- config = cfg.ModelConfig(
391
- vocab_size=32128,
392
- num_layers=12,
393
- max_seq_len=512,
394
- embedding_dim=768,
401
+ block_config = cfg.TransformerBlockConfig(
395
402
  attn_config=attn_config,
396
403
  relative_attention=True,
397
404
  ff_config=ff_config,
398
405
  pre_attention_norm_config=norm_config,
399
406
  post_attention_norm_config=norm_config,
407
+ )
408
+ config = cfg.ModelConfig(
409
+ vocab_size=32128,
410
+ num_layers=12,
411
+ max_seq_len=512,
412
+ embedding_dim=768,
413
+ block_configs=block_config,
400
414
  final_norm_config=norm_config,
401
- parallel_residual=False,
402
415
  lm_head_use_bias=False,
403
416
  enable_hlfb=True,
404
417
  )
@@ -588,12 +601,12 @@ def define_and_run_t5(checkpoint_path: str) -> None:
588
601
  model = build_t5_model(checkpoint_path)
589
602
 
590
603
  idx = get_sample_encoder_input_ids()
591
- tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu")
604
+ tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
592
605
  tokens[0, :77] = idx
593
- input_pos = torch.arange(0, 512)
606
+ input_pos = torch.arange(0, 512, dtype=torch.int)
594
607
 
595
- decode_d_token = torch.tensor([[0]], dtype=torch.int64)
596
- decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
608
+ decode_d_token = torch.tensor([[0]], dtype=torch.int)
609
+ decode_d_input_pos = torch.tensor([0], dtype=torch.int)
597
610
  pad_mask = torch.zeros([model.config.kv_cache_max], dtype=torch.float32)
598
611
  pad_mask[77:] = float("-inf")
599
612
  lm_logits = model.forward(
@@ -620,12 +633,12 @@ def define_and_run_t5_split(checkpoint_path: str) -> None:
620
633
  )
621
634
  idx = get_sample_encoder_input_ids()
622
635
 
623
- tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu")
636
+ tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
624
637
  tokens[0, :77] = idx
625
- input_pos = torch.arange(0, 512)
638
+ input_pos = torch.arange(0, 512, dtype=torch.int)
626
639
 
627
- decode_d_token = torch.tensor([[0]], dtype=torch.int64)
628
- decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
640
+ decode_d_token = torch.tensor([[0]], dtype=torch.int)
641
+ decode_d_input_pos = torch.tensor([0], dtype=torch.int)
629
642
  pad_mask = torch.zeros(
630
643
  [t5_encoder_model.config.kv_cache_max], dtype=torch.float32
631
644
  )
@@ -24,7 +24,6 @@ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_
24
24
  from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
25
25
  import torch
26
26
  from torch import nn
27
- import torch.nn.functional as F
28
27
 
29
28
  BATCH_SIZE = 1
30
29
 
@@ -32,13 +31,18 @@ BATCH_SIZE = 1
32
31
  class EncoderDecoderBlock(nn.Module):
33
32
 
34
33
  def __init__(
35
- self, config: cfg.ModelConfig, has_relative_attention_bias: bool = False
34
+ self,
35
+ config: cfg.TransformerBlockConfig,
36
+ model_config: cfg.ModelConfig,
37
+ has_relative_attention_bias: bool = False,
36
38
  ) -> None:
37
39
  """Initialize an instance of the EncoderDecoderBlock.
38
40
 
39
41
  Args:
40
- config (cfg.ModelConfig): the configuration object for this transformer
41
- block.
42
+ config (cfg.TransformerBlockConfig): the configuration object for this
43
+ transformer block.
44
+ model_config (cfg.ModelConfig): the configuration object for the model
45
+ this transformer block belongs to.
42
46
  has_relative_attention_bias (bool): whether the self attention block has
43
47
  relative bias.
44
48
  """
@@ -46,22 +50,22 @@ class EncoderDecoderBlock(nn.Module):
46
50
  super().__init__()
47
51
  self.atten_func = T5Attention(
48
52
  BATCH_SIZE,
49
- config.embedding_dim,
53
+ model_config.embedding_dim,
50
54
  config.attn_config,
51
55
  config.pre_attention_norm_config,
52
- config.kv_cache_max,
53
- config.enable_hlfb,
56
+ model_config.kv_cache_max,
57
+ model_config.enable_hlfb,
54
58
  has_relative_attention_bias=has_relative_attention_bias,
55
59
  )
56
60
  # For a decoder, we add a cross attention.
57
- if config.is_decoder:
61
+ if model_config.is_decoder:
58
62
  self.cross_atten_func = T5Attention(
59
63
  BATCH_SIZE,
60
- config.embedding_dim,
64
+ model_config.embedding_dim,
61
65
  config.attn_config,
62
66
  config.pre_attention_norm_config,
63
- config.kv_cache_max,
64
- config.enable_hlfb,
67
+ model_config.kv_cache_max,
68
+ model_config.enable_hlfb,
65
69
  # Cross Attention does not have relative attention bias.
66
70
  has_relative_attention_bias=False,
67
71
  )
@@ -69,9 +73,10 @@ class EncoderDecoderBlock(nn.Module):
69
73
  self.cross_atten_func = None
70
74
 
71
75
  self.post_atten_norm = builder.build_norm(
72
- config.embedding_dim, config.post_attention_norm_config
76
+ model_config.embedding_dim,
77
+ config.post_attention_norm_config,
73
78
  )
74
- self.ff = builder.build_ff(config.embedding_dim, config.ff_config)
79
+ self.ff = builder.build_ff(model_config.embedding_dim, config.ff_config)
75
80
  self.config = config
76
81
 
77
82
  def forward(
@@ -20,7 +20,6 @@ from ai_edge_torch.generative.layers.attention import TransformerBlock
20
20
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
21
21
  import ai_edge_torch.generative.layers.builder as builder
22
22
  import ai_edge_torch.generative.layers.model_config as cfg
23
- import numpy as np
24
23
  import torch
25
24
  import torch.nn as nn
26
25
 
@@ -36,16 +35,16 @@ class ToySingleLayerModel(torch.nn.Module):
36
35
  config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
37
36
  )
38
37
  self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
39
- self.transformer_block = TransformerBlock(config)
38
+ self.transformer_block = TransformerBlock(config.block_config(0), config)
40
39
  self.final_norm = builder.build_norm(
41
40
  config.embedding_dim,
42
41
  config.final_norm_config,
43
42
  )
43
+ # Toy model has only one block config.
44
+ attn_config = config.block_config(0).attn_config
44
45
  self.rope_cache = attn_utils.build_rope_cache(
45
46
  size=config.max_seq_len,
46
- dim=int(
47
- config.attn_config.rotary_percentage * config.attn_config.head_dim
48
- ),
47
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
49
48
  base=10_000,
50
49
  condense_ratio=1,
51
50
  dtype=torch.float32,
@@ -85,16 +84,16 @@ class ToySingleLayerModelWeightSharing(torch.nn.Module):
85
84
  bias=config.lm_head_use_bias,
86
85
  )
87
86
  self.lm_head.weight.data = self.tok_embedding.weight.data
88
- self.transformer_block = TransformerBlock(config)
87
+ self.transformer_block = TransformerBlock(config.block_config(0), config)
89
88
  self.final_norm = builder.build_norm(
90
89
  config.embedding_dim,
91
90
  config.final_norm_config,
92
91
  )
92
+ # Toy model has only one block config.
93
+ attn_config = config.block_config(0).attn_config
93
94
  self.rope_cache = attn_utils.build_rope_cache(
94
95
  size=config.max_seq_len,
95
- dim=int(
96
- config.attn_config.rotary_percentage * config.attn_config.head_dim
97
- ),
96
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
98
97
  base=10_000,
99
98
  condense_ratio=1,
100
99
  dtype=torch.float32,
@@ -135,15 +134,18 @@ def get_model_config() -> cfg.ModelConfig:
135
134
  intermediate_size=256,
136
135
  )
137
136
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
137
+ block_config = cfg.TransformerBlockConfig(
138
+ attn_config=attn_config,
139
+ ff_config=ff_config,
140
+ pre_attention_norm_config=norm_config,
141
+ post_attention_norm_config=norm_config,
142
+ )
138
143
  config = cfg.ModelConfig(
139
144
  vocab_size=400,
140
145
  num_layers=1,
141
146
  max_seq_len=KV_CACHE_MAX_LEN,
142
147
  embedding_dim=128,
143
- attn_config=attn_config,
144
- ff_config=ff_config,
145
- pre_attention_norm_config=norm_config,
146
- post_attention_norm_config=norm_config,
148
+ block_configs=block_config,
147
149
  final_norm_config=norm_config,
148
150
  )
149
151
  return config
@@ -12,14 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # A toy example which has basic transformer block (w/ KV-Cache).
15
+
16
+ """A toy example which has basic transformer block (w/ externalized KV-Cache)."""
17
+
16
18
  from typing import Tuple
17
19
 
18
20
  import ai_edge_torch
19
21
  from ai_edge_torch import lowertools
20
- from ai_edge_torch.generative.layers.attention import TransformerBlock
22
+ from ai_edge_torch.generative.layers import attention
23
+ from ai_edge_torch.generative.layers import builder
24
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
- import ai_edge_torch.generative.layers.builder as builder
23
26
  import ai_edge_torch.generative.layers.model_config as cfg
24
27
  import torch
25
28
  import torch.nn as nn
@@ -27,7 +30,7 @@ import torch.nn as nn
27
30
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
28
31
 
29
32
 
30
- class ToyModelWithKV(torch.nn.Module):
33
+ class ToyModelWithKVCache(torch.nn.Module):
31
34
 
32
35
  def __init__(self, config: cfg.ModelConfig) -> None:
33
36
  super().__init__()
@@ -35,18 +38,20 @@ class ToyModelWithKV(torch.nn.Module):
35
38
  config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
36
39
  )
37
40
  self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
41
+ # Toy model has only one block config.
42
+ block_config = config.block_config(0)
38
43
  self.transformer_blocks = nn.ModuleList(
39
- TransformerBlock(config) for _ in range(config.num_layers)
44
+ attention.TransformerBlock(block_config, config)
45
+ for _ in range(config.num_layers)
40
46
  )
41
47
  self.final_norm = builder.build_norm(
42
48
  config.embedding_dim,
43
49
  config.final_norm_config,
44
50
  )
51
+ attn_config = block_config.attn_config
45
52
  self.rope_cache = attn_utils.build_rope_cache(
46
53
  size=config.max_seq_len,
47
- dim=int(
48
- config.attn_config.rotary_percentage * config.attn_config.head_dim
49
- ),
54
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
50
55
  base=10_000,
51
56
  condense_ratio=1,
52
57
  dtype=torch.float32,
@@ -57,18 +62,29 @@ class ToyModelWithKV(torch.nn.Module):
57
62
  )
58
63
  self.config = config
59
64
 
60
- @torch.inference_mode
61
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
62
- x = self.tok_embedding(idx)
65
+ def forward(
66
+ self,
67
+ tokens: torch.Tensor,
68
+ input_pos: torch.Tensor,
69
+ kv_cache: kv_utils.KVCache,
70
+ ) -> Tuple[torch.Tensor, kv_utils.KVCache]:
71
+ x = self.tok_embedding(tokens)
63
72
  cos, sin = self.rope_cache
64
73
  cos = cos.index_select(0, input_pos)
65
74
  sin = sin.index_select(0, input_pos)
66
75
  mask = self.mask_cache.index_select(2, input_pos)
67
76
  mask = mask[:, :, :, : self.config.max_seq_len]
77
+
78
+ updated_kv_entires = []
68
79
  for i, block in enumerate(self.transformer_blocks):
69
- x = block(x, (cos, sin), mask, input_pos)
80
+ kv_entry = kv_cache.caches[i] if kv_cache else None
81
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
82
+ if kv_entry:
83
+ updated_kv_entires.append(kv_entry)
84
+
70
85
  x = self.final_norm(x)
71
- return self.lm_head(x)
86
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
87
+ return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
72
88
 
73
89
 
74
90
  def _export_stablehlo_mlir(model, args):
@@ -78,7 +94,10 @@ def _export_stablehlo_mlir(model, args):
78
94
 
79
95
  def get_model_config() -> cfg.ModelConfig:
80
96
  attn_config = cfg.AttentionConfig(
81
- num_heads=32, head_dim=4, num_query_groups=4, rotary_percentage=1.0
97
+ num_heads=32,
98
+ head_dim=4,
99
+ num_query_groups=4,
100
+ rotary_percentage=1.0,
82
101
  )
83
102
  ff_config = cfg.FeedForwardConfig(
84
103
  type=cfg.FeedForwardType.GATED,
@@ -86,15 +105,18 @@ def get_model_config() -> cfg.ModelConfig:
86
105
  intermediate_size=256,
87
106
  )
88
107
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
89
- config = cfg.ModelConfig(
90
- vocab_size=150,
91
- num_layers=2,
92
- max_seq_len=500,
93
- embedding_dim=128,
108
+ block_config = cfg.TransformerBlockConfig(
94
109
  attn_config=attn_config,
95
110
  ff_config=ff_config,
96
111
  pre_attention_norm_config=norm_config,
97
112
  post_attention_norm_config=norm_config,
113
+ )
114
+ config = cfg.ModelConfig(
115
+ vocab_size=150,
116
+ num_layers=2,
117
+ max_seq_len=100,
118
+ embedding_dim=128,
119
+ block_configs=block_config,
98
120
  final_norm_config=norm_config,
99
121
  enable_hlfb=True,
100
122
  )
@@ -102,40 +124,59 @@ def get_model_config() -> cfg.ModelConfig:
102
124
 
103
125
 
104
126
  def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
105
- idx = torch.unsqueeze(torch.arange(0, 100), 0)
106
- input_pos = torch.arange(0, 100)
107
- return idx, input_pos
127
+ tokens = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
128
+ input_pos = torch.arange(0, 100, dtype=torch.int)
129
+ return tokens, input_pos
108
130
 
109
131
 
110
132
  def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
111
- idx = torch.tensor([[1]], dtype=torch.long)
112
- input_pos = torch.tensor([10], dtype=torch.int64)
113
- return idx, input_pos
133
+ tokens = torch.tensor([[1]], dtype=torch.int)
134
+ input_pos = torch.tensor([10])
135
+ return tokens, input_pos
114
136
 
115
137
 
116
138
  def define_and_run() -> None:
117
139
  dump_mlir = False
118
140
 
119
141
  config = get_model_config()
120
- model = ToyModelWithKV(config)
142
+ model = ToyModelWithExternalKV(config)
143
+ model.eval()
121
144
  print('running an inference')
122
- idx, input_pos = get_sample_prefill_inputs()
123
- decode_idx, decode_input_pos = get_sample_decode_inputs()
124
- print(model.forward(idx, input_pos))
145
+ kv = kv_utils.KVCache.from_model_config(config)
146
+
147
+ tokens, input_pos = get_sample_prefill_inputs()
148
+ decode_token, decode_input_pos = get_sample_decode_inputs()
149
+ print(model.forward(tokens, input_pos, kv))
125
150
 
126
151
  if dump_mlir:
127
- mlir_text = _export_stablehlo_mlir(model, (idx, input_pos))
128
- with open('/tmp/toy_model_with_kv.stablehlo.mlir', 'w') as f:
152
+ mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
153
+ with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
129
154
  f.write(mlir_text)
130
155
 
131
156
  # Convert model to tflite with 2 signatures (prefill + decode).
132
157
  print('converting toy model to tflite with 2 signatures (prefill + decode)')
133
158
  edge_model = (
134
- ai_edge_torch.signature('prefill', model, (idx, input_pos))
135
- .signature('decode', model, (decode_idx, decode_input_pos))
159
+ ai_edge_torch.signature(
160
+ 'prefill',
161
+ model,
162
+ sample_kwargs={
163
+ 'tokens': tokens,
164
+ 'input_pos': input_pos,
165
+ 'kv_cache': kv,
166
+ },
167
+ )
168
+ .signature(
169
+ 'decode',
170
+ model,
171
+ sample_kwargs={
172
+ 'tokens': decode_token,
173
+ 'input_pos': decode_input_pos,
174
+ 'kv_cache': kv,
175
+ },
176
+ )
136
177
  .convert()
137
178
  )
138
- edge_model.export('/tmp/toy_kv_cache.tflite')
179
+ edge_model.export('/tmp/toy_external_kv_cache.tflite')
139
180
 
140
181
 
141
182
  if __name__ == '__main__':
@@ -13,11 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting TinyLlama model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
@@ -44,24 +47,40 @@ def convert_tiny_llama_to_tflite(
44
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
45
48
  )
46
49
  # Tensors used to trace the model graph during conversion.
47
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
48
- prefill_input_pos = torch.arange(0, prefill_seq_len)
49
- decode_token = torch.tensor([[0]], dtype=torch.long)
50
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
+ decode_token = torch.tensor([[0]], dtype=torch.int)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/tiny_llama_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/tiny_llama_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/tiny_llama')
67
- convert_tiny_llama_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama')
86
+ convert_tiny_llama_to_tflite(path)