ai-edge-torch-nightly 0.3.0.dev20241002__py3-none-any.whl → 0.3.0.dev20241003__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (24) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +10 -93
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +0 -1
  3. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +13 -2
  4. ai_edge_torch/generative/examples/llama/llama.py +19 -24
  5. ai_edge_torch/generative/examples/llama/verify.py +18 -3
  6. ai_edge_torch/generative/examples/openelm/openelm.py +9 -90
  7. ai_edge_torch/generative/examples/phi/phi2.py +10 -86
  8. ai_edge_torch/generative/examples/phi/phi3.py +9 -69
  9. ai_edge_torch/generative/examples/qwen/qwen.py +26 -36
  10. ai_edge_torch/generative/examples/smollm/smollm.py +10 -30
  11. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +11 -101
  12. ai_edge_torch/generative/layers/model_config.py +6 -0
  13. ai_edge_torch/generative/test/test_loader.py +2 -1
  14. ai_edge_torch/generative/test/test_model_conversion.py +2 -1
  15. ai_edge_torch/generative/test/test_model_conversion_large.py +6 -5
  16. ai_edge_torch/generative/utilities/model_builder.py +141 -0
  17. ai_edge_torch/version.py +1 -1
  18. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/METADATA +1 -1
  19. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/RECORD +22 -23
  20. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +0 -68
  21. ai_edge_torch/generative/examples/llama/verify_3b.py +0 -73
  22. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/LICENSE +0 -0
  23. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/WHEEL +0 -0
  24. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/top_level.txt +0 -0
@@ -15,28 +15,10 @@
15
15
 
16
16
  """Example of building Qwen 2.5 models."""
17
17
 
18
- import copy
19
-
20
- from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
18
  import ai_edge_torch.generative.layers.model_config as cfg
22
- import ai_edge_torch.generative.utilities.loader as loading_utils
23
- from torch import nn
24
-
25
- TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
26
- # Qwen re-uses the embedding as the head projection layer.
27
- TENSOR_NAMES.lm_head = None
28
-
29
-
30
- class Qwen(tiny_llama.TinyLlama):
31
- """A Qwen model built from the Edge Generative API layers.
32
-
33
- Qwen 2.5 shares the same architecture as TinyLlama.
34
- """
19
+ from ai_edge_torch.generative.utilities import model_builder
35
20
 
36
- def __init__(self, config: cfg.ModelConfig):
37
- super().__init__(config)
38
- # Qwen re-uses the embedding as the head projection layer.
39
- self.lm_head.weight.data = self.tok_embedding.weight.data
21
+ TENSOR_NAMES = model_builder.TENSOR_NAMES
40
22
 
41
23
 
42
24
  def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -119,23 +101,31 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
119
101
  return config
120
102
 
121
103
 
122
- def _build_model(checkpoint_path: str, config: cfg.ModelConfig) -> nn.Module:
123
- model = Qwen(config)
124
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
125
- # Since embedding and lm-head use the same weight, we need to set strict
126
- # to False.
127
- loader.load(model, strict=False)
128
- model.eval()
129
- return model
130
-
131
-
132
- def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
133
- return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
104
+ def build_3b_model(
105
+ checkpoint_path: str, **kwargs
106
+ ) -> model_builder.DecoderOnlyModel:
107
+ return model_builder.build_decoder_only_model(
108
+ checkpoint_path=checkpoint_path,
109
+ config=get_3b_model_config(**kwargs),
110
+ tensor_names=TENSOR_NAMES,
111
+ )
134
112
 
135
113
 
136
- def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
137
- return _build_model(checkpoint_path, get_1_5b_model_config(**kwargs))
114
+ def build_1_5b_model(
115
+ checkpoint_path: str, **kwargs
116
+ ) -> model_builder.DecoderOnlyModel:
117
+ return model_builder.build_decoder_only_model(
118
+ checkpoint_path=checkpoint_path,
119
+ config=get_1_5b_model_config(**kwargs),
120
+ tensor_names=TENSOR_NAMES,
121
+ )
138
122
 
