ai-edge-torch-nightly 0.3.0.dev20241002__py3-none-any.whl → 0.3.0.dev20241003__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +10 -93
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +0 -1
  3. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +13 -2
  4. ai_edge_torch/generative/examples/llama/llama.py +19 -24
  5. ai_edge_torch/generative/examples/llama/verify.py +18 -3
  6. ai_edge_torch/generative/examples/openelm/openelm.py +9 -90
  7. ai_edge_torch/generative/examples/phi/phi2.py +10 -86
  8. ai_edge_torch/generative/examples/phi/phi3.py +9 -69
  9. ai_edge_torch/generative/examples/qwen/qwen.py +26 -36
  10. ai_edge_torch/generative/examples/smollm/smollm.py +10 -30
  11. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +11 -101
  12. ai_edge_torch/generative/layers/model_config.py +6 -0
  13. ai_edge_torch/generative/test/test_loader.py +2 -1
  14. ai_edge_torch/generative/test/test_model_conversion.py +2 -1
  15. ai_edge_torch/generative/test/test_model_conversion_large.py +6 -5
  16. ai_edge_torch/generative/utilities/model_builder.py +141 -0
  17. ai_edge_torch/version.py +1 -1
  18. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/METADATA +1 -1
  19. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/RECORD +22 -23
  20. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +0 -68
  21. ai_edge_torch/generative/examples/llama/verify_3b.py +0 -73
  22. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/LICENSE +0 -0
  23. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/WHEEL +0 -0
  24. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/top_level.txt +0 -0
@@ -15,28 +15,10 @@
15
15
 
16
16
  """Example of building Qwen 2.5 models."""
17
17
 
18
- import copy
19
-
20
- from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
18
  import ai_edge_torch.generative.layers.model_config as cfg
22
- import ai_edge_torch.generative.utilities.loader as loading_utils
23
- from torch import nn
24
-
25
- TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
26
- # Qwen re-uses the embedding as the head projection layer.
27
- TENSOR_NAMES.lm_head = None
28
-
29
-
30
- class Qwen(tiny_llama.TinyLlama):
31
- """A Qwen model built from the Edge Generative API layers.
32
-
33
- Qwen 2.5 shares the same architecture as TinyLlama.
34
- """
19
+ from ai_edge_torch.generative.utilities import model_builder
35
20
 
36
- def __init__(self, config: cfg.ModelConfig):
37
- super().__init__(config)
38
- # Qwen re-uses the embedding as the head projection layer.
39
- self.lm_head.weight.data = self.tok_embedding.weight.data
21
+ TENSOR_NAMES = model_builder.TENSOR_NAMES
40
22
 
41
23
 
42
24
  def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -119,23 +101,31 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
119
101
  return config
120
102
 
121
103
 
122
- def _build_model(checkpoint_path: str, config: cfg.ModelConfig) -> nn.Module:
123
- model = Qwen(config)
124
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
125
- # Since embedding and lm-head use the same weight, we need to set strict
126
- # to False.
127
- loader.load(model, strict=False)
128
- model.eval()
129
- return model
130
-
131
-
132
- def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
133
- return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
104
+ def build_3b_model(
105
+ checkpoint_path: str, **kwargs
106
+ ) -> model_builder.DecoderOnlyModel:
107
+ return model_builder.build_decoder_only_model(
108
+ checkpoint_path=checkpoint_path,
109
+ config=get_3b_model_config(**kwargs),
110
+ tensor_names=TENSOR_NAMES,
111
+ )
134
112
 
135
113
 
136
- def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
137
- return _build_model(checkpoint_path, get_1_5b_model_config(**kwargs))
114
+ def build_1_5b_model(
115
+ checkpoint_path: str, **kwargs
116
+ ) -> model_builder.DecoderOnlyModel:
117
+ return model_builder.build_decoder_only_model(
118
+ checkpoint_path=checkpoint_path,
119
+ config=get_1_5b_model_config(**kwargs),
120
+ tensor_names=TENSOR_NAMES,
121
+ )
138
122
 
