ai-edge-torch-nightly 0.3.0.dev20240930__py3-none-any.whl → 0.3.0.dev20241003__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (26) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +10 -93
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +0 -1
  3. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +13 -2
  4. ai_edge_torch/generative/examples/llama/llama.py +19 -24
  5. ai_edge_torch/generative/examples/llama/verify.py +18 -3
  6. ai_edge_torch/generative/examples/openelm/openelm.py +9 -90
  7. ai_edge_torch/generative/examples/phi/phi2.py +10 -86
  8. ai_edge_torch/generative/examples/phi/phi3.py +9 -69
  9. ai_edge_torch/generative/examples/qwen/qwen.py +26 -36
  10. ai_edge_torch/generative/examples/smollm/smollm.py +10 -30
  11. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +12 -5
  12. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +11 -101
  13. ai_edge_torch/generative/layers/model_config.py +6 -0
  14. ai_edge_torch/generative/test/test_loader.py +2 -1
  15. ai_edge_torch/generative/test/test_model_conversion.py +2 -1
  16. ai_edge_torch/generative/test/test_model_conversion_large.py +7 -8
  17. ai_edge_torch/generative/utilities/model_builder.py +141 -0
  18. ai_edge_torch/lowertools/torch_xla_utils.py +3 -0
  19. ai_edge_torch/version.py +1 -1
  20. {ai_edge_torch_nightly-0.3.0.dev20240930.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/METADATA +1 -1
  21. {ai_edge_torch_nightly-0.3.0.dev20240930.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/RECORD +24 -25
  22. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +0 -68
  23. ai_edge_torch/generative/examples/llama/verify_3b.py +0 -73
  24. {ai_edge_torch_nightly-0.3.0.dev20240930.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/LICENSE +0 -0
  25. {ai_edge_torch_nightly-0.3.0.dev20240930.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/WHEEL +0 -0
  26. {ai_edge_torch_nightly-0.3.0.dev20240930.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/top_level.txt +0 -0
@@ -15,28 +15,10 @@
15
15
 
16
16
  """Example of building Qwen 2.5 models."""
17
17
 
18
- import copy
19
-
20
- from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
18
  import ai_edge_torch.generative.layers.model_config as cfg
22
- import ai_edge_torch.generative.utilities.loader as loading_utils
23
- from torch import nn
24
-
25
- TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
26
- # Qwen re-uses the embedding as the head projection layer.
27
- TENSOR_NAMES.lm_head = None
28
-
29
-
30
- class Qwen(tiny_llama.TinyLlama):
31
- """A Qwen model built from the Edge Generative API layers.
32
-
33
- Qwen 2.5 shares the same architecture as TinyLlama.
34
- """
19
+ from ai_edge_torch.generative.utilities import model_builder
35
20
 
36
- def __init__(self, config: cfg.ModelConfig):
37
- super().__init__(config)
38
- # Qwen re-uses the embedding as the head projection layer.
39
- self.lm_head.weight.data = self.tok_embedding.weight.data
21
+ TENSOR_NAMES = model_builder.TENSOR_NAMES
40
22
 
41
23
 
42
24
  def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -119,23 +101,31 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
119
101
  return config
120
102
 
121
103
 
122
- def _build_model(checkpoint_path: str, config: cfg.ModelConfig) -> nn.Module:
123
- model = Qwen(config)
124
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
125
- # Since embedding and lm-head use the same weight, we need to set strict
126
- # to False.
127
- loader.load(model, strict=False)
128
- model.eval()
129
- return model
130
-
131
-
132
- def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
133
- return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
104
+ def build_3b_model(
105
+ checkpoint_path: str, **kwargs
106
+ ) -> model_builder.DecoderOnlyModel:
107
+ return model_builder.build_decoder_only_model(
108
+ checkpoint_path=checkpoint_path,
109
+ config=get_3b_model_config(**kwargs),
110
+ tensor_names=TENSOR_NAMES,
111
+ )
134
112
 
135
113
 
136
- def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
137
- return _build_model(checkpoint_path, get_1_5b_model_config(**kwargs))
114
+ def build_1_5b_model(
115
+ checkpoint_path: str, **kwargs
116
+ ) -> model_builder.DecoderOnlyModel:
117
+ return model_builder.build_decoder_only_model(
118
+ checkpoint_path=checkpoint_path,
119
+ config=get_1_5b_model_config(**kwargs),
120
+ tensor_names=TENSOR_NAMES,
121
+ )
138
122
 
139
123
 
140
- def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
141
- return _build_model(checkpoint_path, get_0_5b_model_config(**kwargs))
124
+ def build_0_5b_model(
125
+ checkpoint_path: str, **kwargs
126
+ ) -> model_builder.DecoderOnlyModel:
127
+ return model_builder.build_decoder_only_model(
128
+ checkpoint_path=checkpoint_path,
129
+ config=get_0_5b_model_config(**kwargs),
130
+ tensor_names=TENSOR_NAMES,
131
+ )
@@ -15,29 +15,10 @@
15
15
 
16
16
  """Example of building a SmolLM model."""
17
17
 
18
- import copy
19
-
20
- from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
18
  import ai_edge_torch.generative.layers.model_config as cfg
22
- import ai_edge_torch.generative.utilities.loader as loading_utils
23
- from torch import nn
24
-
25
- TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
26
- # SmolLM re-uses the embedding as the head projection layer.
27
- TENSOR_NAMES.lm_head = None
28
-
29
-
30
- class SmolLM(tiny_llama.TinyLlama):
31
- """A SmolLM model built from the Edge Generative API layers.
19
+ from ai_edge_torch.generative.utilities import model_builder
32
20
 
33
- SmolLM shares the same architecture as TinyLlama, but with different model
34
- sizes.
35
- """
36
-
37
- def __init__(self, config: cfg.ModelConfig):
38
- super().__init__(config)
39
- # SmolLM re-uses the embedding as the head projection layer.
40
- self.lm_head.weight.data = self.tok_embedding.weight.data
21
+ TENSOR_NAMES = model_builder.TENSOR_NAMES
41
22
 
42
23
 
43
24
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -91,12 +72,11 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
91
72
  return config
92
73
 
93
74
 
94
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
95
- config = get_model_config(**kwargs)
96
- model = SmolLM(config)
97
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
98
- # Since embedding and lm-head use the same weight, we need to set strict
99
- # to False.
100
- loader.load(model, strict=False)
101
- model.eval()
102
- return model
75
+ def build_model(
76
+ checkpoint_path: str, **kwargs
77
+ ) -> model_builder.DecoderOnlyModel:
78
+ return model_builder.build_decoder_only_model(
79
+ checkpoint_path=checkpoint_path,
80
+ config=get_model_config(**kwargs),
81
+ tensor_names=TENSOR_NAMES,
82
+ )
@@ -180,9 +180,13 @@ def run_tflite_pipeline(
180
180
 
181
181
  # Text embedding.
182
182
  cond_tokens = model.tokenizer.encode(prompt)
183
- cond_context = model.clip(np.array(cond_tokens), signature_name='encode')
183
+ cond_context = model.clip(
184
+ np.array(cond_tokens).astype(np.int32), signature_name='encode'
185
+ )
184
186
  uncond_tokens = model.tokenizer.encode(uncond_prompt)
185
- uncond_context = model.clip(np.array(uncond_tokens), signature_name='encode')
187
+ uncond_context = model.clip(
188
+ np.array(uncond_tokens).astype(np.int32), signature_name='encode'
189
+ )
186
190
  context = np.concatenate([cond_context, uncond_context], axis=0)
187
191
  noise_shape = (1, 4, height // 8, width // 8)
188
192
 
@@ -198,7 +202,7 @@ def run_tflite_pipeline(
198
202
  input_image_np = util.rescale(input_image, (0, 255), (-1, 1))
199
203
  input_image_np = util.move_channel(input_image_np, to='first')
200
204
  encoder_noise = np.random.normal(size=noise_shape).astype(np.float32)
201
- latents = model.encoder(input_image_np, encoder_noise)
205
+ latents = model.encoder(input_image_np.astype(np.float32), encoder_noise)
202
206
  latents_noise = np.random.normal(size=noise_shape).astype(np.float32)
203
207
  sampler.set_strength(strength=strength)
204
208
  latents += latents_noise * sampler.initial_scale
@@ -214,7 +218,10 @@ def run_tflite_pipeline(
214
218
  input_latents = latents * sampler.get_input_scale()
215
219
  input_latents = input_latents.repeat(2, axis=0)
216
220
  output = model.diffusion(
217
- input_latents, context, time_embedding, signature_name='diffusion'
221
+ input_latents.astype(np.float32),
222
+ context.astype(np.float32),
223
+ time_embedding,
224
+ signature_name='diffusion',
218
225
  )
219
226
  output_cond, output_uncond = np.split(output, 2, axis=0)
220
227
  output = cfg_scale * (output_cond - output_uncond) + output_uncond
@@ -222,7 +229,7 @@ def run_tflite_pipeline(
222
229
  latents = sampler.step(latents, output)
223
230
 
224
231
  # Image decoding.
225
- images = model.decoder(latents, signature_name='decode')
232
+ images = model.decoder(latents.astype(np.float32), signature_name='decode')
226
233
  images = util.rescale(images, (-1, 1), (0, 255), clamp=True)
227
234
  images = util.move_channel(images, to='last')
228
235
  if not os.path.exists(output_path):
@@ -15,102 +15,10 @@
15
15
 
16
16
  """Example of building a TinyLlama model."""
17
17
 
18
- from ai_edge_torch.generative.layers import attention
19
- from ai_edge_torch.generative.layers import builder
20
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
18
  import ai_edge_torch.generative.layers.model_config as cfg
23
- import ai_edge_torch.generative.utilities.loader as loading_utils
24
- import torch
25
- from torch import nn
19
+ from ai_edge_torch.generative.utilities import model_builder
26
20
 
27
- TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
28
- ff_up_proj="model.layers.{}.mlp.up_proj",
29
- ff_down_proj="model.layers.{}.mlp.down_proj",
30
- ff_gate_proj="model.layers.{}.mlp.gate_proj",
31
- attn_query_proj="model.layers.{}.self_attn.q_proj",
32
- attn_key_proj="model.layers.{}.self_attn.k_proj",
33
- attn_value_proj="model.layers.{}.self_attn.v_proj",
34
- attn_output_proj="model.layers.{}.self_attn.o_proj",
35
- pre_attn_norm="model.layers.{}.input_layernorm",
36
- post_attn_norm="model.layers.{}.post_attention_layernorm",
37
- embedding="model.embed_tokens",
38
- final_norm="model.norm",
39
- lm_head="lm_head",
40
- )
41
-
42
-
43
- class TinyLlama(nn.Module):
44
- """A TinyLlama model built from the Edge Generative API layers."""
45
-
46
- def __init__(self, config: cfg.ModelConfig):
47
- super().__init__()
48
-
49
- # Construct model layers.
50
- self.lm_head = nn.Linear(
51
- config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
52
- )
53
- self.tok_embedding = nn.Embedding(
54
- config.vocab_size, config.embedding_dim, padding_idx=0
55
- )
56
- # TinyLlama has only one block config.
57
- block_config = config.block_config(0)
58
- self.transformer_blocks = nn.ModuleList(
59
- attention.TransformerBlock(block_config, config)
60
- for _ in range(config.num_layers)
61
- )
62
- self.final_norm = builder.build_norm(
63
- config.embedding_dim,
64
- config.final_norm_config,
65
- )
66
- attn_config = block_config.attn_config
67
- self.rope_cache = attn_utils.build_rope_cache(
68
- size=config.kv_cache_max,
69
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
70
- base=attn_config.rotary_base,
71
- )
72
- self.mask_cache = attn_utils.build_causal_mask_cache(
73
- size=config.kv_cache_max,
74
- )
75
- self.config = config
76
-
77
- @torch.inference_mode
78
- def forward(
79
- self,
80
- tokens: torch.Tensor,
81
- input_pos: torch.Tensor,
82
- kv_cache: kv_utils.KVCache,
83
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
84
- _, seq_len = tokens.size()
85
- assert self.config.max_seq_len >= seq_len, (
86
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
87
- f" {self.config.max_seq_len}"
88
- )
89
- assert len(self.transformer_blocks) == len(kv_cache.caches), (
90
- "The number of transformer blocks and the number of KV cache entries"
91
- " must be the same."
92
- )
93
-
94
- cos, sin = self.rope_cache
95
- cos = cos.index_select(0, input_pos)
96
- sin = sin.index_select(0, input_pos)
97
- mask = self.mask_cache.index_select(2, input_pos)
98
- mask = mask[:, :, :, : self.config.kv_cache_max]
99
-
100
- # token embeddings of shape (b, t, n_embd)
101
- x = self.tok_embedding(tokens)
102
-
103
- updated_kv_entires = []
104
- for i, block in enumerate(self.transformer_blocks):
105
- kv_entry = kv_cache.caches[i] if kv_cache else None
106
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
107
- if kv_entry:
108
- updated_kv_entires.append(kv_entry)
109
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
110
-
111
- x = self.final_norm(x)
112
- logits = self.lm_head(x) # (b, t, vocab_size)
113
- return {"logits": logits, "kv_cache": updated_kv_cache}
21
+ TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
114
22
 
115
23
 
116
24
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -150,6 +58,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
150
58
  kv_cache_max_len=kv_cache_max_len,
151
59
  block_configs=block_config,
152
60
  final_norm_config=norm_config,
61
+ lm_head_share_weight_with_embedding=False,
153
62
  enable_hlfb=True,
154
63
  )
155
64
  return config
@@ -164,10 +73,11 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
164
73
  return config
165
74
 
166
75
 
167
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
168
- config = get_model_config(**kwargs)
169
- model = TinyLlama(config)
170
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
171
- loader.load(model)
172
- model.eval()
173
- return model
76
+ def build_model(
77
+ checkpoint_path: str, **kwargs
78
+ ) -> model_builder.DecoderOnlyModel:
79
+ return model_builder.build_decoder_only_model(
80
+ checkpoint_path=checkpoint_path,
81
+ config=get_model_config(**kwargs),
82
+ tensor_names=TENSOR_NAMES,
83
+ )
@@ -184,8 +184,14 @@ class ModelConfig:
184
184
  default_factory=NormalizationConfig
185
185
  )
186
186
 
187
+ # Scale factor of the embedding.
188
+ embedding_scale: Optional[float] = None
189
+
187
190
  # Use bias term within LLM's HEAD.
188
191
  lm_head_use_bias: bool = False
192
+ # Whether LLM's HEAD shares the weight of the embedding.
193
+ lm_head_share_weight_with_embedding: bool = True
194
+
189
195
  # Whether to turn on high-level function boundary.
190
196
  enable_hlfb: bool = False
191
197
 
@@ -19,6 +19,7 @@ import tempfile
19
19
 
20
20
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
21
  from ai_edge_torch.generative.utilities import loader as loading_utils
22
+ from ai_edge_torch.generative.utilities import model_builder
22
23
  import safetensors.torch
23
24
  import torch
24
25
 
@@ -71,7 +72,7 @@ class TestLoader(googletest.TestCase):
71
72
  safetensors.torch.save_file(test_weights, file_path)
72
73
  cfg = tiny_llama.get_model_config()
73
74
  cfg.num_layers = 1
74
- model = tiny_llama.TinyLlama(cfg)
75
+ model = model_builder.DecoderOnlyModel(cfg)
75
76
 
76
77
  loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
77
78
  # if returns successfully, it means all the tensors were initiallized.
@@ -21,6 +21,7 @@ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cach
21
21
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
22
22
  from ai_edge_torch.generative.layers import kv_cache
23
23
  from ai_edge_torch.generative.test import utils as test_utils
24
+ from ai_edge_torch.generative.utilities import model_builder
24
25
  import numpy as np
25
26
  import torch
26
27
 
@@ -163,7 +164,7 @@ class TestModelConversion(googletest.TestCase):
163
164
  )
164
165
  def test_tiny_llama_multisig(self):
165
166
  config = tiny_llama.get_fake_model_config()
166
- pytorch_model = tiny_llama.TinyLlama(config).eval()
167
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
167
168
  self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
168
169
 
169
170
 
@@ -29,6 +29,7 @@ from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
29
29
  from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
30
30
  from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
31
31
  from ai_edge_torch.generative.layers import kv_cache
32
+ from ai_edge_torch.generative.utilities import model_builder
32
33
  from ai_edge_torch.generative.test import utils as test_utils
33
34
  import numpy as np
34
35
  import torch
@@ -90,7 +91,7 @@ class TestModelConversion(googletest.TestCase):
90
91
  )
91
92
  def test_gemma1(self):
92
93
  config = gemma1.get_fake_model_config()
93
- pytorch_model = gemma1.Gemma(config).eval()
94
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
94
95
  self._test_model(
95
96
  config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
96
97
  )
@@ -119,7 +120,7 @@ class TestModelConversion(googletest.TestCase):
119
120
  )
120
121
  def test_phi2(self):
121
122
  config = phi2.get_fake_model_config()
122
- pytorch_model = phi2.Phi2(config).eval()
123
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
123
124
  self._test_model(
124
125
  config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
125
126
  )
@@ -131,9 +132,7 @@ class TestModelConversion(googletest.TestCase):
131
132
  def test_phi3(self):
132
133
  config = phi3.get_fake_model_config()
133
134
  pytorch_model = phi3.Phi3_5Mini(config).eval()
134
- self._test_model(
135
- config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5
136
- )
135
+ self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
137
136
 
138
137
  @googletest.skipIf(
139
138
  ai_edge_config.Config.use_torch_xla,
@@ -141,7 +140,7 @@ class TestModelConversion(googletest.TestCase):
141
140
  )
142
141
  def test_smollm(self):
143
142
  config = smollm.get_fake_model_config()
144
- pytorch_model = smollm.SmolLM(config).eval()
143
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
145
144
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
146
145
 
147
146
  @googletest.skipIf(
@@ -150,7 +149,7 @@ class TestModelConversion(googletest.TestCase):
150
149
  )
151
150
  def test_openelm(self):
152
151
  config = openelm.get_fake_model_config()
153
- pytorch_model = openelm.OpenELM(config).eval()
152
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
154
153
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
155
154
 
156
155
  @googletest.skipIf(
@@ -159,7 +158,7 @@ class TestModelConversion(googletest.TestCase):
159
158
  )
160
159
  def test_qwen(self):
161
160
  config = qwen.get_fake_model_config()
162
- pytorch_model = qwen.Qwen(config).eval()
161
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
163
162
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
164
163
 
165
164
  @googletest.skipIf(
@@ -0,0 +1,141 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utilities to be used for re-authoring transformer models."""
17
+
18
+ import copy
19
+
20
+ from ai_edge_torch.generative.layers import attention
21
+ from ai_edge_torch.generative.layers import builder
22
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
+ import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import torch
27
+ from torch import nn
28
+
29
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
30
+ ff_up_proj="model.layers.{}.mlp.up_proj",
31
+ ff_down_proj="model.layers.{}.mlp.down_proj",
32
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
33
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
34
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
35
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
36
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
37
+ pre_attn_norm="model.layers.{}.input_layernorm",
38
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
39
+ embedding="model.embed_tokens",
40
+ final_norm="model.norm",
41
+ )
42
+
43
+ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD = copy.copy(TENSOR_NAMES)
44
+ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD.lm_head = "lm_head"
45
+
46
+
47
+ class DecoderOnlyModel(nn.Module):
48
+ """A simple decoder-only transformer model built from the Edge Generative API.
49
+
50
+ This model is used for re-authoring. model_config is used to specify the
51
+ details of model architecture and parameters.
52
+
53
+ It assumes that the attention configs for ROPE, i.e. head_dim, rotary_base,
54
+ and rotary_percentage are the same for all layers.
55
+ """
56
+
57
+ def __init__(self, config: cfg.ModelConfig):
58
+ super().__init__()
59
+
60
+ # Construct model layers.
61
+ self.tok_embedding = nn.Embedding(
62
+ config.vocab_size, config.embedding_dim, padding_idx=0
63
+ )
64
+ self.lm_head = nn.Linear(
65
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
66
+ )
67
+ if config.lm_head_share_weight_with_embedding:
68
+ self.lm_head.weight.data = self.tok_embedding.weight.data
69
+ self.transformer_blocks = nn.ModuleList(
70
+ attention.TransformerBlock(config.block_config(idx), config)
71
+ for idx in range(config.num_layers)
72
+ )
73
+ self.final_norm = builder.build_norm(
74
+ config.embedding_dim,
75
+ config.final_norm_config,
76
+ )
77
+ # ROPE parameters for all attn_configs are the same. Take the first one.
78
+ attn_config = config.block_config(0).attn_config
79
+ self.rope_cache = attn_utils.build_rope_cache(
80
+ size=config.kv_cache_max,
81
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
82
+ base=attn_config.rotary_base,
83
+ )
84
+ self.mask_cache = attn_utils.build_causal_mask_cache(
85
+ size=config.kv_cache_max,
86
+ )
87
+ self.config = config
88
+
89
+ @torch.inference_mode
90
+ def forward(
91
+ self,
92
+ tokens: torch.Tensor,
93
+ input_pos: torch.Tensor,
94
+ kv_cache: kv_utils.KVCache,
95
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
96
+ _, seq_len = tokens.size()
97
+ assert self.config.max_seq_len >= seq_len, (
98
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
99
+ f" {self.config.max_seq_len}"
100
+ )
101
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
102
+ "The number of transformer blocks and the number of KV cache entries"
103
+ " must be the same."
104
+ )
105
+
106
+ cos, sin = self.rope_cache
107
+ cos = cos.index_select(0, input_pos)
108
+ sin = sin.index_select(0, input_pos)
109
+ mask = self.mask_cache.index_select(2, input_pos)
110
+ mask = mask[:, :, :, : self.config.kv_cache_max]
111
+
112
+ # token embeddings of shape (b, t, n_embd)
113
+ x = self.tok_embedding(tokens)
114
+ if self.config.embedding_scale is not None:
115
+ x = x * self.config.embedding_scale
116
+
117
+ updated_kv_entires = []
118
+ for i, block in enumerate(self.transformer_blocks):
119
+ kv_entry = kv_cache.caches[i] if kv_cache else None
120
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
121
+ if kv_entry:
122
+ updated_kv_entires.append(kv_entry)
123
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
124
+
125
+ x = self.final_norm(x)
126
+ logits = self.lm_head(x) # (b, t, vocab_size)
127
+ return {"logits": logits, "kv_cache": updated_kv_cache}
128
+
129
+
130
+ def build_decoder_only_model(
131
+ checkpoint_path: str,
132
+ config: cfg.ModelConfig,
133
+ tensor_names: loading_utils.ModelLoader.TensorNames,
134
+ ) -> DecoderOnlyModel:
135
+ transformer = DecoderOnlyModel(config)
136
+ loader = loading_utils.ModelLoader(checkpoint_path, tensor_names)
137
+ loader.load(
138
+ transformer, strict=not config.lm_head_share_weight_with_embedding
139
+ )
140
+ transformer.eval()
141
+ return transformer
@@ -250,6 +250,7 @@ def merged_bundle_to_tfl_model(
250
250
  },
251
251
  )
252
252
  # Clean up intermediate memory early.
253
+ del tf_functions
253
254
  del tf_module
254
255
  del tf_concrete_funcs
255
256
  gc.collect()
@@ -271,6 +272,8 @@ def merged_bundle_to_tfl_model(
271
272
  conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
272
273
 
273
274
  tflite_model = converter.convert()
275
+ del converter
276
+ gc.collect()
274
277
 
275
278
  if (
276
279
  quant_config is not None
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240930"
16
+ __version__ = "0.3.0.dev20241003"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240930
3
+ Version: 0.3.0.dev20241003
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI