ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  3. ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
  4. ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
  5. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
  7. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  8. ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
  9. ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
  10. ai_edge_torch/generative/examples/t5/t5.py +35 -22
  11. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  12. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  13. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
  14. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  15. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
  16. ai_edge_torch/generative/layers/attention.py +77 -73
  17. ai_edge_torch/generative/layers/builder.py +5 -3
  18. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  19. ai_edge_torch/generative/layers/model_config.py +38 -19
  20. ai_edge_torch/generative/layers/normalization.py +158 -0
  21. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  22. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  23. ai_edge_torch/generative/test/test_loader.py +1 -1
  24. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  25. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  26. ai_edge_torch/generative/test/utils.py +54 -0
  27. ai_edge_torch/generative/utilities/loader.py +15 -15
  28. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  29. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  30. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  31. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  32. ai_edge_torch/version.py +1 -1
  33. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
  34. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +39 -45
  35. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  36. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  37. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  38. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  39. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  40. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  42. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  43. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  44. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  45. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  46. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
  47. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -12,26 +12,22 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building phi-2 model from the Edge Generative API layers.
16
- #
17
- # Note: This is an experimental version of phi2 with external KV cache.
18
- # Please use with caution.
15
+
16
+ """Example of building a Phi-2 model."""
19
17
 
20
18
  import os
21
- from pathlib import Path
22
- from typing import Tuple
19
+ import pathlib
23
20
 
21
+ from ai_edge_torch.generative.layers import attention
24
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
- from ai_edge_torch.generative.layers.experimental import attention
27
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
25
  import ai_edge_torch.generative.layers.model_config as cfg
29
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
30
27
  import numpy as np
31
28
  import torch
32
29
  from torch import nn
33
30
 
34
-
35
31
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
36
32
  ff_up_proj="model.layers.{}.mlp.fc1",
37
33
  ff_down_proj="model.layers.{}.mlp.fc2",
@@ -52,7 +48,6 @@ class Phi2(nn.Module):
52
48
  def __init__(self, config: cfg.ModelConfig):
53
49
  super().__init__()
54
50
 
55
- self.config = config
56
51
  # Construct model layers.
57
52
  self.lm_head = nn.Linear(
58
53
  config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
@@ -60,18 +55,20 @@ class Phi2(nn.Module):
60
55
  self.tok_embedding = nn.Embedding(
61
56
  config.vocab_size, config.embedding_dim, padding_idx=0
62
57
  )
58
+ # Phi-2 has only one block config.
59
+ block_config = config.block_config(0)
63
60
  self.transformer_blocks = nn.ModuleList(
64
- attention.TransformerBlock(config) for _ in range(config.num_layers)
61
+ attention.TransformerBlock(block_config, config)
62
+ for _ in range(config.num_layers)
65
63
  )
66
64
  self.final_norm = builder.build_norm(
67
65
  config.embedding_dim,
68
66
  config.final_norm_config,
69
67
  )
68
+ attn_config = block_config.attn_config
70
69
  self.rope_cache = attn_utils.build_rope_cache(
71
70
  size=config.kv_cache_max,
72
- dim=int(
73
- config.attn_config.rotary_percentage * config.attn_config.head_dim
74
- ),
71
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
75
72
  base=10_000,
76
73
  condense_ratio=1,
77
74
  dtype=torch.float32,
@@ -89,13 +86,17 @@ class Phi2(nn.Module):
89
86
  self,
90
87
  tokens: torch.Tensor,
91
88
  input_pos: torch.Tensor,
92
- kv_cache: kv_utils.EKVCache,
93
- ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
89
+ kv_cache: kv_utils.KVCache,
90
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
94
91
  _, seq_len = tokens.size()
95
92
  assert self.config.max_seq_len >= seq_len, (
96
93
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
97
94
  f" {self.config.max_seq_len}"
98
95
  )
96
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
97
+ "The number of transformer blocks and the number of KV cache entries"
98
+ " must be the same."
99
+ )
99
100
 
100
101
  cos, sin = self.rope_cache
101
102
  cos = cos.index_select(0, input_pos)
@@ -111,11 +112,11 @@ class Phi2(nn.Module):
111
112
  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
112
113
  if kv_entry:
113
114
  updated_kv_entires.append(kv_entry)
114
- updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
115
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
115
116
 
116
117
  x = self.final_norm(x)
117
- res = self.lm_head(x) # (b, t, vocab_size)
118
- return res, updated_kv_cache
118
+ logits = self.lm_head(x) # (b, t, vocab_size)
119
+ return {"logits": logits, "kv_cache": updated_kv_cache}
119
120
 
120
121
 
121
122
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -143,17 +144,20 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
143
144
  use_bias=True,
144
145
  )