139
123
 
140
- def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
141
- return _build_model(checkpoint_path, get_0_5b_model_config(**kwargs))
124
+ def build_0_5b_model(
125
+ checkpoint_path: str, **kwargs
126
+ ) -> model_builder.DecoderOnlyModel:
127
+ return model_builder.build_decoder_only_model(
128
+ checkpoint_path=checkpoint_path,
129
+ config=get_0_5b_model_config(**kwargs),
130
+ tensor_names=TENSOR_NAMES,
131
+ )
@@ -15,29 +15,10 @@
15
15
 
16
16
  """Example of building a SmolLM model."""
17
17
 
18
- import copy
19
-
20
- from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
18
  import ai_edge_torch.generative.layers.model_config as cfg
22
- import ai_edge_torch.generative.utilities.loader as loading_utils
23
- from torch import nn
24
-
25
- TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
26
- # SmolLM re-uses the embedding as the head projection layer.
27
- TENSOR_NAMES.lm_head = None
28
-
29
-
30
- class SmolLM(tiny_llama.TinyLlama):
31
- """A SmolLM model built from the Edge Generative API layers.
19
+ from ai_edge_torch.generative.utilities import model_builder
32
20
 
33
- SmolLM shares the same architecture as TinyLlama, but with different model
34
- sizes.
35
- """
36
-
37
- def __init__(self, config: cfg.ModelConfig):
38
- super().__init__(config)
39
- # SmolLM re-uses the embedding as the head projection layer.
40
- self.lm_head.weight.data = self.tok_embedding.weight.data
21
+ TENSOR_NAMES = model_builder.TENSOR_NAMES
41
22
 
42
23
 
43
24
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -91,12 +72,11 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
91
72
  return config
92
73
 
93
74
 
94
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
95
- config = get_model_config(**kwargs)
96
- model = SmolLM(config)
97
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
98
- # Since embedding and lm-head use the same weight, we need to set strict
99
- # to False.
100
- loader.load(model, strict=False)
101
- model.eval()
102
- return model
75
+ def build_model(
76
+ checkpoint_path: str, **kwargs
77
+ ) -> model_builder.DecoderOnlyModel:
78
+ return model_builder.build_decoder_only_model(
79
+ checkpoint_path=checkpoint_path,
80
+ config=get_model_config(**kwargs),
81
+ tensor_names=TENSOR_NAMES,
82
+ )
@@ -15,102 +15,10 @@
15
15
 
16
16
  """Example of building a TinyLlama model."""
17
17
 
18
- from ai_edge_torch.generative.layers import attention
19
- from ai_edge_torch.generative.layers import builder
20
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
18
  import ai_edge_torch.generative.layers.model_config as cfg
23
- import ai_edge_torch.generative.utilities.loader as loading_utils
24
- import torch
25
- from torch import nn
19
+ from ai_edge_torch.generative.utilities import model_builder
26
20
 