139
123
 
140
- def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
141
- return _build_model(checkpoint_path, get_0_5b_model_config(**kwargs))
124
+ def build_0_5b_model(
125
+ checkpoint_path: str, **kwargs
126
+ ) -> model_builder.DecoderOnlyModel:
127
+ return model_builder.build_decoder_only_model(
128
+ checkpoint_path=checkpoint_path,
129
+ config=get_0_5b_model_config(**kwargs),
130
+ tensor_names=TENSOR_NAMES,
131
+ )
@@ -15,29 +15,10 @@
15
15
 
16
16
  """Example of building a SmolLM model."""
17
17
 
18
- import copy
19
-
20
- from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
18
  import ai_edge_torch.generative.layers.model_config as cfg
22
- import ai_edge_torch.generative.utilities.loader as loading_utils
23
- from torch import nn
24
-
25
- TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
26
- # SmolLM re-uses the embedding as the head projection layer.
27
- TENSOR_NAMES.lm_head = None
28
-
29
-
30
- class SmolLM(tiny_llama.TinyLlama):
31
- """A SmolLM model built from the Edge Generative API layers.
19
+ from ai_edge_torch.generative.utilities import model_builder
32
20
 
33
- SmolLM shares the same architecture as TinyLlama, but with different model
34
- sizes.
35
- """
36
-
37
- def __init__(self, config: cfg.ModelConfig):
38
- super().__init__(config)
39
- # SmolLM re-uses the embedding as the head projection layer.
40
- self.lm_head.weight.data = self.tok_embedding.weight.data
21
+ TENSOR_NAMES = model_builder.TENSOR_NAMES
41
22
 
42
23
 
43
24
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -91,12 +72,11 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
91
72
  return config
92
73
 
93
74
 
94
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
95
- config = get_model_config(**kwargs)
96
- model = SmolLM(config)
97
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
98
- # Since embedding and lm-head use the same weight, we need to set strict
99
- # to False.
100
- loader.load(model, strict=False)
101
- model.eval()
102
- return model
75
+ def build_model(
76
+ checkpoint_path: str, **kwargs
77
+ ) -> model_builder.DecoderOnlyModel:
78
+ return model_builder.build_decoder_only_model(
79
+ checkpoint_path=checkpoint_path,
80
+ config=get_model_config(**kwargs),
81
+ tensor_names=TENSOR_NAMES,
82
+ )
@@ -15,102 +15,10 @@
15
15
 
16
16
  """Example of building a TinyLlama model."""
17
17
 
18
- from ai_edge_torch.generative.layers import attention
19
- from ai_edge_torch.generative.layers import builder
20
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
18
  import ai_edge_torch.generative.layers.model_config as cfg
23
- import ai_edge_torch.generative.utilities.loader as loading_utils
24
- import torch
25
- from torch import nn
19
+ from ai_edge_torch.generative.utilities import model_builder
26
20
 