145
146
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
147
+ block_config = cfg.TransformerBlockConfig(
148
+ attn_config=attn_config,
149
+ ff_config=ff_config,
150
+ pre_attention_norm_config=norm_config,
151
+ parallel_residual=True,
152
+ )
146
153
  config = cfg.ModelConfig(
147
154
  vocab_size=51200,
148
155
  num_layers=32,
149
156
  max_seq_len=2048,
150
157
  kv_cache_max_len=kv_cache_max_len,
151
158
  embedding_dim=2560,
152
- attn_config=attn_config,
153
- ff_config=ff_config,
154
- pre_attention_norm_config=norm_config,
159
+ block_configs=block_config,
155
160
  final_norm_config=norm_config,
156
- parallel_residual=True,
157
161
  lm_head_use_bias=True,
158
162
  enable_hlfb=True,
159
163
  )
@@ -165,43 +169,42 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
165
169
  config.vocab_size = 128
166
170
  config.num_layers = 2
167
171
  config.max_seq_len = 2 * kv_cache_max_len
168
- config.ff_config.intermediate_size = 128
172
+ # Phi-2 has only one block config.
173
+ config.block_config(0).ff_config.intermediate_size = 128
169
174
  return config
170
175
 
171
176
 
172
- def build_model(
173
- checkpoint_path: str, test_model: bool = False, **kwargs
174
- ) -> nn.Module:
177
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
175
178
  """Instantiates the model instance and load checkpoint if provided."""
176
- config = (
177
- get_fake_model_config(**kwargs)
178
- if test_model
179
- else get_model_config(**kwargs)
180
- )
179
+ config = get_model_config(**kwargs)
181
180
  model = Phi2(config)
182
- if checkpoint_path is not None:
183
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
184
- loader.load(model)
181
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
182
+ loader.load(model)
185
183
  model.eval()
186
184
  return model
187
185
 
188
186
 
189
- def define_and_run(checkpoint_path: str, test_model: bool = False) -> None:
187
+ def define_and_run(checkpoint_path: str) -> None:
190
188
  """Instantiates and runs a Phi-2 model."""
191
189
 
190
+ current_dir = pathlib.Path(__file__).parent.resolve()
191
+ phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
192
192
  kv_cache_max_len = 1024
193
- model = build_model(
194
- checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
195
- )
193
+ model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
196
194
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
197
195
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
198
196
  tokens[0, :4] = idx
199
197
  input_pos = torch.arange(0, kv_cache_max_len)
200
- kv = kv_utils.EKVCache.from_model_config(model.config)
201
- print("running an inference")
202
- print(model.forward(tokens, input_pos, kv))
198
+ kv = kv_utils.KVCache.from_model_config(model.config)
199
+ output = model.forward(tokens, input_pos, kv)
200
+ print("comparing with goldens..")
201
+ assert torch.allclose(
202
+ phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
203
+ )
203
204
 
204
205
 
205
206
  if __name__ == "__main__":
206
- input_checkpoint_path = os.path.join(Path.home(), "Downloads/phi2")
207
+ input_checkpoint_path = os.path.join(
208
+ pathlib.Path.home(), "Downloads/llm_data/phi2"
209
+ )
207
210
  define_and_run(input_checkpoint_path)
@@ -12,30 +12,27 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- #
16
- # Note: This is an experimental version of Gemma with external KV cache.
17
- # Please use with caution.
18
15
 
16
+ """Example of converting SmalLM model to multi-signature tflite model."""
19
17
 
20
18
  import os
21
- from pathlib import Path
19
+ import pathlib
22
20
 
