ai-edge-torch-nightly 0.3.0.dev20241002__py3-none-any.whl → 0.3.0.dev20241004__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (24) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +10 -93
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +0 -1
  3. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +13 -2
  4. ai_edge_torch/generative/examples/llama/llama.py +19 -24
  5. ai_edge_torch/generative/examples/llama/verify.py +18 -3
  6. ai_edge_torch/generative/examples/openelm/openelm.py +9 -90
  7. ai_edge_torch/generative/examples/phi/phi2.py +10 -86
  8. ai_edge_torch/generative/examples/phi/phi3.py +9 -69
  9. ai_edge_torch/generative/examples/qwen/qwen.py +26 -36
  10. ai_edge_torch/generative/examples/smollm/smollm.py +10 -30
  11. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +11 -101
  12. ai_edge_torch/generative/layers/model_config.py +6 -0
  13. ai_edge_torch/generative/test/test_loader.py +2 -1
  14. ai_edge_torch/generative/test/test_model_conversion.py +39 -17
  15. ai_edge_torch/generative/test/test_model_conversion_large.py +6 -5
  16. ai_edge_torch/generative/utilities/model_builder.py +141 -0
  17. ai_edge_torch/version.py +1 -1
  18. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241004.dist-info}/METADATA +1 -1
  19. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241004.dist-info}/RECORD +22 -23
  20. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +0 -68
  21. ai_edge_torch/generative/examples/llama/verify_3b.py +0 -73
  22. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241004.dist-info}/LICENSE +0 -0
  23. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241004.dist-info}/WHEEL +0 -0
  24. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241004.dist-info}/top_level.txt +0 -0
@@ -15,28 +15,10 @@
15
15
 
16
16
  """Example of building Qwen 2.5 models."""
17
17
 
18
- import copy
19
-
20
- from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
18
  import ai_edge_torch.generative.layers.model_config as cfg
22
- import ai_edge_torch.generative.utilities.loader as loading_utils
23
- from torch import nn
24
-
25
- TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
26
- # Qwen re-uses the embedding as the head projection layer.
27
- TENSOR_NAMES.lm_head = None
28
-
29
-
30
- class Qwen(tiny_llama.TinyLlama):
31
- """A Qwen model built from the Edge Generative API layers.
32
-
33
- Qwen 2.5 shares the same architecture as TinyLlama.
34
- """
19
+ from ai_edge_torch.generative.utilities import model_builder
35
20
 
36
- def __init__(self, config: cfg.ModelConfig):
37
- super().__init__(config)
38
- # Qwen re-uses the embedding as the head projection layer.
39
- self.lm_head.weight.data = self.tok_embedding.weight.data
21
+ TENSOR_NAMES = model_builder.TENSOR_NAMES
40
22
 
41
23
 
42
24
  def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -119,23 +101,31 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
119
101
  return config
120
102
 
121
103
 
122
- def _build_model(checkpoint_path: str, config: cfg.ModelConfig) -> nn.Module:
123
- model = Qwen(config)
124
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
125
- # Since embedding and lm-head use the same weight, we need to set strict
126
- # to False.
127
- loader.load(model, strict=False)
128
- model.eval()
129
- return model
130
-
131
-
132
- def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
133
- return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
104
+ def build_3b_model(
105
+ checkpoint_path: str, **kwargs
106
+ ) -> model_builder.DecoderOnlyModel:
107
+ return model_builder.build_decoder_only_model(
108
+ checkpoint_path=checkpoint_path,
109
+ config=get_3b_model_config(**kwargs),
110
+ tensor_names=TENSOR_NAMES,
111
+ )
134
112
 
135
113
 
136
- def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
137
- return _build_model(checkpoint_path, get_1_5b_model_config(**kwargs))
114
+ def build_1_5b_model(
115
+ checkpoint_path: str, **kwargs
116
+ ) -> model_builder.DecoderOnlyModel:
117
+ return model_builder.build_decoder_only_model(
118
+ checkpoint_path=checkpoint_path,
119
+ config=get_1_5b_model_config(**kwargs),
120
+ tensor_names=TENSOR_NAMES,
121
+ )
138
122
 