27
- TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
28
- ff_up_proj="model.layers.{}.mlp.up_proj",
29
- ff_down_proj="model.layers.{}.mlp.down_proj",
30
- ff_gate_proj="model.layers.{}.mlp.gate_proj",
31
- attn_query_proj="model.layers.{}.self_attn.q_proj",
32
- attn_key_proj="model.layers.{}.self_attn.k_proj",
33
- attn_value_proj="model.layers.{}.self_attn.v_proj",
34
- attn_output_proj="model.layers.{}.self_attn.o_proj",
35
- pre_attn_norm="model.layers.{}.input_layernorm",
36
- post_attn_norm="model.layers.{}.post_attention_layernorm",
37
- embedding="model.embed_tokens",
38
- final_norm="model.norm",
39
- lm_head="lm_head",
40
- )
41
-
42
-
43
- class TinyLlama(nn.Module):
44
- """A TinyLlama model built from the Edge Generative API layers."""
45
-
46
- def __init__(self, config: cfg.ModelConfig):
47
- super().__init__()
48
-
49
- # Construct model layers.
50
- self.lm_head = nn.Linear(
51
- config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
52
- )
53
- self.tok_embedding = nn.Embedding(
54
- config.vocab_size, config.embedding_dim, padding_idx=0
55
- )
56
- # TinyLlama has only one block config.
57
- block_config = config.block_config(0)
58
- self.transformer_blocks = nn.ModuleList(
59
- attention.TransformerBlock(block_config, config)
60
- for _ in range(config.num_layers)
61
- )
62
- self.final_norm = builder.build_norm(
63
- config.embedding_dim,
64
- config.final_norm_config,
65
- )
66
- attn_config = block_config.attn_config
67
- self.rope_cache = attn_utils.build_rope_cache(
68
- size=config.kv_cache_max,
69
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
70
- base=attn_config.rotary_base,
71
- )
72
- self.mask_cache = attn_utils.build_causal_mask_cache(
73
- size=config.kv_cache_max,
74
- )
75
- self.config = config
76
-
77
- @torch.inference_mode
78
- def forward(
79
- self,
80
- tokens: torch.Tensor,
81
- input_pos: torch.Tensor,
82
- kv_cache: kv_utils.KVCache,
83
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
84
- _, seq_len = tokens.size()
85
- assert self.config.max_seq_len >= seq_len, (
86
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
87
- f" {self.config.max_seq_len}"
88
- )
89
- assert len(self.transformer_blocks) == len(kv_cache.caches), (
90
- "The number of transformer blocks and the number of KV cache entries"
91
- " must be the same."
92
- )
93
-
94
- cos, sin = self.rope_cache
95
- cos = cos.index_select(0, input_pos)
96
- sin = sin.index_select(0, input_pos)
97
- mask = self.mask_cache.index_select(2, input_pos)
98
- mask = mask[:, :, :, : self.config.kv_cache_max]
99
-
100
- # token embeddings of shape (b, t, n_embd)
101
- x = self.tok_embedding(tokens)
102
-
103
- updated_kv_entires = []
104
- for i, block in enumerate(self.transformer_blocks):
105
- kv_entry = kv_cache.caches[i] if kv_cache else None
106
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
107
- if kv_entry:
108
- updated_kv_entires.append(kv_entry)
109
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
110
-
111
- x = self.final_norm(x)
112
- logits = self.lm_head(x) # (b, t, vocab_size)
113
- return {"logits": logits, "kv_cache": updated_kv_cache}
21
+ TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
114
22
 
115
23
 
116
24
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -150,6 +58,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
150
58
  kv_cache_max_len=kv_cache_max_len,
151
59
  block_configs=block_config,
152
60
  final_norm_config=norm_config,
61
+ lm_head_share_weight_with_embedding=False,
153
62
  enable_hlfb=True,
154
63
  )
155
64
  return config
@@ -164,10 +73,11 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
164
73
  return config
165
74
 
166
75
 
167
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
168
- config = get_model_config(**kwargs)
169
- model = TinyLlama(config)
170
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
171
- loader.load(model)
172
- model.eval()
173
- return model
76
+ def build_model(
77
+ checkpoint_path: str, **kwargs
78
+ ) -> model_builder.DecoderOnlyModel:
79
+ return model_builder.build_decoder_only_model(
80
+ checkpoint_path=checkpoint_path,
81
+ config=get_model_config(**kwargs),
82
+ tensor_names=TENSOR_NAMES,
83
+ )
@@ -184,8 +184,14 @@ class ModelConfig:
184
184
  default_factory=NormalizationConfig
185
185
  )
186
186
 
187
+ # Scale factor of the embedding.
188
+ embedding_scale: Optional[float] = None
189
+
187
190
  # Use bias term within LLM's HEAD.
188
191
  lm_head_use_bias: bool = False