27
- TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
28
- ff_up_proj="model.layers.{}.mlp.up_proj",
29
- ff_down_proj="model.layers.{}.mlp.down_proj",
30
- ff_gate_proj="model.layers.{}.mlp.gate_proj",
31
- attn_query_proj="model.layers.{}.self_attn.q_proj",
32
- attn_key_proj="model.layers.{}.self_attn.k_proj",
33
- attn_value_proj="model.layers.{}.self_attn.v_proj",
34
- attn_output_proj="model.layers.{}.self_attn.o_proj",
35
- pre_attn_norm="model.layers.{}.input_layernorm",
36
- post_attn_norm="model.layers.{}.post_attention_layernorm",
37
- embedding="model.embed_tokens",
38
- final_norm="model.norm",
39
- lm_head="lm_head",
40
- )
41
-
42
-
43
- class TinyLlama(nn.Module):
44
- """A TinyLlama model built from the Edge Generative API layers."""
45
-
46
- def __init__(self, config: cfg.ModelConfig):
47
- super().__init__()
48
-
49
- # Construct model layers.
50
- self.lm_head = nn.Linear(
51
- config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
52
- )
53
- self.tok_embedding = nn.Embedding(
54
- config.vocab_size, config.embedding_dim, padding_idx=0
55
- )
56
- # TinyLlama has only one block config.
57
- block_config = config.block_config(0)
58
- self.transformer_blocks = nn.ModuleList(
59
- attention.TransformerBlock(block_config, config)
60
- for _ in range(config.num_layers)
61
- )
62
- self.final_norm = builder.build_norm(
63
- config.embedding_dim,
64
- config.final_norm_config,
65
- )
66
- attn_config = block_config.attn_config
67
- self.rope_cache = attn_utils.build_rope_cache(
68
- size=config.kv_cache_max,
69
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
70
- base=attn_config.rotary_base,
71
- )
72
- self.mask_cache = attn_utils.build_causal_mask_cache(
73
- size=config.kv_cache_max,
74
- )
75
- self.config = config
76
-
77
- @torch.inference_mode
78
- def forward(
79
- self,
80
- tokens: torch.Tensor,
81
- input_pos: torch.Tensor,
82
- kv_cache: kv_utils.KVCache,
83
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
84
- _, seq_len = tokens.size()
85
- assert self.config.max_seq_len >= seq_len, (
86
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
87
- f" {self.config.max_seq_len}"
88
- )
89
- assert len(self.transformer_blocks) == len(kv_cache.caches), (
90
- "The number of transformer blocks and the number of KV cache entries"
91
- " must be the same."
92
- )
93
-
94
- cos, sin = self.rope_cache
95
- cos = cos.index_select(0, input_pos)
96
- sin = sin.index_select(0, input_pos)
97
- mask = self.mask_cache.index_select(2, input_pos)
98
- mask = mask[:, :, :, : self.config.kv_cache_max]
99
-
100
- # token embeddings of shape (b, t, n_embd)
101
- x = self.tok_embedding(tokens)
102
-
103
- updated_kv_entires = []
104
- for i, block in enumerate(self.transformer_blocks):
105
- kv_entry = kv_cache.caches[i] if kv_cache else None
106
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
107
- if kv_entry:
108
- updated_kv_entires.append(kv_entry)
109
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
110
-
111
- x = self.final_norm(x)
112
- logits = self.lm_head(x) # (b, t, vocab_size)
113
- return {"logits": logits, "kv_cache": updated_kv_cache}
21
+ TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
114
22
 
115
23
 
116
24
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -150,6 +58,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
150
58
  kv_cache_max_len=kv_cache_max_len,
151
59
  block_configs=block_config,
152
60
  final_norm_config=norm_config,
61
+ lm_head_share_weight_with_embedding=False,
153
62
  enable_hlfb=True,
154
63
  )
155
64
  return config
@@ -164,10 +73,11 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
164
73
  return config
165
74
 
166
75
 
167
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
168
- config = get_model_config(**kwargs)
169
- model = TinyLlama(config)
170
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
171
- loader.load(model)
172
- model.eval()
173
- return model
76
+ def build_model(
77
+ checkpoint_path: str, **kwargs
78
+ ) -> model_builder.DecoderOnlyModel:
79
+ return model_builder.build_decoder_only_model(
80
+ checkpoint_path=checkpoint_path,
81
+ config=get_model_config(**kwargs),
82
+ tensor_names=TENSOR_NAMES,
83
+ )
@@ -184,8 +184,14 @@ class ModelConfig:
184
184
  default_factory=NormalizationConfig
185
185
  )
186
186
 
187
+ # Scale factor of the embedding.
188
+ embedding_scale: Optional[float] = None
189
+
187
190
  # Use bias term within LLM's HEAD.
188
191
  lm_head_use_bias: bool = False