139
123
 
140
- def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
141
- return _build_model(checkpoint_path, get_0_5b_model_config(**kwargs))
124
+ def build_0_5b_model(
125
+ checkpoint_path: str, **kwargs
126
+ ) -> model_builder.DecoderOnlyModel:
127
+ return model_builder.build_decoder_only_model(
128
+ checkpoint_path=checkpoint_path,
129
+ config=get_0_5b_model_config(**kwargs),
130
+ tensor_names=TENSOR_NAMES,
131
+ )
@@ -15,29 +15,10 @@
15
15
 
16
16
  """Example of building a SmolLM model."""
17
17
 
18
- import copy
19
-
20
- from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
18
  import ai_edge_torch.generative.layers.model_config as cfg
22
- import ai_edge_torch.generative.utilities.loader as loading_utils
23
- from torch import nn
24
-
25
- TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
26
- # SmolLM re-uses the embedding as the head projection layer.
27
- TENSOR_NAMES.lm_head = None
28
-
29
-
30
- class SmolLM(tiny_llama.TinyLlama):
31
- """A SmolLM model built from the Edge Generative API layers.
19
+ from ai_edge_torch.generative.utilities import model_builder
32
20
 
33
- SmolLM shares the same architecture as TinyLlama, but with different model
34
- sizes.
35
- """
36
-
37
- def __init__(self, config: cfg.ModelConfig):
38
- super().__init__(config)
39
- # SmolLM re-uses the embedding as the head projection layer.
40
- self.lm_head.weight.data = self.tok_embedding.weight.data
21
+ TENSOR_NAMES = model_builder.TENSOR_NAMES
41
22
 
42
23
 
43
24
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -91,12 +72,11 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
91
72
  return config
92
73
 
93
74
 
94
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
95
- config = get_model_config(**kwargs)
96
- model = SmolLM(config)
97
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
98
- # Since embedding and lm-head use the same weight, we need to set strict
99
- # to False.
100
- loader.load(model, strict=False)
101
- model.eval()
102
- return model
75
+ def build_model(
76
+ checkpoint_path: str, **kwargs
77
+ ) -> model_builder.DecoderOnlyModel:
78
+ return model_builder.build_decoder_only_model(
79
+ checkpoint_path=checkpoint_path,
80
+ config=get_model_config(**kwargs),
81
+ tensor_names=TENSOR_NAMES,
82
+ )
@@ -15,102 +15,10 @@
15
15
 
16
16
  """Example of building a TinyLlama model."""
17
17
 
18
- from ai_edge_torch.generative.layers import attention
19
- from ai_edge_torch.generative.layers import builder
20
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
18
  import ai_edge_torch.generative.layers.model_config as cfg
23
- import ai_edge_torch.generative.utilities.loader as loading_utils
24
- import torch
25
- from torch import nn
19
+ from ai_edge_torch.generative.utilities import model_builder
26
20
 