23
21
  import ai_edge_torch
24
- from ai_edge_torch.generative.examples.experimental.gemma import gemma
25
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
22
+ from ai_edge_torch.generative.examples.smallm import smallm
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
26
24
  from ai_edge_torch.generative.quantize import quant_recipes
27
25
  import torch
28
26
 
29
27
 
30
- def convert_gemma_to_tflite(
28
+ def convert_smallm_to_tflite(
31
29
  checkpoint_path: str,
32
30
  prefill_seq_len: int = 512,
33
31
  kv_cache_max_len: int = 1024,
34
32
  quantize: bool = True,
35
33
  ):
36
- """An example method for converting a Gemma 2B model to multi-signature
34
+ """Converts SmalLM model to multi-signature tflite model.
37
35
 
38
- tflite model.
39
36
  Args:
40
37
  checkpoint_path (str): The filepath to the model checkpoint, or directory
41
38
  holding the checkpoint.
@@ -46,7 +43,7 @@ def convert_gemma_to_tflite(
46
43
  quantize (bool, optional): Whether the model should be quanized. Defaults
47
44
  to True.
48
45
  """
49
- pytorch_model = gemma.build_2b_model(
46
+ pytorch_model = smallm.build_model(
50
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
51
48
  )
52
49
  # Tensors used to trace the model graph during conversion.
@@ -54,7 +51,7 @@ def convert_gemma_to_tflite(
54
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
55
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
56
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
57
- kv = kv_utils.EKVCache.from_model_config(pytorch_model.config)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
58
55
 
59
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
60
57
  edge_model = (
@@ -78,11 +75,12 @@ def convert_gemma_to_tflite(
78
75
  )
79
76
  .convert(quant_config=quant_config)
80
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
81
79
  edge_model.export(
82
- f'/tmp/gemma_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
80
+ f'/tmp/smallm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
83
81
  )
84
82
 
85
83
 
86
84
  if __name__ == '__main__':
87
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
88
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smallm')
86
+ convert_smallm_to_tflite(path)
@@ -0,0 +1,122 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a SmalLM model."""
17
+
18
+ import copy
19
+ import os
20
+ import pathlib
21
+
22
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
+ import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import numpy as np
27
+ import torch
28
+ from torch import nn
29
+
30
+ TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
31
+ # SmalLM re-uses the embedding as the head projection layer.
32
+ TENSOR_NAMES.lm_head = None
33
+
34
+
35
+ class SmalLM(tiny_llama.TinyLlama):
36
+ """A SmalLM model built from the Edge Generative API layers.
37
+
38
+ SmalLM shares the same architecture as TinyLlama, but with different model
39
+ sizes.
40
+ """
41
+
42
+ def __init__(self, config: cfg.ModelConfig):
43
+ super().__init__(config)
44
+ # SmalLM re-uses the embedding as the head projection layer.
45
+ self.lm_head.weight.data = self.tok_embedding.weight.data
46
+
47
+
48
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
49
+ """Returns the model config for a SmalLM 135M model.
50
+
51
+ Args:
52
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
53
+ is 1024.
54
+
55
+ Returns:
56
+ The model config for a SmalLM model.
57
+ """
58
+ attn_config = cfg.AttentionConfig(
59
+ num_heads=9,
60
+ head_dim=64,
61
+ num_query_groups=3,
62
+ rotary_percentage=1.0,
63
+ )
64
+ ff_config = cfg.FeedForwardConfig(
65
+ type=cfg.FeedForwardType.GATED,
66
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
67
+ intermediate_size=1536,
68
+ )
69
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
70
+ block_config = cfg.TransformerBlockConfig(
71
+ attn_config=attn_config,
72
+ ff_config=ff_config,
73
+ pre_attention_norm_config=norm_config,
74
+ post_attention_norm_config=norm_config,
75
+ )
76
+ config = cfg.ModelConfig(
77
+ vocab_size=49152,
78
+ num_layers=30,
79
+ max_seq_len=2048,
80
+ embedding_dim=576,
81
+ kv_cache_max_len=kv_cache_max_len,
82
+ block_configs=block_config,
83
+ final_norm_config=norm_config,
84
+ enable_hlfb=True,
85
+ )
86
+ return config
87
+
88
+
89
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
90
+ config = get_model_config(**kwargs)
91
+ model = SmalLM(config)
92
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
93
+ # Since embedding and lm-head use the same weight, we need to set strict
94
+ # to False.
95
+ loader.load(model, strict=False)
96
+ model.eval()
97
+ return model
98
+
99
+
100
+ def define_and_run(checkpoint_path: str) -> None:
101
+ """Instantiates and runs a SmalLM model."""
102
+
103
+ current_dir = pathlib.Path(__file__).parent.resolve()
104
+ smallm_goldens = torch.load(current_dir / "smallm_lm_logits.pt")
105
+ kv_cache_max_len = 1024
106
+ model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
107
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
108
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
109
+ tokens[0, :4] = idx
110
+ input_pos = torch.arange(0, kv_cache_max_len)
111
+ kv = kv_utils.KVCache.from_model_config(model.config)
112
+ output = model.forward(tokens, input_pos, kv)
113
+ assert torch.allclose(
114
+ smallm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
115
+ )
116
+
117
+
118
+ if __name__ == "__main__":
119
+ input_checkpoint_path = os.path.join(
120
+ pathlib.Path.home(), "Downloads/llm_data/smallm"
121
+ )
122
+ define_and_run(input_checkpoint_path)
@@ -61,8 +61,10 @@ class CLIP(nn.Module):
61
61
  )
62
62
 
63
63
  self.config = config
64
+ # CLIP has only one block config.
65
+ block_config = config.block_config(0)
64
66
  self.transformer_blocks = nn.ModuleList(
65
- TransformerBlock(config) for _ in range(config.num_layers)
67
+ TransformerBlock(block_config, config) for _ in range(config.num_layers)
66
68
  )
67
69
  self.final_norm = builder.build_norm(
68
70
  config.embedding_dim, config.final_norm_config
@@ -112,15 +114,19 @@ def get_model_config() -> cfg.ModelConfig:
112
114
 
113
115
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
114
116
 
117
+ block_config = cfg.TransformerBlockConfig(
118
+ attn_config=attn_config,
119
+ ff_config=ff_config,
120
+ pre_attention_norm_config=norm_config,
121
+ post_attention_norm_config=norm_config,
122
+ )
123
+
115
124
  config = cfg.ModelConfig(
116
125
  vocab_size=vocab_size,
117
126
  num_layers=num_layers,
118
127
  max_seq_len=max_seq_len,
119
128
  embedding_dim=embedding_dim,
120
- attn_config=attn_config,
121
- ff_config=ff_config,
122
- pre_attention_norm_config=norm_config,
123
- post_attention_norm_config=norm_config,
129
+ block_configs=block_config,
124
130
  final_norm_config=norm_config,
125
131
  enable_hlfb=True,
126
132
  )
@@ -52,9 +52,15 @@ class T5Stack(nn.Module):
52
52
  self.config = config
53
53
  self.embed_tokens = embed_tokens
54
54
  self.is_decoder = config.is_decoder
55
+ # T5 has only one block config.
56
+ block_config = config.block_config(0)
55
57
  self.transformer_blocks = nn.ModuleList([
56
- EncoderDecoderBlock(config, has_relative_attention_bias=bool(i == 0))
57
- for i in range(config.num_layers)
58
+ EncoderDecoderBlock(
59
+ block_config,
60
+ config,
61
+ has_relative_attention_bias=bool(idx == 0),
62
+ )
63
+ for idx in range(config.num_layers)
58
64
  ])
59
65
  self.final_norm = builder.build_norm(
60
66
  config.embedding_dim, config.final_norm_config
@@ -73,13 +79,11 @@ class T5Stack(nn.Module):
73
79
  torch.Tensor
74
80
  ] = None, # should be for decoder case
75
81
  ):
76
- input_shape = input_ids.size()
77
82
  inputs_embeds = self.embed_tokens(input_ids)
78
- batch_size, seq_length = input_shape
79
83
  hidden_states = inputs_embeds
80
84
  position_bias = None
81
85
  encoder_decoder_position_bias = None
82
- for i, layer_module in enumerate(self.transformer_blocks):
86
+ for _, layer_module in enumerate(self.transformer_blocks):
83
87
  # EncoderDecoderBlock.forward
84
88
  hidden_states, position_bias, encoder_decoder_position_bias = (
85
89
  layer_module(
@@ -111,7 +115,8 @@ class T5(nn.Module):
111
115
 
112
116
  encoder_config = copy.deepcopy(config)
113
117
  encoder_config.is_decoder = False
114
- encoder_config.attn_config.enable_kv_cache = False
118
+ # T5 has only one block config.
119
+ encoder_config.block_config(0).attn_config.enable_kv_cache = False
115
120
  self.encoder = T5Stack(encoder_config, self.tok_embedding)
116
121
 
117
122
  decoder_config = copy.deepcopy(config)
@@ -137,20 +142,22 @@ class T5(nn.Module):
137
142
  device=torch.device("cpu"),
138
143
  )
139
144
 
145
+ # T5 has only one block config.
146
+ attn_config = config.block_config(0).attn_config
140
147
  self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
141
148
  bidirectional=True,
142
149
  query_length=config.kv_cache_max,
143
150
  key_length=config.kv_cache_max,
144
- num_buckets=config.attn_config.relative_attention_num_buckets,
145
- max_distance=config.attn_config.relative_attention_max_distance,
151
+ num_buckets=attn_config.relative_attention_num_buckets,
152
+ max_distance=attn_config.relative_attention_max_distance,
146
153
  )
147
154
 
148
155
  self.dec_rel_pos_mask = attn_utils.build_relative_position_buckets(
149
156
  bidirectional=False,
150
157
  query_length=config.kv_cache_max,
151
158
  key_length=config.kv_cache_max,
152
- num_buckets=config.attn_config.relative_attention_num_buckets,
153
- max_distance=config.attn_config.relative_attention_max_distance,
159
+ num_buckets=attn_config.relative_attention_num_buckets,
160
+ max_distance=attn_config.relative_attention_max_distance,
154
161
  )
155
162
 
156
163
  @torch.inference_mode
@@ -230,7 +237,8 @@ class T5Encoder(nn.Module):
230
237
 
231
238
  encoder_config = copy.deepcopy(config)
232
239
  encoder_config.is_decoder = False
233
- encoder_config.attn_config.enable_kv_cache = False
240
+ # T5 has only one block config.
241
+ encoder_config.block_config(0).attn_config.enable_kv_cache = False
234
242
  self.encoder = T5Stack(encoder_config, self.tok_embedding)
235
243
 
236
244
  self.enc_attn_mask_cache = (
@@ -243,12 +251,14 @@ class T5Encoder(nn.Module):
243
251
  .unsqueeze(0)
244
252
  )
245
253
 
254
+ # T5 has only one block config.
255
+ attn_config = config.block_config(0).attn_config
246
256
  self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
247
257
  bidirectional=True,
248
258
  query_length=config.kv_cache_max,
249
259
  key_length=config.kv_cache_max,
250
- num_buckets=config.attn_config.relative_attention_num_buckets,
251
- max_distance=config.attn_config.relative_attention_max_distance,
260
+ num_buckets=attn_config.relative_attention_num_buckets,
261
+ max_distance=attn_config.relative_attention_max_distance,
252
262
  )