192
+ # Whether LLM's HEAD shares the weight of the embedding.
193
+ lm_head_share_weight_with_embedding: bool = True
194
+
189
195
  # Whether to turn on high-level function boundary.
190
196
  enable_hlfb: bool = False
191
197
 
@@ -19,6 +19,7 @@ import tempfile
19
19
 
20
20
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
21
  from ai_edge_torch.generative.utilities import loader as loading_utils
22
+ from ai_edge_torch.generative.utilities import model_builder
22
23
  import safetensors.torch
23
24
  import torch
24
25
 
@@ -71,7 +72,7 @@ class TestLoader(googletest.TestCase):
71
72
  safetensors.torch.save_file(test_weights, file_path)
72
73
  cfg = tiny_llama.get_model_config()
73
74
  cfg.num_layers = 1
74
- model = tiny_llama.TinyLlama(cfg)
75
+ model = model_builder.DecoderOnlyModel(cfg)
75
76
 
76
77
  loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
77
78
  # if returns successfully, it means all the tensors were initiallized.
@@ -21,6 +21,7 @@ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cach
21
21
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
22
22
  from ai_edge_torch.generative.layers import kv_cache
23
23
  from ai_edge_torch.generative.test import utils as test_utils
24
+ from ai_edge_torch.generative.utilities import model_builder
24
25
  import numpy as np
25
26
  import torch
26
27
 
@@ -163,7 +164,7 @@ class TestModelConversion(googletest.TestCase):
163
164
  )
164
165
  def test_tiny_llama_multisig(self):
165
166
  config = tiny_llama.get_fake_model_config()
166
- pytorch_model = tiny_llama.TinyLlama(config).eval()
167
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
167
168
  self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
168
169
 
169
170
 
@@ -29,6 +29,7 @@ from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
29
29
  from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
30
30
  from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
31
31
  from ai_edge_torch.generative.layers import kv_cache
32
+ from ai_edge_torch.generative.utilities import model_builder
32
33
  from ai_edge_torch.generative.test import utils as test_utils
33
34
  import numpy as np
34
35
  import torch
@@ -90,7 +91,7 @@ class TestModelConversion(googletest.TestCase):
90
91
  )
91
92
  def test_gemma1(self):
92
93
  config = gemma1.get_fake_model_config()
93
- pytorch_model = gemma1.Gemma(config).eval()
94
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
94
95
  self._test_model(
95
96
  config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
96
97
  )
@@ -119,7 +120,7 @@ class TestModelConversion(googletest.TestCase):
119
120
  )
120
121
  def test_phi2(self):
121
122
  config = phi2.get_fake_model_config()
122
- pytorch_model = phi2.Phi2(config).eval()
123
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
123
124
  self._test_model(
124
125
  config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
125
126
  )
@@ -139,7 +140,7 @@ class TestModelConversion(googletest.TestCase):
139
140
  )
140
141
  def test_smollm(self):
141
142
  config = smollm.get_fake_model_config()
142
- pytorch_model = smollm.SmolLM(config).eval()
143
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
143
144
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
144
145
 
145
146
  @googletest.skipIf(
@@ -148,7 +149,7 @@ class TestModelConversion(googletest.TestCase):
148
149
  )
149
150
  def test_openelm(self):
150
151
  config = openelm.get_fake_model_config()
151
- pytorch_model = openelm.OpenELM(config).eval()
152
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
152
153
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
153
154
 
154
155
  @googletest.skipIf(
@@ -157,7 +158,7 @@ class TestModelConversion(googletest.TestCase):
157
158
  )
158
159
  def test_qwen(self):
159
160
  config = qwen.get_fake_model_config()
160
- pytorch_model = qwen.Qwen(config).eval()
161
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
161
162
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
162
163
 