27
- TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
28
- ff_up_proj="model.layers.{}.mlp.up_proj",
29
- ff_down_proj="model.layers.{}.mlp.down_proj",
30
- ff_gate_proj="model.layers.{}.mlp.gate_proj",
31
- attn_query_proj="model.layers.{}.self_attn.q_proj",
32
- attn_key_proj="model.layers.{}.self_attn.k_proj",
33
- attn_value_proj="model.layers.{}.self_attn.v_proj",
34
- attn_output_proj="model.layers.{}.self_attn.o_proj",
35
- pre_attn_norm="model.layers.{}.input_layernorm",
36
- post_attn_norm="model.layers.{}.post_attention_layernorm",
37
- embedding="model.embed_tokens",
38
- final_norm="model.norm",
39
- lm_head="lm_head",
40
- )
41
-
42
-
43
- class TinyLlama(nn.Module):
44
- """A TinyLlama model built from the Edge Generative API layers."""
45
-
46
- def __init__(self, config: cfg.ModelConfig):
47
- super().__init__()
48
-
49
- # Construct model layers.
50
- self.lm_head = nn.Linear(
51
- config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
52
- )
53
- self.tok_embedding = nn.Embedding(
54
- config.vocab_size, config.embedding_dim, padding_idx=0
55
- )
56
- # TinyLlama has only one block config.
57
- block_config = config.block_config(0)
58
- self.transformer_blocks = nn.ModuleList(
59
- attention.TransformerBlock(block_config, config)
60
- for _ in range(config.num_layers)
61
- )
62
- self.final_norm = builder.build_norm(
63
- config.embedding_dim,
64
- config.final_norm_config,
65
- )
66
- attn_config = block_config.attn_config
67
- self.rope_cache = attn_utils.build_rope_cache(
68
- size=config.kv_cache_max,
69
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
70
- base=attn_config.rotary_base,
71
- )
72
- self.mask_cache = attn_utils.build_causal_mask_cache(
73
- size=config.kv_cache_max,
74
- )
75
- self.config = config
76
-
77
- @torch.inference_mode
78
- def forward(
79
- self,
80
- tokens: torch.Tensor,
81
- input_pos: torch.Tensor,
82
- kv_cache: kv_utils.KVCache,
83
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
84
- _, seq_len = tokens.size()
85
- assert self.config.max_seq_len >= seq_len, (
86
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
87
- f" {self.config.max_seq_len}"
88
- )
89
- assert len(self.transformer_blocks) == len(kv_cache.caches), (
90
- "The number of transformer blocks and the number of KV cache entries"
91
- " must be the same."
92
- )
93
-
94
- cos, sin = self.rope_cache
95
- cos = cos.index_select(0, input_pos)
96
- sin = sin.index_select(0, input_pos)
97
- mask = self.mask_cache.index_select(2, input_pos)
98
- mask = mask[:, :, :, : self.config.kv_cache_max]
99
-
100
- # token embeddings of shape (b, t, n_embd)
101
- x = self.tok_embedding(tokens)
102
-
103
- updated_kv_entires = []
104
- for i, block in enumerate(self.transformer_blocks):
105
- kv_entry = kv_cache.caches[i] if kv_cache else None
106
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
107
- if kv_entry:
108
- updated_kv_entires.append(kv_entry)
109
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
110
-
111
- x = self.final_norm(x)
112
- logits = self.lm_head(x) # (b, t, vocab_size)
113
- return {"logits": logits, "kv_cache": updated_kv_cache}
21
+ TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
114
22
 
115
23
 
116
24
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -150,6 +58,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
150
58
  kv_cache_max_len=kv_cache_max_len,
151
59
  block_configs=block_config,
152
60
  final_norm_config=norm_config,
61
+ lm_head_share_weight_with_embedding=False,
153
62
  enable_hlfb=True,
154
63
  )
155
64
  return config
@@ -164,10 +73,11 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
164
73
  return config
165
74
 
166
75
 
167
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
168
- config = get_model_config(**kwargs)
169
- model = TinyLlama(config)
170
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
171
- loader.load(model)
172
- model.eval()
173
- return model
76
+ def build_model(
77
+ checkpoint_path: str, **kwargs
78
+ ) -> model_builder.DecoderOnlyModel:
79
+ return model_builder.build_decoder_only_model(
80
+ checkpoint_path=checkpoint_path,
81
+ config=get_model_config(**kwargs),
82
+ tensor_names=TENSOR_NAMES,
83
+ )
@@ -184,8 +184,14 @@ class ModelConfig:
184
184
  default_factory=NormalizationConfig
185
185
  )
186
186
 
187
+ # Scale factor of the embedding.
188
+ embedding_scale: Optional[float] = None
189
+
187
190
  # Use bias term within LLM's HEAD.
188
191
  lm_head_use_bias: bool = False
192
+ # Whether LLM's HEAD shares the weight of the embedding.
193
+ lm_head_share_weight_with_embedding: bool = True
194
+
189
195
  # Whether to turn on high-level function boundary.
190
196
  enable_hlfb: bool = False
191
197
 