253
263
 
254
264
  @torch.inference_mode
@@ -313,12 +323,14 @@ class T5Decoder(nn.Module):
313
323
  .unsqueeze(0)
314
324
  )
315
325
 
326
+ # T5 has only one block config.
327
+ attn_config = config.block_config(0).attn_config
316
328
  self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
317
329
  bidirectional=True,
318
330
  query_length=config.kv_cache_max,
319
331
  key_length=config.kv_cache_max,
320
- num_buckets=config.attn_config.relative_attention_num_buckets,
321
- max_distance=config.attn_config.relative_attention_max_distance,
332
+ num_buckets=attn_config.relative_attention_num_buckets,
333
+ max_distance=attn_config.relative_attention_max_distance,
322
334
  )
323
335
 
324
336
  self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
@@ -386,19 +398,20 @@ def get_model_config_t5() -> cfg.ModelConfig:
386
398
  type=cfg.NormalizationType.RMS_NORM,
387
399
  epsilon=1e-6,
388
400
  )
389
-
390
- config = cfg.ModelConfig(
391
- vocab_size=32128,
392
- num_layers=12,
393
- max_seq_len=512,
394
- embedding_dim=768,
401
+ block_config = cfg.TransformerBlockConfig(
395
402
  attn_config=attn_config,
396
403
  relative_attention=True,
397
404
  ff_config=ff_config,
398
405
  pre_attention_norm_config=norm_config,
399
406
  post_attention_norm_config=norm_config,
407
+ )
408
+ config = cfg.ModelConfig(
409
+ vocab_size=32128,
410
+ num_layers=12,
411
+ max_seq_len=512,
412
+ embedding_dim=768,
413
+ block_configs=block_config,
400
414
  final_norm_config=norm_config,
401
- parallel_residual=False,
402
415
  lm_head_use_bias=False,
403
416
  enable_hlfb=True,
404
417
  )
@@ -24,7 +24,6 @@ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_
24
24
  from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
25
25
  import torch
26
26
  from torch import nn
27
- import torch.nn.functional as F
28
27
 
29
28
  BATCH_SIZE = 1
30
29
 
@@ -32,13 +31,18 @@ BATCH_SIZE = 1
32
31
  class EncoderDecoderBlock(nn.Module):
33
32
 
34
33
  def __init__(
35
- self, config: cfg.ModelConfig, has_relative_attention_bias: bool = False
34
+ self,
35
+ config: cfg.TransformerBlockConfig,
36
+ model_config: cfg.ModelConfig,
37
+ has_relative_attention_bias: bool = False,
36
38
  ) -> None:
37
39
  """Initialize an instance of the EncoderDecoderBlock.
38
40
 
39
41
  Args:
40
- config (cfg.ModelConfig): the configuration object for this transformer
41
- block.
42
+ config (cfg.TransformerBlockConfig): the configuration object for this
43
+ transformer block.
44
+ model_config (cfg.ModelConfig): the configuration object for the model
45
+ this transformer block belongs to.
42
46
  has_relative_attention_bias (bool): whether the self attention block has
43
47
  relative bias.
44
48
  """
@@ -46,22 +50,22 @@ class EncoderDecoderBlock(nn.Module):
46
50
  super().__init__()
47
51
  self.atten_func = T5Attention(
48
52
  BATCH_SIZE,
49
- config.embedding_dim,
53
+ model_config.embedding_dim,
50
54
  config.attn_config,
51
55
  config.pre_attention_norm_config,
52
- config.kv_cache_max,
53
- config.enable_hlfb,
56
+ model_config.kv_cache_max,
57
+ model_config.enable_hlfb,
54
58
  has_relative_attention_bias=has_relative_attention_bias,
55
59
  )
56
60
  # For a decoder, we add a cross attention.
57
- if config.is_decoder:
61
+ if model_config.is_decoder:
58
62
  self.cross_atten_func = T5Attention(
59
63
  BATCH_SIZE,
60
- config.embedding_dim,
64
+ model_config.embedding_dim,
61
65
  config.attn_config,
62
66
  config.pre_attention_norm_config,
63
- config.kv_cache_max,
64
- config.enable_hlfb,
67
+ model_config.kv_cache_max,
68
+ model_config.enable_hlfb,
65
69
  # Cross Attention does not have relative attention bias.
66
70
  has_relative_attention_bias=False,
67
71
  )
@@ -69,9 +73,10 @@ class EncoderDecoderBlock(nn.Module):
69
73
  self.cross_atten_func = None
70
74
 
71
75
  self.post_atten_norm = builder.build_norm(
72
- config.embedding_dim, config.post_attention_norm_config
76
+ model_config.embedding_dim,
77
+ config.post_attention_norm_config,
73
78
  )
74
- self.ff = builder.build_ff(config.embedding_dim, config.ff_config)
79
+ self.ff = builder.build_ff(model_config.embedding_dim, config.ff_config)
75
80
  self.config = config
76
81
 
77
82
  def forward(