ai-edge-torch-nightly 0.3.0.dev20240926__py3-none-any.whl → 0.3.0.dev20240929__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +2 -8
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +0 -1
- ai_edge_torch/generative/examples/gemma/verify_util.py +13 -24
- ai_edge_torch/generative/examples/llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/llama.py +204 -0
- ai_edge_torch/generative/examples/llama/verify.py +73 -0
- ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
- ai_edge_torch/generative/examples/openelm/verify.py +14 -7
- ai_edge_torch/generative/examples/phi/phi2.py +2 -6
- ai_edge_torch/generative/examples/phi/phi3.py +17 -24
- ai_edge_torch/generative/examples/phi/verify.py +8 -9
- ai_edge_torch/generative/examples/phi/verify_phi3.py +8 -9
- ai_edge_torch/generative/examples/qwen/__init__.py +14 -0
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +81 -0
- ai_edge_torch/generative/examples/qwen/qwen.py +141 -0
- ai_edge_torch/generative/examples/qwen/verify.py +88 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
- ai_edge_torch/generative/examples/smollm/verify.py +14 -6
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +2 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +2 -0
- ai_edge_torch/generative/examples/t5/t5.py +0 -2
- ai_edge_torch/generative/examples/test_models/toy_model.py +5 -10
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -5
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
- ai_edge_torch/generative/examples/tiny_llama/verify.py +15 -7
- ai_edge_torch/generative/layers/model_config.py +2 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +20 -0
- ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
- ai_edge_torch/generative/utilities/verifier.py +117 -97
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/RECORD +40 -29
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ import pathlib
|
|
20
20
|
from absl import app
|
21
21
|
from absl import flags
|
22
22
|
from ai_edge_torch.generative.examples.openelm import openelm
|
23
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
24
|
from ai_edge_torch.generative.utilities import verifier
|
24
25
|
import transformers
|
25
26
|
|
@@ -29,15 +30,18 @@ _PROMPTS = flags.DEFINE_multi_string(
|
|
29
30
|
"What is the meaning of life?",
|
30
31
|
"The input prompts to generate answers.",
|
31
32
|
)
|
33
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
34
|
+
"max_new_tokens",
|
35
|
+
30,
|
36
|
+
"The maximum size of the generated tokens.",
|
37
|
+
)
|
32
38
|
|
33
39
|
|
34
40
|
def main(_):
|
35
41
|
checkpoint = "apple/OpenELM-3B"
|
36
42
|
logging.info("Loading the original model from: %s", checkpoint)
|
37
|
-
|
38
|
-
|
39
|
-
checkpoint, trust_remote_code=True
|
40
|
-
),
|
43
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
44
|
+
checkpoint, trust_remote_code=True
|
41
45
|
)
|
42
46
|
|
43
47
|
# Locate the cached dir.
|
@@ -53,10 +57,13 @@ def main(_):
|
|
53
57
|
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
|
54
58
|
|
55
59
|
verifier.verify_reauthored_model(
|
56
|
-
original_model=
|
57
|
-
|
58
|
-
|
60
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
61
|
+
original_model
|
62
|
+
),
|
63
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
64
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
59
65
|
generate_prompts=_PROMPTS.value,
|
66
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
60
67
|
)
|
61
68
|
|
62
69
|
|
@@ -65,15 +65,10 @@ class Phi2(nn.Module):
|
|
65
65
|
self.rope_cache = attn_utils.build_rope_cache(
|
66
66
|
size=config.kv_cache_max,
|
67
67
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
68
|
-
base=
|
69
|
-
condense_ratio=1,
|
70
|
-
dtype=torch.float32,
|
71
|
-
device=torch.device("cpu"),
|
68
|
+
base=attn_config.rotary_base,
|
72
69
|
)
|
73
70
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
74
71
|
size=config.kv_cache_max,
|
75
|
-
dtype=torch.float32,
|
76
|
-
device=torch.device("cpu"),
|
77
72
|
)
|
78
73
|
self.config = config
|
79
74
|
|
@@ -129,6 +124,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
129
124
|
num_heads=32,
|
130
125
|
head_dim=80,
|
131
126
|
num_query_groups=32,
|
127
|
+
rotary_base=10000,
|
132
128
|
rotary_percentage=0.4,
|
133
129
|
qkv_use_bias=True,
|
134
130
|
output_proj_use_bias=True,
|
@@ -97,15 +97,15 @@ ROPE_SHORT_FACTOR = [
|
|
97
97
|
]
|
98
98
|
|
99
99
|
|
100
|
-
def
|
100
|
+
def _build_rope_cache(
|
101
101
|
size: int,
|
102
102
|
dim: int,
|
103
|
-
base: int
|
104
|
-
condense_ratio: int
|
105
|
-
dtype: torch.dtype
|
106
|
-
device: torch.device
|
107
|
-
theta_factors: torch.Tensor
|
108
|
-
scale: float
|
103
|
+
base: int,
|
104
|
+
condense_ratio: int,
|
105
|
+
dtype: torch.dtype,
|
106
|
+
device: torch.device,
|
107
|
+
theta_factors: torch.Tensor,
|
108
|
+
scale: float,
|
109
109
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
110
110
|
"""Precomputes Rotary Positional Embeddings for Phi-3.5 model.
|
111
111
|
|
@@ -116,26 +116,20 @@ def build_rope_cache(
|
|
116
116
|
Args:
|
117
117
|
size (int): The size of the built cache.
|
118
118
|
dim (int): Each sequence's dimmension.
|
119
|
-
base (int, optional): Rope base value.
|
119
|
+
base (int, optional): Rope base value.
|
120
120
|
condense_ratio (int, optional): The ratio by which sequence indicies are
|
121
|
-
condensed.
|
122
|
-
dtype (torch.dtype, optional): Output tensor's data type.
|
123
|
-
|
124
|
-
device (torch.device, optional): Output tensor's data type. Defaults to
|
125
|
-
None in which case "cpu" is used.
|
121
|
+
condensed.
|
122
|
+
dtype (torch.dtype, optional): Output tensor's data type.
|
123
|
+
device (torch.device, optional): Output tensor's data type.
|
126
124
|
theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
|
127
|
-
scale the theta values.
|
128
|
-
scale (float, optional): A float used to scale the rope values.
|
129
|
-
to 1.0.
|
125
|
+
scale the theta values.
|
126
|
+
scale (float, optional): A float used to scale the rope values.
|
130
127
|
|
131
128
|
Returns:
|
132
129
|
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
|
133
130
|
"""
|
134
|
-
if device is None:
|
135
|
-
device = torch.device('cpu')
|
136
131
|
theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
137
|
-
|
138
|
-
theta = theta / theta_factors
|
132
|
+
theta = theta / theta_factors
|
139
133
|
seq_idx = torch.arange(size) / condense_ratio
|
140
134
|
idx_theta = torch.outer(seq_idx, theta)
|
141
135
|
cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
|
@@ -167,10 +161,10 @@ class Phi3_5Mini(nn.Module):
|
|
167
161
|
config.final_norm_config,
|
168
162
|
)
|
169
163
|
attn_config = block_config.attn_config
|
170
|
-
self.rope_cache =
|
164
|
+
self.rope_cache = _build_rope_cache(
|
171
165
|
size=config.kv_cache_max,
|
172
166
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
173
|
-
base=
|
167
|
+
base=attn_config.rotary_base,
|
174
168
|
condense_ratio=1,
|
175
169
|
dtype=torch.float32,
|
176
170
|
device=torch.device("cpu"),
|
@@ -181,8 +175,6 @@ class Phi3_5Mini(nn.Module):
|
|
181
175
|
)
|
182
176
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
183
177
|
size=config.kv_cache_max,
|
184
|
-
dtype=torch.float32,
|
185
|
-
device=torch.device("cpu"),
|
186
178
|
)
|
187
179
|
self.config = config
|
188
180
|
|
@@ -238,6 +230,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
238
230
|
num_heads=32,
|
239
231
|
head_dim=96,
|
240
232
|
num_query_groups=32,
|
233
|
+
rotary_base=10000,
|
241
234
|
rotary_percentage=1.0,
|
242
235
|
qkv_transpose_before_split=True,
|
243
236
|
)
|
@@ -19,6 +19,7 @@ import logging
|
|
19
19
|
from absl import app
|
20
20
|
from absl import flags
|
21
21
|
from ai_edge_torch.generative.examples.phi import phi2
|
22
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
22
23
|
from ai_edge_torch.generative.utilities import verifier
|
23
24
|
import kagglehub
|
24
25
|
import transformers
|
@@ -39,12 +40,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
39
40
|
def main(_):
|
40
41
|
checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
|
41
42
|
logging.info("Loading the original model from: %s", checkpoint)
|
42
|
-
|
43
|
-
generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
|
44
|
-
wrapper_model = verifier.ModelWrapper(
|
45
|
-
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
|
46
|
-
hf_generation_config=generation_config,
|
47
|
-
)
|
43
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
48
44
|
|
49
45
|
logging.info("Building the reauthored model from: %s", checkpoint)
|
50
46
|
reauthored_model = phi2.build_model(checkpoint)
|
@@ -53,10 +49,13 @@ def main(_):
|
|
53
49
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
54
50
|
|
55
51
|
verifier.verify_reauthored_model(
|
56
|
-
original_model=
|
57
|
-
|
58
|
-
|
52
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
53
|
+
original_model
|
54
|
+
),
|
55
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
56
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
59
57
|
generate_prompts=_PROMPTS.value,
|
58
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
60
59
|
atol=1e-03,
|
61
60
|
)
|
62
61
|
|
@@ -21,6 +21,7 @@ import pathlib
|
|
21
21
|
from absl import app
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.phi import phi3
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
24
25
|
from ai_edge_torch.generative.utilities import verifier
|
25
26
|
import transformers
|
26
27
|
|
@@ -40,12 +41,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
40
41
|
def main(_):
|
41
42
|
checkpoint = "microsoft/Phi-3.5-mini-instruct"
|
42
43
|
logging.info("Loading the original model from: %s", checkpoint)
|
43
|
-
|
44
|
-
generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
|
45
|
-
wrapper_model = verifier.ModelWrapper(
|
46
|
-
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
|
47
|
-
hf_generation_config=generation_config,
|
48
|
-
)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
49
45
|
|
50
46
|
# Locate the cached dir.
|
51
47
|
cached_config_file = transformers.utils.cached_file(
|
@@ -59,10 +55,13 @@ def main(_):
|
|
59
55
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
60
56
|
|
61
57
|
verifier.verify_reauthored_model(
|
62
|
-
original_model=
|
63
|
-
|
64
|
-
|
58
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
59
|
+
original_model
|
60
|
+
),
|
61
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
62
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
65
63
|
generate_prompts=_PROMPTS.value,
|
64
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
66
65
|
)
|
67
66
|
|
68
67
|
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -0,0 +1,81 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting Qwen 2.5 models to multi-signature tflite model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.qwen import qwen
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
25
|
+
|
26
|
+
_MODEL_SIZE = flags.DEFINE_enum(
|
27
|
+
'model_size',
|
28
|
+
'3b',
|
29
|
+
['0.5b', '1.5b', '3b'],
|
30
|
+
'The size of the model to convert.',
|
31
|
+
)
|
32
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
33
|
+
'checkpoint_path',
|
34
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
|
35
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
36
|
+
)
|
37
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
38
|
+
'tflite_path',
|
39
|
+
'/tmp/',
|
40
|
+
'The tflite file path to export.',
|
41
|
+
)
|
42
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
43
|
+
'prefill_seq_len',
|
44
|
+
1024,
|
45
|
+
'The maximum size of prefill input tensor.',
|
46
|
+
)
|
47
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
48
|
+
'kv_cache_max_len',
|
49
|
+
1280,
|
50
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
51
|
+
)
|
52
|
+
_QUANTIZE = flags.DEFINE_bool(
|
53
|
+
'quantize',
|
54
|
+
True,
|
55
|
+
'Whether the model should be quantized.',
|
56
|
+
)
|
57
|
+
|
58
|
+
_BUILDER = {
|
59
|
+
'0.5b': qwen.build_0_5b_model,
|
60
|
+
'1.5b': qwen.build_1_5b_model,
|
61
|
+
'3b': qwen.build_3b_model,
|
62
|
+
}
|
63
|
+
|
64
|
+
|
65
|
+
def main(_):
|
66
|
+
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
67
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
68
|
+
)
|
69
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
70
|
+
model_size = _MODEL_SIZE.value.replace('.', '_')
|
71
|
+
output_filename = f'qwen_{model_size}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
72
|
+
converter.convert_to_tflite(
|
73
|
+
pytorch_model,
|
74
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
75
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
76
|
+
quantize=_QUANTIZE.value,
|
77
|
+
)
|
78
|
+
|
79
|
+
|
80
|
+
if __name__ == '__main__':
|
81
|
+
app.run(main)
|
@@ -0,0 +1,141 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building Qwen 2.5 models."""
|
17
|
+
|
18
|
+
import copy
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
21
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
22
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
23
|
+
from torch import nn
|
24
|
+
|
25
|
+
TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
|
26
|
+
# Qwen re-uses the embedding as the head projection layer.
|
27
|
+
TENSOR_NAMES.lm_head = None
|
28
|
+
|
29
|
+
|
30
|
+
class Qwen(tiny_llama.TinyLlama):
|
31
|
+
"""A Qwen model built from the Edge Generative API layers.
|
32
|
+
|
33
|
+
Qwen 2.5 shares the same architecture as TinyLlama.
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(self, config: cfg.ModelConfig):
|
37
|
+
super().__init__(config)
|
38
|
+
# Qwen re-uses the embedding as the head projection layer.
|
39
|
+
self.lm_head.weight.data = self.tok_embedding.weight.data
|
40
|
+
|
41
|
+
|
42
|
+
def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
43
|
+
"""Returns the model config for a Qwen 2.5 3B model.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
47
|
+
is 1024.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
The model config for a SmolLM model.
|
51
|
+
"""
|
52
|
+
attn_config = cfg.AttentionConfig(
|
53
|
+
num_heads=16,
|
54
|
+
head_dim=128,
|
55
|
+
num_query_groups=2,
|
56
|
+
rotary_base=1000000,
|
57
|
+
rotary_percentage=1.0,
|
58
|
+
qkv_use_bias=True,
|
59
|
+
)
|
60
|
+
ff_config = cfg.FeedForwardConfig(
|
61
|
+
type=cfg.FeedForwardType.GATED,
|
62
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
63
|
+
intermediate_size=11008,
|
64
|
+
)
|
65
|
+
norm_config = cfg.NormalizationConfig(
|
66
|
+
type=cfg.NormalizationType.RMS_NORM,
|
67
|
+
epsilon=1e-06,
|
68
|
+
)
|
69
|
+
block_config = cfg.TransformerBlockConfig(
|
70
|
+
attn_config=attn_config,
|
71
|
+
ff_config=ff_config,
|
72
|
+
pre_attention_norm_config=norm_config,
|
73
|
+
post_attention_norm_config=norm_config,
|
74
|
+
)
|
75
|
+
config = cfg.ModelConfig(
|
76
|
+
vocab_size=151936,
|
77
|
+
num_layers=36,
|
78
|
+
max_seq_len=32768,
|
79
|
+
embedding_dim=2048,
|
80
|
+
kv_cache_max_len=kv_cache_max_len,
|
81
|
+
block_configs=block_config,
|
82
|
+
final_norm_config=norm_config,
|
83
|
+
enable_hlfb=True,
|
84
|
+
)
|
85
|
+
return config
|
86
|
+
|
87
|
+
|
88
|
+
def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
89
|
+
"""Returns the model config for a Qwen 2.5 1B model."""
|
90
|
+
config = get_3b_model_config(kv_cache_max_len)
|
91
|
+
# Qwen has only one block config.
|
92
|
+
block_config = config.block_config(0)
|
93
|
+
block_config.attn_config.num_heads = 12
|
94
|
+
block_config.ff_config.intermediate_size = 8960
|
95
|
+
config.num_layers = 28
|
96
|
+
config.embedding_dim = 1536
|
97
|
+
return config
|
98
|
+
|
99
|
+
|
100
|
+
def get_0_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
101
|
+
"""Returns the model config for a Qwen 2.5 0.5B model."""
|
102
|
+
config = get_3b_model_config(kv_cache_max_len)
|
103
|
+
# Qwen has only one block config.
|
104
|
+
block_config = config.block_config(0)
|
105
|
+
block_config.attn_config.num_heads = 14
|
106
|
+
block_config.attn_config.head_dim = 64
|
107
|
+
block_config.ff_config.intermediate_size = 4864
|
108
|
+
config.num_layers = 24
|
109
|
+
config.embedding_dim = 896
|
110
|
+
return config
|
111
|
+
|
112
|
+
|
113
|
+
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
114
|
+
config = get_3b_model_config(**kwargs)
|
115
|
+
config.vocab_size = 128
|
116
|
+
config.num_layers = 2
|
117
|
+
# Qwen has only one block config.
|
118
|
+
config.block_config(0).ff_config.intermediate_size = 64
|
119
|
+
return config
|
120
|
+
|
121
|
+
|
122
|
+
def _build_model(checkpoint_path: str, config: cfg.ModelConfig) -> nn.Module:
|
123
|
+
model = Qwen(config)
|
124
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
125
|
+
# Since embedding and lm-head use the same weight, we need to set strict
|
126
|
+
# to False.
|
127
|
+
loader.load(model, strict=False)
|
128
|
+
model.eval()
|
129
|
+
return model
|
130
|
+
|
131
|
+
|
132
|
+
def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
133
|
+
return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
|
134
|
+
|
135
|
+
|
136
|
+
def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
137
|
+
return _build_model(checkpoint_path, get_1_5b_model_config(**kwargs))
|
138
|
+
|
139
|
+
|
140
|
+
def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
141
|
+
return _build_model(checkpoint_path, get_0_5b_model_config(**kwargs))
|
@@ -0,0 +1,88 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored Qwen 2.5 0.5B, 1.5B, and 3B models."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.qwen import qwen
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
+
from ai_edge_torch.generative.utilities import verifier
|
26
|
+
import transformers
|
27
|
+
|
28
|
+
|
29
|
+
_MODEL_SIZE = flags.DEFINE_enum(
|
30
|
+
"model_size",
|
31
|
+
"3b",
|
32
|
+
["0.5b", "1.5b", "3b"],
|
33
|
+
"The size of the model to verify.",
|
34
|
+
)
|
35
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
36
|
+
"prompts",
|
37
|
+
"What is the meaning of life?",
|
38
|
+
"The input prompts to generate answers.",
|
39
|
+
)
|
40
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
41
|
+
"max_new_tokens",
|
42
|
+
30,
|
43
|
+
"The maximum size of the generated tokens.",
|
44
|
+
)
|
45
|
+
|
46
|
+
_CHECKPOINT = {
|
47
|
+
"0.5b": "Qwen/Qwen2.5-0.5B-Instruct",
|
48
|
+
"1.5b": "Qwen/Qwen2.5-1.5B-Instruct",
|
49
|
+
"3b": "Qwen/Qwen2.5-3B-Instruct",
|
50
|
+
}
|
51
|
+
|
52
|
+
_BUILDER = {
|
53
|
+
"0.5b": qwen.build_0_5b_model,
|
54
|
+
"1.5b": qwen.build_1_5b_model,
|
55
|
+
"3b": qwen.build_3b_model,
|
56
|
+
}
|
57
|
+
|
58
|
+
|
59
|
+
def main(_):
|
60
|
+
checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
|
61
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
62
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
63
|
+
|
64
|
+
# Locate the cached dir.
|
65
|
+
cached_config_file = transformers.utils.cached_file(
|
66
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
67
|
+
)
|
68
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
69
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
70
|
+
reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
|
71
|
+
|
72
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
73
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
74
|
+
|
75
|
+
verifier.verify_reauthored_model(
|
76
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
77
|
+
original_model
|
78
|
+
),
|
79
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
80
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
81
|
+
generate_prompts=_PROMPTS.value,
|
82
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
83
|
+
atol=1e-04,
|
84
|
+
)
|
85
|
+
|
86
|
+
|
87
|
+
if __name__ == "__main__":
|
88
|
+
app.run(main)
|
@@ -21,6 +21,7 @@ import pathlib
|
|
21
21
|
from absl import app
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.smollm import smollm
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
24
25
|
from ai_edge_torch.generative.utilities import verifier
|
25
26
|
import transformers
|
26
27
|
|
@@ -30,14 +31,18 @@ _PROMPTS = flags.DEFINE_multi_string(
|
|
30
31
|
"What is the meaning of life?",
|
31
32
|
"The input prompts to generate answers.",
|
32
33
|
)
|
34
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
35
|
+
"max_new_tokens",
|
36
|
+
30,
|
37
|
+
"The maximum size of the generated tokens.",
|
38
|
+
)
|
33
39
|
|
34
40
|
|
35
41
|
def main(_):
|
36
42
|
checkpoint = "HuggingFaceTB/SmolLM-135M"
|
37
43
|
logging.info("Loading the original model from: %s", checkpoint)
|
38
|
-
|
39
|
-
|
40
|
-
)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
|
+
|
41
46
|
# Locate the cached dir.
|
42
47
|
cached_config_file = transformers.utils.cached_file(
|
43
48
|
checkpoint, transformers.utils.CONFIG_NAME
|
@@ -50,10 +55,13 @@ def main(_):
|
|
50
55
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
51
56
|
|
52
57
|
verifier.verify_reauthored_model(
|
53
|
-
original_model=
|
54
|
-
|
55
|
-
|
58
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
59
|
+
original_model
|
60
|
+
),
|
61
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
62
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
56
63
|
generate_prompts=_PROMPTS.value,
|
64
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
57
65
|
atol=1e-04,
|
58
66
|
)
|
59
67
|
|
@@ -98,6 +98,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
98
98
|
num_heads=num_heads,
|
99
99
|
head_dim=embedding_dim // num_heads,
|
100
100
|
num_query_groups=num_query_groups,
|
101
|
+
rotary_base=0,
|
101
102
|
rotary_percentage=0.0,
|
102
103
|
qkv_use_bias=True,
|
103
104
|
qkv_transpose_before_split=True,
|
@@ -148,6 +149,7 @@ def get_fake_model_config() -> cfg.ModelConfig:
|
|
148
149
|
num_heads=num_heads,
|
149
150
|
head_dim=embedding_dim // num_heads,
|
150
151
|
num_query_groups=num_query_groups,
|
152
|
+
rotary_base=0,
|
151
153
|
rotary_percentage=0.0,
|
152
154
|
qkv_use_bias=True,
|
153
155
|
qkv_transpose_before_split=True,
|
@@ -295,6 +295,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
295
295
|
enable_kv_cache=False,
|
296
296
|
qkv_transpose_before_split=True,
|
297
297
|
qkv_fused_interleaved=False,
|
298
|
+
rotary_base=0,
|
298
299
|
rotary_percentage=0.0,
|
299
300
|
),
|
300
301
|
enable_hlfb=False,
|
@@ -351,6 +352,7 @@ def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
351
352
|
enable_kv_cache=False,
|
352
353
|
qkv_transpose_before_split=True,
|
353
354
|
qkv_fused_interleaved=False,
|
355
|
+
rotary_base=0,
|
354
356
|
rotary_percentage=0.0,
|
355
357
|
),
|
356
358
|
enable_hlfb=False,
|
@@ -199,6 +199,7 @@ def build_attention_config(
|
|
199
199
|
num_heads,
|
200
200
|
dim,
|
201
201
|
num_query_groups,
|
202
|
+
rotary_base=0,
|
202
203
|
rotary_percentage=0.0,
|
203
204
|
qkv_transpose_before_split=True,
|
204
205
|
qkv_use_bias=False,
|
@@ -211,6 +212,7 @@ def build_attention_config(
|
|
211
212
|
num_heads=num_heads,
|
212
213
|
head_dim=dim // num_heads,
|
213
214
|
num_query_groups=num_query_groups,
|
215
|
+
rotary_base=rotary_base,
|
214
216
|
rotary_percentage=rotary_percentage,
|
215
217
|
qkv_transpose_before_split=qkv_transpose_before_split,
|
216
218
|
qkv_use_bias=qkv_use_bias,
|