192
+ # Whether LLM's HEAD shares the weight of the embedding.
193
+ lm_head_share_weight_with_embedding: bool = True
194
+
189
195
  # Whether to turn on high-level function boundary.
190
196
  enable_hlfb: bool = False
191
197
 
@@ -19,6 +19,7 @@ import tempfile
19
19
 
20
20
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
21
  from ai_edge_torch.generative.utilities import loader as loading_utils
22
+ from ai_edge_torch.generative.utilities import model_builder
22
23
  import safetensors.torch
23
24
  import torch
24
25
 
@@ -71,7 +72,7 @@ class TestLoader(googletest.TestCase):
71
72
  safetensors.torch.save_file(test_weights, file_path)
72
73
  cfg = tiny_llama.get_model_config()
73
74
  cfg.num_layers = 1
74
- model = tiny_llama.TinyLlama(cfg)
75
+ model = model_builder.DecoderOnlyModel(cfg)
75
76
 
76
77
  loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
77
78
  # if returns successfully, it means all the tensors were initiallized.
@@ -21,6 +21,7 @@ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cach
21
21
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
22
22
  from ai_edge_torch.generative.layers import kv_cache
23
23
  from ai_edge_torch.generative.test import utils as test_utils
24
+ from ai_edge_torch.generative.utilities import model_builder
24
25
  import numpy as np
25
26
  import torch
26
27
 
@@ -163,7 +164,7 @@ class TestModelConversion(googletest.TestCase):
163
164
  )
164
165
  def test_tiny_llama_multisig(self):
165
166
  config = tiny_llama.get_fake_model_config()
166
- pytorch_model = tiny_llama.TinyLlama(config).eval()
167
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
167
168
  self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
168
169
 
169
170
 
@@ -29,6 +29,7 @@ from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
29
29
  from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
30
30
  from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
31
31
  from ai_edge_torch.generative.layers import kv_cache
32
+ from ai_edge_torch.generative.utilities import model_builder
32
33
  from ai_edge_torch.generative.test import utils as test_utils
33
34
  import numpy as np
34
35
  import torch
@@ -90,7 +91,7 @@ class TestModelConversion(googletest.TestCase):
90
91
  )
91
92
  def test_gemma1(self):
92
93
  config = gemma1.get_fake_model_config()
93
- pytorch_model = gemma1.Gemma(config).eval()
94
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
94
95
  self._test_model(
95
96
  config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
96
97
  )
@@ -119,7 +120,7 @@ class TestModelConversion(googletest.TestCase):
119
120
  )
120
121
  def test_phi2(self):
121
122
  config = phi2.get_fake_model_config()
122
- pytorch_model = phi2.Phi2(config).eval()
123
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
123
124
  self._test_model(
124
125
  config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
125
126
  )
@@ -139,7 +140,7 @@ class TestModelConversion(googletest.TestCase):
139
140
  )
140
141
  def test_smollm(self):
141
142
  config = smollm.get_fake_model_config()
142
- pytorch_model = smollm.SmolLM(config).eval()
143
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
143
144
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
144
145
 
145
146
  @googletest.skipIf(
@@ -148,7 +149,7 @@ class TestModelConversion(googletest.TestCase):
148
149
  )
149
150
  def test_openelm(self):
150
151
  config = openelm.get_fake_model_config()
151
- pytorch_model = openelm.OpenELM(config).eval()
152
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
152
153
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
153
154
 
154
155
  @googletest.skipIf(
@@ -157,7 +158,7 @@ class TestModelConversion(googletest.TestCase):
157
158
  )
158
159
  def test_qwen(self):
159
160
  config = qwen.get_fake_model_config()
160
- pytorch_model = qwen.Qwen(config).eval()
161
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
161
162
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
162
163
 