@@ -19,6 +19,7 @@ import tempfile
19
19
 
20
20
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
21
  from ai_edge_torch.generative.utilities import loader as loading_utils
22
+ from ai_edge_torch.generative.utilities import model_builder
22
23
  import safetensors.torch
23
24
  import torch
24
25
 
@@ -71,7 +72,7 @@ class TestLoader(googletest.TestCase):
71
72
  safetensors.torch.save_file(test_weights, file_path)
72
73
  cfg = tiny_llama.get_model_config()
73
74
  cfg.num_layers = 1
74
- model = tiny_llama.TinyLlama(cfg)
75
+ model = model_builder.DecoderOnlyModel(cfg)
75
76
 
76
77
  loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
77
78
  # if returns successfully, it means all the tensors were initiallized.
@@ -21,6 +21,7 @@ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cach
21
21
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
22
22
  from ai_edge_torch.generative.layers import kv_cache
23
23
  from ai_edge_torch.generative.test import utils as test_utils
24
+ from ai_edge_torch.generative.utilities import model_builder
24
25
  import numpy as np
25
26
  import torch
26
27
 
@@ -42,31 +43,40 @@ class TestModelConversion(googletest.TestCase):
42
43
  )
43
44
  )
44
45
 
45
- def _test_model_with_kv_cache(self, config, pytorch_model):
46
+ def _get_params(self, enable_hlfb: bool):
47
+ """Returns a model, edge model and the kwargs to use for testing."""
48
+ config = toy_model_with_kv_cache.get_model_config()
49
+ config.enable_hlfb = enable_hlfb
50
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
46
51
  tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
47
52
  [10], dtype=torch.int
48
53
  )
49
54
  kv = kv_cache.KVCache.from_model_config(config)
55
+ kwargs = {
56
+ "tokens": tokens,
57
+ "input_pos": input_pos,
58
+ "kv_cache": kv,
59
+ }
50
60
 
51
61
  edge_model = ai_edge_torch.convert(
52
62
  pytorch_model,
53
- sample_kwargs={
54
- "tokens": tokens,
55
- "input_pos": input_pos,
56
- "kv_cache": kv,
57
- },
63
+ sample_kwargs=kwargs,
58
64
  )
59
65
  edge_model.set_interpreter_builder(
60
66
  self._interpreter_builder(edge_model.tflite_model())
61
67
  )
68
+ return pytorch_model, edge_model, kwargs
69
+
70
+ def _test_model_with_kv_cache(self, enable_hlfb: bool):
71
+ pytorch_model, edge_model, kwargs = self._get_params(enable_hlfb)
62
72
 
63
73
  self.assertTrue(
64
74
  test_utils.compare_tflite_torch(
65
75
  edge_model,
66
76
  pytorch_model,
67
- tokens,
68
- input_pos,
69
- kv,
77
+ kwargs["tokens"],
78
+ kwargs["input_pos"],
79
+ kwargs["kv_cache"],
70
80
  signature_name="serving_default",
71
81
  atol=1e-5,
72
82
  rtol=1e-5,
@@ -78,19 +88,31 @@ class TestModelConversion(googletest.TestCase):
78
88
  reason="tests with custom ops are not supported on oss",
79
89
  )
80
90
  def test_toy_model_with_kv_cache(self):
81
- config = toy_model_with_kv_cache.get_model_config()
82
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
83
- self._test_model_with_kv_cache(config, pytorch_model)
91
+ self._test_model_with_kv_cache(enable_hlfb=False)
84
92
 
85
93
  @googletest.skipIf(
86
94
  ai_edge_config.Config.use_torch_xla,
87
95
  reason="tests with custom ops are not supported on oss",
88
96
  )
89
97
  def test_toy_model_with_kv_cache_with_hlfb(self):
90
- config = toy_model_with_kv_cache.get_model_config()
91
- config.enable_hlfb = True
92
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
93
- self._test_model_with_kv_cache(config, pytorch_model)
98
+ self._test_model_with_kv_cache(enable_hlfb=True)
99
+
100
+ @googletest.skipIf(
101
+ ai_edge_config.Config.use_torch_xla,
102
+ reason="tests with custom ops are not supported on oss",
103
+ )
104
+ def test_toy_model_has_ekv_op(self):
105
+ """Tests that the model has the external kv cache op."""
106
+ _, edge_model, _ = self._get_params(enable_hlfb=True)
107
+ interpreter_ = interpreter.InterpreterWithCustomOps(
108
+ custom_op_registerers=["GenAIOpsRegisterer"],
109
+ model_content=edge_model.tflite_model(),
110
+ experimental_default_delegate_latest_features=True,
111
+ )
112
+
113
+ # pylint: disable=protected-access
114
+ op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
115
+ self.assertIn("odml.update_external_kv_cache", op_names)
94
116
 
95
117
  def _test_multisig_model(self, config, pytorch_model, atol, rtol):
96
118
  # prefill
@@ -163,7 +185,7 @@ class TestModelConversion(googletest.TestCase):
163
185
  )
