ai-edge-torch-nightly 0.3.0.dev20241002__py3-none-any.whl → 0.3.0.dev20241003__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/gemma1.py +10 -93
- ai_edge_torch/generative/examples/gemma/gemma2.py +0 -1
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +13 -2
- ai_edge_torch/generative/examples/llama/llama.py +19 -24
- ai_edge_torch/generative/examples/llama/verify.py +18 -3
- ai_edge_torch/generative/examples/openelm/openelm.py +9 -90
- ai_edge_torch/generative/examples/phi/phi2.py +10 -86
- ai_edge_torch/generative/examples/phi/phi3.py +9 -69
- ai_edge_torch/generative/examples/qwen/qwen.py +26 -36
- ai_edge_torch/generative/examples/smollm/smollm.py +10 -30
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +11 -101
- ai_edge_torch/generative/layers/model_config.py +6 -0
- ai_edge_torch/generative/test/test_loader.py +2 -1
- ai_edge_torch/generative/test/test_model_conversion.py +2 -1
- ai_edge_torch/generative/test/test_model_conversion_large.py +6 -5
- ai_edge_torch/generative/utilities/model_builder.py +141 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/RECORD +22 -23
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +0 -68
- ai_edge_torch/generative/examples/llama/verify_3b.py +0 -73
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/top_level.txt +0 -0
@@ -15,28 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Example of building Qwen 2.5 models."""
|
17
17
|
|
18
|
-
import copy
|
19
|
-
|
20
|
-
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
21
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
|
-
|
23
|
-
from torch import nn
|
24
|
-
|
25
|
-
TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
|
26
|
-
# Qwen re-uses the embedding as the head projection layer.
|
27
|
-
TENSOR_NAMES.lm_head = None
|
28
|
-
|
29
|
-
|
30
|
-
class Qwen(tiny_llama.TinyLlama):
|
31
|
-
"""A Qwen model built from the Edge Generative API layers.
|
32
|
-
|
33
|
-
Qwen 2.5 shares the same architecture as TinyLlama.
|
34
|
-
"""
|
19
|
+
from ai_edge_torch.generative.utilities import model_builder
|
35
20
|
|
36
|
-
|
37
|
-
super().__init__(config)
|
38
|
-
# Qwen re-uses the embedding as the head projection layer.
|
39
|
-
self.lm_head.weight.data = self.tok_embedding.weight.data
|
21
|
+
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
40
22
|
|
41
23
|
|
42
24
|
def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -119,23 +101,31 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
119
101
|
return config
|
120
102
|
|
121
103
|
|
122
|
-
def
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
133
|
-
return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
|
104
|
+
def build_3b_model(
|
105
|
+
checkpoint_path: str, **kwargs
|
106
|
+
) -> model_builder.DecoderOnlyModel:
|
107
|
+
return model_builder.build_decoder_only_model(
|
108
|
+
checkpoint_path=checkpoint_path,
|
109
|
+
config=get_3b_model_config(**kwargs),
|
110
|
+
tensor_names=TENSOR_NAMES,
|
111
|
+
)
|
134
112
|
|
135
113
|
|
136
|
-
def build_1_5b_model(
|
137
|
-
|
114
|
+
def build_1_5b_model(
|
115
|
+
checkpoint_path: str, **kwargs
|
116
|
+
) -> model_builder.DecoderOnlyModel:
|
117
|
+
return model_builder.build_decoder_only_model(
|
118
|
+
checkpoint_path=checkpoint_path,
|
119
|
+
config=get_1_5b_model_config(**kwargs),
|
120
|
+
tensor_names=TENSOR_NAMES,
|
121
|
+
)
|
138
122
|
|
139
123
|
|
140
|
-
def build_0_5b_model(
|
141
|
-
|
124
|
+
def build_0_5b_model(
|
125
|
+
checkpoint_path: str, **kwargs
|
126
|
+
) -> model_builder.DecoderOnlyModel:
|
127
|
+
return model_builder.build_decoder_only_model(
|
128
|
+
checkpoint_path=checkpoint_path,
|
129
|
+
config=get_0_5b_model_config(**kwargs),
|
130
|
+
tensor_names=TENSOR_NAMES,
|
131
|
+
)
|
@@ -15,29 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Example of building a SmolLM model."""
|
17
17
|
|
18
|
-
import copy
|
19
|
-
|
20
|
-
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
21
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
|
-
|
23
|
-
from torch import nn
|
24
|
-
|
25
|
-
TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
|
26
|
-
# SmolLM re-uses the embedding as the head projection layer.
|
27
|
-
TENSOR_NAMES.lm_head = None
|
28
|
-
|
29
|
-
|
30
|
-
class SmolLM(tiny_llama.TinyLlama):
|
31
|
-
"""A SmolLM model built from the Edge Generative API layers.
|
19
|
+
from ai_edge_torch.generative.utilities import model_builder
|
32
20
|
|
33
|
-
|
34
|
-
sizes.
|
35
|
-
"""
|
36
|
-
|
37
|
-
def __init__(self, config: cfg.ModelConfig):
|
38
|
-
super().__init__(config)
|
39
|
-
# SmolLM re-uses the embedding as the head projection layer.
|
40
|
-
self.lm_head.weight.data = self.tok_embedding.weight.data
|
21
|
+
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
41
22
|
|
42
23
|
|
43
24
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -91,12 +72,11 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
91
72
|
return config
|
92
73
|
|
93
74
|
|
94
|
-
def build_model(
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
return model
|
75
|
+
def build_model(
|
76
|
+
checkpoint_path: str, **kwargs
|
77
|
+
) -> model_builder.DecoderOnlyModel:
|
78
|
+
return model_builder.build_decoder_only_model(
|
79
|
+
checkpoint_path=checkpoint_path,
|
80
|
+
config=get_model_config(**kwargs),
|
81
|
+
tensor_names=TENSOR_NAMES,
|
82
|
+
)
|
@@ -15,102 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Example of building a TinyLlama model."""
|
17
17
|
|
18
|
-
from ai_edge_torch.generative.layers import attention
|
19
|
-
from ai_edge_torch.generative.layers import builder
|
20
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
|
-
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
|
-
|
24
|
-
import torch
|
25
|
-
from torch import nn
|
19
|
+
from ai_edge_torch.generative.utilities import model_builder
|
26
20
|
|
27
|
-
TENSOR_NAMES =
|
28
|
-
ff_up_proj="model.layers.{}.mlp.up_proj",
|
29
|
-
ff_down_proj="model.layers.{}.mlp.down_proj",
|
30
|
-
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
31
|
-
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
32
|
-
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
33
|
-
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
34
|
-
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
35
|
-
pre_attn_norm="model.layers.{}.input_layernorm",
|
36
|
-
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
37
|
-
embedding="model.embed_tokens",
|
38
|
-
final_norm="model.norm",
|
39
|
-
lm_head="lm_head",
|
40
|
-
)
|
41
|
-
|
42
|
-
|
43
|
-
class TinyLlama(nn.Module):
|
44
|
-
"""A TinyLlama model built from the Edge Generative API layers."""
|
45
|
-
|
46
|
-
def __init__(self, config: cfg.ModelConfig):
|
47
|
-
super().__init__()
|
48
|
-
|
49
|
-
# Construct model layers.
|
50
|
-
self.lm_head = nn.Linear(
|
51
|
-
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
52
|
-
)
|
53
|
-
self.tok_embedding = nn.Embedding(
|
54
|
-
config.vocab_size, config.embedding_dim, padding_idx=0
|
55
|
-
)
|
56
|
-
# TinyLlama has only one block config.
|
57
|
-
block_config = config.block_config(0)
|
58
|
-
self.transformer_blocks = nn.ModuleList(
|
59
|
-
attention.TransformerBlock(block_config, config)
|
60
|
-
for _ in range(config.num_layers)
|
61
|
-
)
|
62
|
-
self.final_norm = builder.build_norm(
|
63
|
-
config.embedding_dim,
|
64
|
-
config.final_norm_config,
|
65
|
-
)
|
66
|
-
attn_config = block_config.attn_config
|
67
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
68
|
-
size=config.kv_cache_max,
|
69
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
70
|
-
base=attn_config.rotary_base,
|
71
|
-
)
|
72
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
73
|
-
size=config.kv_cache_max,
|
74
|
-
)
|
75
|
-
self.config = config
|
76
|
-
|
77
|
-
@torch.inference_mode
|
78
|
-
def forward(
|
79
|
-
self,
|
80
|
-
tokens: torch.Tensor,
|
81
|
-
input_pos: torch.Tensor,
|
82
|
-
kv_cache: kv_utils.KVCache,
|
83
|
-
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
84
|
-
_, seq_len = tokens.size()
|
85
|
-
assert self.config.max_seq_len >= seq_len, (
|
86
|
-
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
87
|
-
f" {self.config.max_seq_len}"
|
88
|
-
)
|
89
|
-
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
90
|
-
"The number of transformer blocks and the number of KV cache entries"
|
91
|
-
" must be the same."
|
92
|
-
)
|
93
|
-
|
94
|
-
cos, sin = self.rope_cache
|
95
|
-
cos = cos.index_select(0, input_pos)
|
96
|
-
sin = sin.index_select(0, input_pos)
|
97
|
-
mask = self.mask_cache.index_select(2, input_pos)
|
98
|
-
mask = mask[:, :, :, : self.config.kv_cache_max]
|
99
|
-
|
100
|
-
# token embeddings of shape (b, t, n_embd)
|
101
|
-
x = self.tok_embedding(tokens)
|
102
|
-
|
103
|
-
updated_kv_entires = []
|
104
|
-
for i, block in enumerate(self.transformer_blocks):
|
105
|
-
kv_entry = kv_cache.caches[i] if kv_cache else None
|
106
|
-
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
107
|
-
if kv_entry:
|
108
|
-
updated_kv_entires.append(kv_entry)
|
109
|
-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
110
|
-
|
111
|
-
x = self.final_norm(x)
|
112
|
-
logits = self.lm_head(x) # (b, t, vocab_size)
|
113
|
-
return {"logits": logits, "kv_cache": updated_kv_cache}
|
21
|
+
TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
|
114
22
|
|
115
23
|
|
116
24
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -150,6 +58,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
150
58
|
kv_cache_max_len=kv_cache_max_len,
|
151
59
|
block_configs=block_config,
|
152
60
|
final_norm_config=norm_config,
|
61
|
+
lm_head_share_weight_with_embedding=False,
|
153
62
|
enable_hlfb=True,
|
154
63
|
)
|
155
64
|
return config
|
@@ -164,10 +73,11 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
164
73
|
return config
|
165
74
|
|
166
75
|
|
167
|
-
def build_model(
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
76
|
+
def build_model(
|
77
|
+
checkpoint_path: str, **kwargs
|
78
|
+
) -> model_builder.DecoderOnlyModel:
|
79
|
+
return model_builder.build_decoder_only_model(
|
80
|
+
checkpoint_path=checkpoint_path,
|
81
|
+
config=get_model_config(**kwargs),
|
82
|
+
tensor_names=TENSOR_NAMES,
|
83
|
+
)
|
@@ -184,8 +184,14 @@ class ModelConfig:
|
|
184
184
|
default_factory=NormalizationConfig
|
185
185
|
)
|
186
186
|
|
187
|
+
# Scale factor of the embedding.
|
188
|
+
embedding_scale: Optional[float] = None
|
189
|
+
|
187
190
|
# Use bias term within LLM's HEAD.
|
188
191
|
lm_head_use_bias: bool = False
|
192
|
+
# Whether LLM's HEAD shares the weight of the embedding.
|
193
|
+
lm_head_share_weight_with_embedding: bool = True
|
194
|
+
|
189
195
|
# Whether to turn on high-level function boundary.
|
190
196
|
enable_hlfb: bool = False
|
191
197
|
|
@@ -19,6 +19,7 @@ import tempfile
|
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
21
21
|
from ai_edge_torch.generative.utilities import loader as loading_utils
|
22
|
+
from ai_edge_torch.generative.utilities import model_builder
|
22
23
|
import safetensors.torch
|
23
24
|
import torch
|
24
25
|
|
@@ -71,7 +72,7 @@ class TestLoader(googletest.TestCase):
|
|
71
72
|
safetensors.torch.save_file(test_weights, file_path)
|
72
73
|
cfg = tiny_llama.get_model_config()
|
73
74
|
cfg.num_layers = 1
|
74
|
-
model =
|
75
|
+
model = model_builder.DecoderOnlyModel(cfg)
|
75
76
|
|
76
77
|
loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
|
77
78
|
# if returns successfully, it means all the tensors were initiallized.
|
@@ -21,6 +21,7 @@ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cach
|
|
21
21
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache
|
23
23
|
from ai_edge_torch.generative.test import utils as test_utils
|
24
|
+
from ai_edge_torch.generative.utilities import model_builder
|
24
25
|
import numpy as np
|
25
26
|
import torch
|
26
27
|
|
@@ -163,7 +164,7 @@ class TestModelConversion(googletest.TestCase):
|
|
163
164
|
)
|
164
165
|
def test_tiny_llama_multisig(self):
|
165
166
|
config = tiny_llama.get_fake_model_config()
|
166
|
-
pytorch_model =
|
167
|
+
pytorch_model = model_builder.DecoderOnlyModel(config).eval()
|
167
168
|
self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
|
168
169
|
|
169
170
|
|
@@ -29,6 +29,7 @@ from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
|
|
29
29
|
from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
|
30
30
|
from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
|
31
31
|
from ai_edge_torch.generative.layers import kv_cache
|
32
|
+
from ai_edge_torch.generative.utilities import model_builder
|
32
33
|
from ai_edge_torch.generative.test import utils as test_utils
|
33
34
|
import numpy as np
|
34
35
|
import torch
|
@@ -90,7 +91,7 @@ class TestModelConversion(googletest.TestCase):
|
|
90
91
|
)
|
91
92
|
def test_gemma1(self):
|
92
93
|
config = gemma1.get_fake_model_config()
|
93
|
-
pytorch_model =
|
94
|
+
pytorch_model = model_builder.DecoderOnlyModel(config).eval()
|
94
95
|
self._test_model(
|
95
96
|
config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
|
96
97
|
)
|
@@ -119,7 +120,7 @@ class TestModelConversion(googletest.TestCase):
|
|
119
120
|
)
|
120
121
|
def test_phi2(self):
|
121
122
|
config = phi2.get_fake_model_config()
|
122
|
-
pytorch_model =
|
123
|
+
pytorch_model = model_builder.DecoderOnlyModel(config).eval()
|
123
124
|
self._test_model(
|
124
125
|
config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
|
125
126
|
)
|
@@ -139,7 +140,7 @@ class TestModelConversion(googletest.TestCase):
|
|
139
140
|
)
|
140
141
|
def test_smollm(self):
|
141
142
|
config = smollm.get_fake_model_config()
|
142
|
-
pytorch_model =
|
143
|
+
pytorch_model = model_builder.DecoderOnlyModel(config).eval()
|
143
144
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
144
145
|
|
145
146
|
@googletest.skipIf(
|
@@ -148,7 +149,7 @@ class TestModelConversion(googletest.TestCase):
|
|
148
149
|
)
|
149
150
|
def test_openelm(self):
|
150
151
|
config = openelm.get_fake_model_config()
|
151
|
-
pytorch_model =
|
152
|
+
pytorch_model = model_builder.DecoderOnlyModel(config).eval()
|
152
153
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
153
154
|
|
154
155
|
@googletest.skipIf(
|
@@ -157,7 +158,7 @@ class TestModelConversion(googletest.TestCase):
|
|
157
158
|
)
|
158
159
|
def test_qwen(self):
|
159
160
|
config = qwen.get_fake_model_config()
|
160
|
-
pytorch_model =
|
161
|
+
pytorch_model = model_builder.DecoderOnlyModel(config).eval()
|
161
162
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
162
163
|
|
163
164
|
@googletest.skipIf(
|
@@ -0,0 +1,141 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Utilities to be used for re-authoring transformer models."""
|
17
|
+
|
18
|
+
import copy
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.layers import attention
|
21
|
+
from ai_edge_torch.generative.layers import builder
|
22
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
|
+
import torch
|
27
|
+
from torch import nn
|
28
|
+
|
29
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
30
|
+
ff_up_proj="model.layers.{}.mlp.up_proj",
|
31
|
+
ff_down_proj="model.layers.{}.mlp.down_proj",
|
32
|
+
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
33
|
+
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
34
|
+
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
35
|
+
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
36
|
+
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
37
|
+
pre_attn_norm="model.layers.{}.input_layernorm",
|
38
|
+
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
39
|
+
embedding="model.embed_tokens",
|
40
|
+
final_norm="model.norm",
|
41
|
+
)
|
42
|
+
|
43
|
+
TENSOR_NAMES_WITH_SEPARATE_LM_HEAD = copy.copy(TENSOR_NAMES)
|
44
|
+
TENSOR_NAMES_WITH_SEPARATE_LM_HEAD.lm_head = "lm_head"
|
45
|
+
|
46
|
+
|
47
|
+
class DecoderOnlyModel(nn.Module):
|
48
|
+
"""A simple decoder-only transformer model built from the Edge Generative API.
|
49
|
+
|
50
|
+
This model is used for re-authoring. model_config is used to specify the
|
51
|
+
details of model architecture and parameters.
|
52
|
+
|
53
|
+
It assumes that the attention configs for ROPE, i.e. head_dim, rotary_base,
|
54
|
+
and rotary_percentage are the same for all layers.
|
55
|
+
"""
|
56
|
+
|
57
|
+
def __init__(self, config: cfg.ModelConfig):
|
58
|
+
super().__init__()
|
59
|
+
|
60
|
+
# Construct model layers.
|
61
|
+
self.tok_embedding = nn.Embedding(
|
62
|
+
config.vocab_size, config.embedding_dim, padding_idx=0
|
63
|
+
)
|
64
|
+
self.lm_head = nn.Linear(
|
65
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
66
|
+
)
|
67
|
+
if config.lm_head_share_weight_with_embedding:
|
68
|
+
self.lm_head.weight.data = self.tok_embedding.weight.data
|
69
|
+
self.transformer_blocks = nn.ModuleList(
|
70
|
+
attention.TransformerBlock(config.block_config(idx), config)
|
71
|
+
for idx in range(config.num_layers)
|
72
|
+
)
|
73
|
+
self.final_norm = builder.build_norm(
|
74
|
+
config.embedding_dim,
|
75
|
+
config.final_norm_config,
|
76
|
+
)
|
77
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
78
|
+
attn_config = config.block_config(0).attn_config
|
79
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
80
|
+
size=config.kv_cache_max,
|
81
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
82
|
+
base=attn_config.rotary_base,
|
83
|
+
)
|
84
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
85
|
+
size=config.kv_cache_max,
|
86
|
+
)
|
87
|
+
self.config = config
|
88
|
+
|
89
|
+
@torch.inference_mode
|
90
|
+
def forward(
|
91
|
+
self,
|
92
|
+
tokens: torch.Tensor,
|
93
|
+
input_pos: torch.Tensor,
|
94
|
+
kv_cache: kv_utils.KVCache,
|
95
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
96
|
+
_, seq_len = tokens.size()
|
97
|
+
assert self.config.max_seq_len >= seq_len, (
|
98
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
99
|
+
f" {self.config.max_seq_len}"
|
100
|
+
)
|
101
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
102
|
+
"The number of transformer blocks and the number of KV cache entries"
|
103
|
+
" must be the same."
|
104
|
+
)
|
105
|
+
|
106
|
+
cos, sin = self.rope_cache
|
107
|
+
cos = cos.index_select(0, input_pos)
|
108
|
+
sin = sin.index_select(0, input_pos)
|
109
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
110
|
+
mask = mask[:, :, :, : self.config.kv_cache_max]
|
111
|
+
|
112
|
+
# token embeddings of shape (b, t, n_embd)
|
113
|
+
x = self.tok_embedding(tokens)
|
114
|
+
if self.config.embedding_scale is not None:
|
115
|
+
x = x * self.config.embedding_scale
|
116
|
+
|
117
|
+
updated_kv_entires = []
|
118
|
+
for i, block in enumerate(self.transformer_blocks):
|
119
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
120
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
121
|
+
if kv_entry:
|
122
|
+
updated_kv_entires.append(kv_entry)
|
123
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
124
|
+
|
125
|
+
x = self.final_norm(x)
|
126
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
127
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
128
|
+
|
129
|
+
|
130
|
+
def build_decoder_only_model(
|
131
|
+
checkpoint_path: str,
|
132
|
+
config: cfg.ModelConfig,
|
133
|
+
tensor_names: loading_utils.ModelLoader.TensorNames,
|
134
|
+
) -> DecoderOnlyModel:
|
135
|
+
transformer = DecoderOnlyModel(config)
|
136
|
+
loader = loading_utils.ModelLoader(checkpoint_path, tensor_names)
|
137
|
+
loader.load(
|
138
|
+
transformer, strict=not config.lm_head_share_weight_with_embedding
|
139
|
+
)
|
140
|
+
transformer.eval()
|
141
|
+
return transformer
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241003
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=WKaZCocAyLb42oFdC07BQ6qpSfohXBwt-HKGV7S2fXw,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -41,35 +41,33 @@ ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQe
|
|
41
41
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
42
42
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
|
43
43
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
|
44
|
-
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=
|
45
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
44
|
+
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
|
45
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
|
46
46
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
47
47
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=K77k-JpdhIwm3tbBnzpw8HQsFRwAVyszxRo82fR6-q4,1762
|
48
48
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=sqltZbnyKemNvKqqi9d09i74gP-PPQFodRYfDfnhycQ,4933
|
49
49
|
ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
50
|
-
ai_edge_torch/generative/examples/llama/
|
51
|
-
ai_edge_torch/generative/examples/llama/
|
52
|
-
ai_edge_torch/generative/examples/llama/
|
53
|
-
ai_edge_torch/generative/examples/llama/verify.py,sha256=7xwKM_yzLCrmFsYj1UbsjW58ZG8Yic0xw1GFkdydrCU,2525
|
54
|
-
ai_edge_torch/generative/examples/llama/verify_3b.py,sha256=IijBWqLXINOfwayM-8EIpc7OcC6Nj5CnberStx-vDSk,2528
|
50
|
+
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=P0-pByTM5tslE23ILgo7nd0nOGE25ciBRG5wKJj0bBk,2411
|
51
|
+
ai_edge_torch/generative/examples/llama/llama.py,sha256=AMcCbuDBxEfbO-l3KiEXbUaXEJ3RLLwkHii7to7UhVo,6854
|
52
|
+
ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
|
55
53
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
56
54
|
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
|
57
|
-
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=
|
55
|
+
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=JsrtuUY4q1Rovxsht2cGCuANUj1sUKnah6bAoSe8AoU,4387
|
58
56
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
59
57
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
60
58
|
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
|
61
59
|
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
|
62
|
-
ai_edge_torch/generative/examples/phi/phi2.py,sha256=
|
63
|
-
ai_edge_torch/generative/examples/phi/phi3.py,sha256=
|
60
|
+
ai_edge_torch/generative/examples/phi/phi2.py,sha256=CQ55KfOdoOM43CxF7yNQsgq8b-j0S50bXpxYzgq-keM,3418
|
61
|
+
ai_edge_torch/generative/examples/phi/phi3.py,sha256=GkHOaYfsFEbHvfZCaLlb3Us_h19ezqPDUakoz_DiG9A,7123
|
64
62
|
ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
|
65
63
|
ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
|
66
64
|
ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
67
65
|
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=QAAVoSKDVf2rHAChzumGloVCWIU0Oe5UYKgv3T192Iw,2496
|
68
|
-
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=
|
66
|
+
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=oYm9hhALUQ4uOn-PO1bF7fCIGP8EWRNK4zClkx2RQs8,4070
|
69
67
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
70
68
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
71
69
|
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
|
72
|
-
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=
|
70
|
+
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=M5qAcSUE5gxOSfq24a8lZku9kgvmlFCyIBar3kF2XEk,2570
|
73
71
|
ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
|
74
72
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
75
73
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
@@ -96,7 +94,7 @@ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYo
|
|
96
94
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=ZpjSIiayjTEVwg5Q1vI9Iy5tq1YSF5zaVDF4HTp_Z2s,4353
|
97
95
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
98
96
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=ekxd8efjMgEvauUu3PidWOC-DszPHn5sqU753F7sJIM,2201
|
99
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
97
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=10X8HwPx4akzclnIMOBNItKQemhRbvxBbTo7nwZtWjM,2650
|
100
98
|
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
|
101
99
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
102
100
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
@@ -106,7 +104,7 @@ ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHif
|
|
106
104
|
ai_edge_torch/generative/layers/builder.py,sha256=oE8DdqLA-oWkBC2zySSCh8JNAJg_hk8-W_UoMSrgDVk,5088
|
107
105
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
108
106
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
|
109
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
107
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=xZt4xaNZJPvtdy4hfbnRencEENr689zO0WnZbhpNTIs,7137
|
110
108
|
ai_edge_torch/generative/layers/normalization.py,sha256=cpo88JUXbF9j3sJTU4JuwOap9ryGV05C1QkPij-YQwU,6999
|
111
109
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
112
110
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
@@ -123,14 +121,15 @@ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FB
|
|
123
121
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
124
122
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
125
123
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
126
|
-
ai_edge_torch/generative/test/test_loader.py,sha256=
|
127
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256
|
128
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
124
|
+
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
125
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=-qB-JEIfPFNlpGyJA1TYo_5fawTdyf1C6ee8cP4kYOY,5530
|
126
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bVCm_mubuGszCBON6oRjQXcBgPZqlVmmOaLWwhZJLio,9060
|
129
127
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
130
128
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
131
129
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
132
130
|
ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0bXq4X3McI5i-qWraieE,2987
|
133
131
|
ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
|
132
|
+
ai_edge_torch/generative/utilities/model_builder.py,sha256=89jt80UUfDzYBi-x077HBavWeuNJuYPXym9fiKCY1Tk,5278
|
134
133
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
135
134
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
136
135
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
@@ -181,8 +180,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
181
180
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
182
181
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
183
182
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
184
|
-
ai_edge_torch_nightly-0.3.0.
|
185
|
-
ai_edge_torch_nightly-0.3.0.
|
186
|
-
ai_edge_torch_nightly-0.3.0.
|
187
|
-
ai_edge_torch_nightly-0.3.0.
|
188
|
-
ai_edge_torch_nightly-0.3.0.
|
183
|
+
ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
184
|
+
ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/METADATA,sha256=a6Q1LozCx-4NWkm1EKZJFeCJTYiTNUSigoVwRevV0oc,1897
|
185
|
+
ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
186
|
+
ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
187
|
+
ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/RECORD,,
|