163
164
  @googletest.skipIf(
@@ -0,0 +1,141 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utilities to be used for re-authoring transformer models."""
17
+
18
+ import copy
19
+
20
+ from ai_edge_torch.generative.layers import attention
21
+ from ai_edge_torch.generative.layers import builder
22
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
+ import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import torch
27
+ from torch import nn
28
+
29
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
30
+ ff_up_proj="model.layers.{}.mlp.up_proj",
31
+ ff_down_proj="model.layers.{}.mlp.down_proj",
32
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
33
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
34
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
35
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
36
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
37
+ pre_attn_norm="model.layers.{}.input_layernorm",
38
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
39
+ embedding="model.embed_tokens",
40
+ final_norm="model.norm",
41
+ )
42
+
43
+ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD = copy.copy(TENSOR_NAMES)
44
+ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD.lm_head = "lm_head"
45
+
46
+
47
+ class DecoderOnlyModel(nn.Module):
48
+ """A simple decoder-only transformer model built from the Edge Generative API.
49
+
50
+ This model is used for re-authoring. model_config is used to specify the
51
+ details of model architecture and parameters.
52
+
53
+ It assumes that the attention configs for ROPE, i.e. head_dim, rotary_base,
54
+ and rotary_percentage are the same for all layers.
55
+ """
56
+
57
+ def __init__(self, config: cfg.ModelConfig):
58
+ super().__init__()
59
+
60
+ # Construct model layers.
61
+ self.tok_embedding = nn.Embedding(
62
+ config.vocab_size, config.embedding_dim, padding_idx=0
63
+ )
64
+ self.lm_head = nn.Linear(
65
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
66
+ )
67
+ if config.lm_head_share_weight_with_embedding:
68
+ self.lm_head.weight.data = self.tok_embedding.weight.data
69
+ self.transformer_blocks = nn.ModuleList(
70
+ attention.TransformerBlock(config.block_config(idx), config)
71
+ for idx in range(config.num_layers)
72
+ )
73
+ self.final_norm = builder.build_norm(
74
+ config.embedding_dim,
75
+ config.final_norm_config,
76
+ )
77
+ # ROPE parameters for all attn_configs are the same. Take the first one.
78
+ attn_config = config.block_config(0).attn_config
79
+ self.rope_cache = attn_utils.build_rope_cache(
80
+ size=config.kv_cache_max,
81
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
82
+ base=attn_config.rotary_base,
83
+ )
84
+ self.mask_cache = attn_utils.build_causal_mask_cache(
85
+ size=config.kv_cache_max,
86
+ )
87
+ self.config = config
88
+
89
+ @torch.inference_mode
90
+ def forward(
91
+ self,
92
+ tokens: torch.Tensor,
93
+ input_pos: torch.Tensor,
94
+ kv_cache: kv_utils.KVCache,
95
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
96
+ _, seq_len = tokens.size()
97
+ assert self.config.max_seq_len >= seq_len, (
98
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
99
+ f" {self.config.max_seq_len}"
100
+ )
101
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
102
+ "The number of transformer blocks and the number of KV cache entries"
103
+ " must be the same."
104
+ )
105
+
106
+ cos, sin = self.rope_cache
107
+ cos = cos.index_select(0, input_pos)
108
+ sin = sin.index_select(0, input_pos)
109
+ mask = self.mask_cache.index_select(2, input_pos)
110
+ mask = mask[:, :, :, : self.config.kv_cache_max]
111
+
112
+ # token embeddings of shape (b, t, n_embd)
113
+ x = self.tok_embedding(tokens)
114
+ if self.config.embedding_scale is not None:
115
+ x = x * self.config.embedding_scale
116
+
117
+ updated_kv_entires = []
118
+ for i, block in enumerate(self.transformer_blocks):
119
+ kv_entry = kv_cache.caches[i] if kv_cache else None
120
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
121
+ if kv_entry:
122
+ updated_kv_entires.append(kv_entry)
123
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
124
+
125
+ x = self.final_norm(x)
126
+ logits = self.lm_head(x) # (b, t, vocab_size)
127
+ return {"logits": logits, "kv_cache": updated_kv_cache}
128
+
129
+
130
+ def build_decoder_only_model(
131
+ checkpoint_path: str,
132
+ config: cfg.ModelConfig,
133
+ tensor_names: loading_utils.ModelLoader.TensorNames,
134
+ ) -> DecoderOnlyModel:
135
+ transformer = DecoderOnlyModel(config)
136
+ loader = loading_utils.ModelLoader(checkpoint_path, tensor_names)
137
+ loader.load(
138
+ transformer, strict=not config.lm_head_share_weight_with_embedding
139
+ )
140
+ transformer.eval()
141
+ return transformer
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241002"
16
+ __version__ = "0.3.0.dev20241003"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241002
3
+ Version: 0.3.0.dev20241003
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=ODx8CRsxZZYlliSx6vnHxxTorI9c0WPgrVvwGY5KAQI,706
6
+ ai_edge_torch/version.py,sha256=WKaZCocAyLb42oFdC07BQ6qpSfohXBwt-HKGV7S2fXw,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -41,35 +41,33 @@ ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQe
41
41
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
42
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
43
43
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
44
- ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=kxWmmoVvtLP5auB3UXA2vsvZmSnpBs4SBixzYeAXzVA,6255
45
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=7VF5RYJ8QhROQNIlx-QovO-y6-jFp_EHgAkBNChZaqE,9066
44
+ ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
45
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
46
46
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
47
47
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=K77k-JpdhIwm3tbBnzpw8HQsFRwAVyszxRo82fR6-q4,1762
48
48
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=sqltZbnyKemNvKqqi9d09i74gP-PPQFodRYfDfnhycQ,4933
49
49
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py,sha256=_OrerrTA6tvP9Tnwj601QO95Cm8PlOiYP-mxvtmBmb4,2186
51
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=GGo6Kxiwqva4JfurGx3InU3nROW70XtYvxUwEf_6mBQ,2180
52
- ai_edge_torch/generative/examples/llama/llama.py,sha256=5vlh2Z8vEPH8Z4LoHoFYCcuOQynx4mbVE37v3yMl1hE,7162
53
- ai_edge_torch/generative/examples/llama/verify.py,sha256=7xwKM_yzLCrmFsYj1UbsjW58ZG8Yic0xw1GFkdydrCU,2525
54
- ai_edge_torch/generative/examples/llama/verify_3b.py,sha256=IijBWqLXINOfwayM-8EIpc7OcC6Nj5CnberStx-vDSk,2528
50
+ ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=P0-pByTM5tslE23ILgo7nd0nOGE25ciBRG5wKJj0bBk,2411
51
+ ai_edge_torch/generative/examples/llama/llama.py,sha256=AMcCbuDBxEfbO-l3KiEXbUaXEJ3RLLwkHii7to7UhVo,6854
52
+ ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
55
53
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
56
54
  ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
57
- ai_edge_torch/generative/examples/openelm/openelm.py,sha256=hxbpvk0fNswzbqZfGteflqKMmkH7yzeMuW6r29s_xnQ,7374
55
+ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=JsrtuUY4q1Rovxsht2cGCuANUj1sUKnah6bAoSe8AoU,4387
58
56
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
59
57
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
60
58
  ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
61
59
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
62
- ai_edge_torch/generative/examples/phi/phi2.py,sha256=82SEKRwtKfT9VcNQaykGmemiov_XaXWLi4Zyw9Vtmj0,6075
63
- ai_edge_torch/generative/examples/phi/phi3.py,sha256=Xh-l7TQdXYZJ9PViRVk2_y91Ec7Yntn0UpkuzRIG3T8,9231
60
+ ai_edge_torch/generative/examples/phi/phi2.py,sha256=CQ55KfOdoOM43CxF7yNQsgq8b-j0S50bXpxYzgq-keM,3418
61
+ ai_edge_torch/generative/examples/phi/phi3.py,sha256=GkHOaYfsFEbHvfZCaLlb3Us_h19ezqPDUakoz_DiG9A,7123
64
62
  ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
65
63
  ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
66
64
  ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
67
65
  ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=QAAVoSKDVf2rHAChzumGloVCWIU0Oe5UYKgv3T192Iw,2496
68
- ai_edge_torch/generative/examples/qwen/qwen.py,sha256=b03q1On6JzPhJzTs1dQwT_tJjO7C9NYmyzrzV2kQ_yo,4579
66
+ ai_edge_torch/generative/examples/qwen/qwen.py,sha256=oYm9hhALUQ4uOn-PO1bF7fCIGP8EWRNK4zClkx2RQs8,4070
69
67
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
70
68
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
71
69
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
72
- ai_edge_torch/generative/examples/smollm/smollm.py,sha256=dal8vnZjQd6vR7sc76-FYGDKUlVjOlfUALV-pwbXJGc,3264
70
+ ai_edge_torch/generative/examples/smollm/smollm.py,sha256=M5qAcSUE5gxOSfq24a8lZku9kgvmlFCyIBar3kF2XEk,2570
73
71
  ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
74
72
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
75
73
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
@@ -96,7 +94,7 @@ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYo
96
94
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=ZpjSIiayjTEVwg5Q1vI9Iy5tq1YSF5zaVDF4HTp_Z2s,4353
97
95
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
98
96
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=ekxd8efjMgEvauUu3PidWOC-DszPHn5sqU753F7sJIM,2201
99
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=aSNHOAar5yPnGAeKsv8zrqYhOq9RR_7hwqHUMBb2mkM,5930
97
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=10X8HwPx4akzclnIMOBNItKQemhRbvxBbTo7nwZtWjM,2650
100
98
  ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
101
99
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
102
100
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
@@ -106,7 +104,7 @@ ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHif
106
104
  ai_edge_torch/generative/layers/builder.py,sha256=oE8DdqLA-oWkBC2zySSCh8JNAJg_hk8-W_UoMSrgDVk,5088
107
105
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
108
106
  ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
109
- ai_edge_torch/generative/layers/model_config.py,sha256=Fa0eFCMlyfdwd3cM1drhP9vlXRhIguDrglsHn4ax2_w,6948
107
+ ai_edge_torch/generative/layers/model_config.py,sha256=xZt4xaNZJPvtdy4hfbnRencEENr689zO0WnZbhpNTIs,7137
110
108
  ai_edge_torch/generative/layers/normalization.py,sha256=cpo88JUXbF9j3sJTU4JuwOap9ryGV05C1QkPij-YQwU,6999
111
109
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
112
110
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
@@ -123,14 +121,15 @@ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FB
123
121
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
124
122
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
125
123
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
126
- ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
127
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
128
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=ASXTeO9TxjhqcNwXwbyMUP07aqye7wD6JU6OGZCEmR4,8907
124
+ ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
125
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=-qB-JEIfPFNlpGyJA1TYo_5fawTdyf1C6ee8cP4kYOY,5530
126
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bVCm_mubuGszCBON6oRjQXcBgPZqlVmmOaLWwhZJLio,9060
129
127
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
130
128
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
131
129
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
132
130
  ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0bXq4X3McI5i-qWraieE,2987
133
131
  ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
132
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=89jt80UUfDzYBi-x077HBavWeuNJuYPXym9fiKCY1Tk,5278
134
133
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
135
134
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
136
135
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
@@ -181,8 +180,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
181
180
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
182
181
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
183
182
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
184
- ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
185
- ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/METADATA,sha256=l2x0NhvSM0VtobvX6i8hXWKYdfjaRUizk42xaJrQXtw,1897
186
- ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
187
- ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
188
- ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/RECORD,,
183
+ ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
+ ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/METADATA,sha256=a6Q1LozCx-4NWkm1EKZJFeCJTYiTNUSigoVwRevV0oc,1897
185
+ ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
186
+ ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
187
+ ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/RECORD,,