164
186
  def test_tiny_llama_multisig(self):
165
187
  config = tiny_llama.get_fake_model_config()
166
- pytorch_model = tiny_llama.TinyLlama(config).eval()
188
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
167
189
  self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
168
190
 
169
191
 
@@ -29,6 +29,7 @@ from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
29
29
  from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
30
30
  from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
31
31
  from ai_edge_torch.generative.layers import kv_cache
32
+ from ai_edge_torch.generative.utilities import model_builder
32
33
  from ai_edge_torch.generative.test import utils as test_utils
33
34
  import numpy as np
34
35
  import torch
@@ -90,7 +91,7 @@ class TestModelConversion(googletest.TestCase):
90
91
  )
91
92
  def test_gemma1(self):
92
93
  config = gemma1.get_fake_model_config()
93
- pytorch_model = gemma1.Gemma(config).eval()
94
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
94
95
  self._test_model(
95
96
  config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
96
97
  )
@@ -119,7 +120,7 @@ class TestModelConversion(googletest.TestCase):
119
120
  )
120
121
  def test_phi2(self):
121
122
  config = phi2.get_fake_model_config()
122
- pytorch_model = phi2.Phi2(config).eval()
123
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
123
124
  self._test_model(
124
125
  config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
125
126
  )
@@ -139,7 +140,7 @@ class TestModelConversion(googletest.TestCase):
139
140
  )
140
141
  def test_smollm(self):
141
142
  config = smollm.get_fake_model_config()
142
- pytorch_model = smollm.SmolLM(config).eval()
143
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
143
144
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
144
145
 
145
146
  @googletest.skipIf(
@@ -148,7 +149,7 @@ class TestModelConversion(googletest.TestCase):
148
149
  )
149
150
  def test_openelm(self):
150
151
  config = openelm.get_fake_model_config()
151
- pytorch_model = openelm.OpenELM(config).eval()
152
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
152
153
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
153
154
 
154
155
  @googletest.skipIf(
@@ -157,7 +158,7 @@ class TestModelConversion(googletest.TestCase):
157
158
  )
158
159
  def test_qwen(self):
159
160
  config = qwen.get_fake_model_config()
160
- pytorch_model = qwen.Qwen(config).eval()
161
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
161
162
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
162
163
 