163
164
  @googletest.skipIf(
@@ -0,0 +1,141 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utilities to be used for re-authoring transformer models."""
17
+
18
+ import copy
19
+
20
+ from ai_edge_torch.generative.layers import attention
21
+ from ai_edge_torch.generative.layers import builder
22
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
+ import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import torch
27
+ from torch import nn
28
+
29
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
30
+ ff_up_proj="model.layers.{}.mlp.up_proj",
31
+ ff_down_proj="model.layers.{}.mlp.down_proj",
32
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
33
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
34
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
35
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
36
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
37
+ pre_attn_norm="model.layers.{}.input_layernorm",
38
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
39
+ embedding="model.embed_tokens",
40
+ final_norm="model.norm",
41
+ )
42
+
43
+ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD = copy.copy(TENSOR_NAMES)
44
+ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD.lm_head = "lm_head"
45
+
46
+
47
+ class DecoderOnlyModel(nn.Module):
48
+ """A simple decoder-only transformer model built from the Edge Generative API.
49
+
50
+ This model is used for re-authoring. model_config is used to specify the
51
+ details of model architecture and parameters.
52
+
53
+ It assumes that the attention configs for ROPE, i.e. head_dim, rotary_base,
54
+ and rotary_percentage are the same for all layers.
55
+ """
56
+
57
+ def __init__(self, config: cfg.ModelConfig):
58
+ super().__init__()
59
+
60
+ # Construct model layers.
61
+ self.tok_embedding = nn.Embedding(
62
+ config.vocab_size, config.embedding_dim, padding_idx=0
63
+ )
64
+ self.lm_head = nn.Linear(
65
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
66
+ )
67
+ if config.lm_head_share_weight_with_embedding:
68
+ self.lm_head.weight.data = self.tok_embedding.weight.data
69
+ self.transformer_blocks = nn.ModuleList(
70
+ attention.TransformerBlock(config.block_config(idx), config)
71
+ for idx in range(config.num_layers)
72
+ )
73
+ self.final_norm = builder.build_norm(
74
+ config.embedding_dim,
75
+ config.final_norm_config,
76
+ )
77
+ # ROPE parameters for all attn_configs are the same. Take the first one.
78
+ attn_config = config.block_config(0).attn_config
79
+ self.rope_cache = attn_utils.build_rope_cache(
80
+ size=config.kv_cache_max,
81
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
82
+ base=attn_config.rotary_base,
83
+ )
84
+ self.mask_cache = attn_utils.build_causal_mask_cache(
85
+ size=config.kv_cache_max,
86
+ )
87
+ self.config = config
88
+
89
+ @torch.inference_mode
90
+ def forward(
91
+ self,
92
+ tokens: torch.Tensor,
93
+ input_pos: torch.Tensor,
94
+ kv_cache: kv_utils.KVCache,
95
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
96
+ _, seq_len = tokens.size()
97
+ assert self.config.max_seq_len >= seq_len, (
98
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
99
+ f" {self.config.max_seq_len}"
100
+ )
101
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
102
+ "The number of transformer blocks and the number of KV cache entries"
103
+ " must be the same."
104
+ )
105
+
106
+ cos, sin = self.rope_cache
107
+ cos = cos.index_select(0, input_pos)
108
+ sin = sin.index_select(0, input_pos)
109
+ mask = self.mask_cache.index_select(2, input_pos)
110
+ mask = mask[:, :, :, : self.config.kv_cache_max]
111
+
112
+ # token embeddings of shape (b, t, n_embd)
113
+ x = self.tok_embedding(tokens)
114
+ if self.config.embedding_scale is not None:
115
+ x = x * self.config.embedding_scale
116
+
117
+ updated_kv_entires = []
118
+ for i, block in enumerate(self.transformer_blocks):
119
+ kv_entry = kv_cache.caches[i] if kv_cache else None
120
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
121
+ if kv_entry:
122
+ updated_kv_entires.append(kv_entry)
123
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
124
+
125
+ x = self.final_norm(x)
126
+ logits = self.lm_head(x) # (b, t, vocab_size)
127
+ return {"logits": logits, "kv_cache": updated_kv_cache}
128
+
129
+
130
+ def build_decoder_only_model(
131
+ checkpoint_path: str,
132
+ config: cfg.ModelConfig,
133
+ tensor_names: loading_utils.ModelLoader.TensorNames,
134
+ ) -> DecoderOnlyModel:
135
+ transformer = DecoderOnlyModel(config)
136
+ loader = loading_utils.ModelLoader(checkpoint_path, tensor_names)
137
+ loader.load(
138
+ transformer, strict=not config.lm_head_share_weight_with_embedding
139
+ )
140
+ transformer.eval()
141
+ return transformer
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241002"
16
+ __version__ = "0.3.0.dev20241003"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241002
3
+ Version: 0.3.0.dev20241003
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=ODx8CRsxZZYlliSx6vnHxxTorI9c0WPgrVvwGY5KAQI,706
6
+ ai_edge_torch/version.py,sha256=WKaZCocAyLb42oFdC07BQ6qpSfohXBwt-HKGV7S2fXw,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -41,35 +41,33 @@ ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQe
41
41
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
42
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
43
43
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
44
- ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=kxWmmoVvtLP5auB3UXA2vsvZmSnpBs4SBixzYeAXzVA,6255
45
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=7VF5RYJ8QhROQNIlx-QovO-y6-jFp_EHgAkBNChZaqE,9066
44
+ ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
45
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
46
46
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
47
47
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=K77k-JpdhIwm3tbBnzpw8HQsFRwAVyszxRo82fR6-q4,1762
48
48
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=sqltZbnyKemNvKqqi9d09i74gP-PPQFodRYfDfnhycQ,4933
49
49
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py,sha256=_OrerrTA6tvP9Tnwj601QO95Cm8PlOiYP-mxvtmBmb4,2186
51
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=GGo6Kxiwqva4JfurGx3InU3nROW70XtYvxUwEf_6mBQ,2180
52
- ai_edge_torch/generative/examples/llama/llama.py,sha256=5vlh2Z8vEPH8Z4LoHoFYCcuOQynx4mbVE37v3yMl1hE,7162
53
- ai_edge_torch/generative/examples/llama/verify.py,sha256=7xwKM_yzLCrmFsYj1UbsjW58ZG8Yic0xw1GFkdydrCU,2525
54
- ai_edge_torch/generative/examples/llama/verify_3b.py,sha256=IijBWqLXINOfwayM-8EIpc7OcC6Nj5CnberStx-vDSk,2528
50
+ ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=P0-pByTM5tslE23ILgo7nd0nOGE25ciBRG5wKJj0bBk,2411
51
+ ai_edge_torch/generative/examples/llama/llama.py,sha256=AMcCbuDBxEfbO-l3KiEXbUaXEJ3RLLwkHii7to7UhVo,6854
52
+ ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
55
53
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
56
54
  ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
57
- ai_edge_torch/generative/examples/openelm/openelm.py,sha256=hxbpvk0fNswzbqZfGteflqKMmkH7yzeMuW6r29s_xnQ,7374
55
+ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=JsrtuUY4q1Rovxsht2cGCuANUj1sUKnah6bAoSe8AoU,4387
58
56
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
59
57
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
60
58
  ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
61
59
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
62
- ai_edge_torch/generative/examples/phi/phi2.py,sha256=82SEKRwtKfT9VcNQaykGmemiov_XaXWLi4Zyw9Vtmj0,6075
63
- ai_edge_torch/generative/examples/phi/phi3.py,sha256=Xh-l7TQdXYZJ9PViRVk2_y91Ec7Yntn0UpkuzRIG3T8,9231
60
+ ai_edge_torch/generative/examples/phi/phi2.py,sha256=CQ55KfOdoOM43CxF7yNQsgq8b-j0S50bXpxYzgq-keM,3418
61
+ ai_edge_torch/generative/examples/phi/phi3.py,sha256=GkHOaYfsFEbHvfZCaLlb3Us_h19ezqPDUakoz_DiG9A,7123
64
62
  ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
65
63
  ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
66
64
  ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
67
65
  ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=QAAVoSKDVf2rHAChzumGloVCWIU0Oe5UYKgv3T192Iw,2496
68
- ai_edge_torch/generative/examples/qwen/qwen.py,sha256=b03q1On6JzPhJzTs1dQwT_tJjO7C9NYmyzrzV2kQ_yo,4579
66
+ ai_edge_torch/generative/examples/qwen/qwen.py,sha256=oYm9hhALUQ4uOn-PO1bF7fCIGP8EWRNK4zClkx2RQs8,4070
69
67
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
70
68
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
71
69
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
72
- ai_edge_torch/generative/examples/smollm/smollm.py,sha256=dal8vnZjQd6vR7sc76-FYGDKUlVjOlfUALV-pwbXJGc,3264
70
+ ai_edge_torch/generative/examples/smollm/smollm.py,sha256=M5qAcSUE5gxOSfq24a8lZku9kgvmlFCyIBar3kF2XEk,2570
73
71
  ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
74
72
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
75
73
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
@@ -96,7 +94,7 @@ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYo
96
94
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=ZpjSIiayjTEVwg5Q1vI9Iy5tq1YSF5zaVDF4HTp_Z2s,4353
97
95
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
98
96
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=ekxd8efjMgEvauUu3PidWOC-DszPHn5sqU753F7sJIM,2201
99
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=aSNHOAar5yPnGAeKsv8zrqYhOq9RR_7hwqHUMBb2mkM,5930
97
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=10X8HwPx4akzclnIMOBNItKQemhRbvxBbTo7nwZtWjM,2650
100
98
  ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
101
99
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
102
100
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
@@ -106,7 +104,7 @@ ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHif
106
104
  ai_edge_torch/generative/layers/builder.py,sha256=oE8DdqLA-oWkBC2zySSCh8JNAJg_hk8-W_UoMSrgDVk,5088
107
105
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
108
106
  ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
109
- ai_edge_torch/generative/layers/model_config.py,sha256=Fa0eFCMlyfdwd3cM1drhP9vlXRhIguDrglsHn4ax2_w,6948
107
+ ai_edge_torch/generative/layers/model_config.py,sha256=xZt4xaNZJPvtdy4hfbnRencEENr689zO0WnZbhpNTIs,7137
110
108
  ai_edge_torch/generative/layers/normalization.py,sha256=cpo88JUXbF9j3sJTU4JuwOap9ryGV05C1QkPij-YQwU,6999
111
109
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
112
110
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
@@ -123,14 +121,15 @@ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FB
123
121
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
124
122
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
125
123
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
126
- ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
127
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
128
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=ASXTeO9TxjhqcNwXwbyMUP07aqye7wD6JU6OGZCEmR4,8907
124
+ ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
125
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=-qB-JEIfPFNlpGyJA1TYo_5fawTdyf1C6ee8cP4kYOY,5530
126
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bVCm_mubuGszCBON6oRjQXcBgPZqlVmmOaLWwhZJLio,9060
129
127
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
130
128
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
131
129
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
132
130
  ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0bXq4X3McI5i-qWraieE,2987
133
131
  ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
132
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=89jt80UUfDzYBi-x077HBavWeuNJuYPXym9fiKCY1Tk,5278
134
133
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
135
134
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
136
135
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
@@ -181,8 +180,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
181
180
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
182
181
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
183
182
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
184
- ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
185
- ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/METADATA,sha256=l2x0NhvSM0VtobvX6i8hXWKYdfjaRUizk42xaJrQXtw,1897
186
- ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
187
- ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
188
- ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/RECORD,,
183
+ ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
+ ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/METADATA,sha256=a6Q1LozCx-4NWkm1EKZJFeCJTYiTNUSigoVwRevV0oc,1897
185
+ ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
186
+ ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
187
+ ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/RECORD,,