163
164
  @googletest.skipIf(
@@ -0,0 +1,141 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utilities to be used for re-authoring transformer models."""
17
+
18
+ import copy
19
+
20
+ from ai_edge_torch.generative.layers import attention
21
+ from ai_edge_torch.generative.layers import builder
22
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
+ import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import torch
27
+ from torch import nn
28
+
29
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
30
+ ff_up_proj="model.layers.{}.mlp.up_proj",
31
+ ff_down_proj="model.layers.{}.mlp.down_proj",
32
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
33
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
34
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
35
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
36
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
37
+ pre_attn_norm="model.layers.{}.input_layernorm",
38
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
39
+ embedding="model.embed_tokens",
40
+ final_norm="model.norm",
41
+ )
42
+
43
+ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD = copy.copy(TENSOR_NAMES)
44
+ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD.lm_head = "lm_head"
45
+
46
+
47
+ class DecoderOnlyModel(nn.Module):
48
+ """A simple decoder-only transformer model built from the Edge Generative API.
49
+
50
+ This model is used for re-authoring. model_config is used to specify the
51
+ details of model architecture and parameters.
52
+
53
+ It assumes that the attention configs for ROPE, i.e. head_dim, rotary_base,
54
+ and rotary_percentage are the same for all layers.
55
+ """
56
+
57
+ def __init__(self, config: cfg.ModelConfig):
58
+ super().__init__()
59
+
60
+ # Construct model layers.
61
+ self.tok_embedding = nn.Embedding(
62
+ config.vocab_size, config.embedding_dim, padding_idx=0
63
+ )
64
+ self.lm_head = nn.Linear(
65
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
66
+ )
67
+ if config.lm_head_share_weight_with_embedding:
68
+ self.lm_head.weight.data = self.tok_embedding.weight.data
69
+ self.transformer_blocks = nn.ModuleList(
70
+ attention.TransformerBlock(config.block_config(idx), config)
71
+ for idx in range(config.num_layers)
72
+ )
73
+ self.final_norm = builder.build_norm(
74
+ config.embedding_dim,
75
+ config.final_norm_config,
76
+ )
77
+ # ROPE parameters for all attn_configs are the same. Take the first one.
78
+ attn_config = config.block_config(0).attn_config
79
+ self.rope_cache = attn_utils.build_rope_cache(
80
+ size=config.kv_cache_max,
81
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
82
+ base=attn_config.rotary_base,
83
+ )
84
+ self.mask_cache = attn_utils.build_causal_mask_cache(
85
+ size=config.kv_cache_max,
86
+ )
87
+ self.config = config
88
+
89
+ @torch.inference_mode
90
+ def forward(
91
+ self,
92
+ tokens: torch.Tensor,
93
+ input_pos: torch.Tensor,
94
+ kv_cache: kv_utils.KVCache,
95
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
96
+ _, seq_len = tokens.size()
97
+ assert self.config.max_seq_len >= seq_len, (
98
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
99
+ f" {self.config.max_seq_len}"
100
+ )
101
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
102
+ "The number of transformer blocks and the number of KV cache entries"
103
+ " must be the same."
104
+ )
105
+
106
+ cos, sin = self.rope_cache
107
+ cos = cos.index_select(0, input_pos)
108
+ sin = sin.index_select(0, input_pos)
109
+ mask = self.mask_cache.index_select(2, input_pos)
110
+ mask = mask[:, :, :, : self.config.kv_cache_max]
111
+
112
+ # token embeddings of shape (b, t, n_embd)
113
+ x = self.tok_embedding(tokens)
114
+ if self.config.embedding_scale is not None:
115
+ x = x * self.config.embedding_scale
116
+
117
+ updated_kv_entires = []
118
+ for i, block in enumerate(self.transformer_blocks):
119
+ kv_entry = kv_cache.caches[i] if kv_cache else None
120
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
121
+ if kv_entry:
122
+ updated_kv_entires.append(kv_entry)
123
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
124
+
125
+ x = self.final_norm(x)
126
+ logits = self.lm_head(x) # (b, t, vocab_size)
127
+ return {"logits": logits, "kv_cache": updated_kv_cache}
128
+
129
+
130
+ def build_decoder_only_model(
131
+ checkpoint_path: str,
132
+ config: cfg.ModelConfig,
133
+ tensor_names: loading_utils.ModelLoader.TensorNames,
134
+ ) -> DecoderOnlyModel:
135
+ transformer = DecoderOnlyModel(config)
136
+ loader = loading_utils.ModelLoader(checkpoint_path, tensor_names)
137
+ loader.load(
138
+ transformer, strict=not config.lm_head_share_weight_with_embedding
139
+ )
140
+ transformer.eval()
141
+ return transformer
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241002"
16
+ __version__ = "0.3.0.dev20241004"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241002
3
+ Version: 0.3.0.dev20241004
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI