ai-edge-torch-nightly 0.5.0.dev20250517__py3-none-any.whl → 0.5.0.dev20250518__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/amd_llama_135m/verify.py +4 -32
- ai_edge_torch/generative/examples/amd_llama_135m/verify_util.py +76 -0
- ai_edge_torch/generative/examples/deepseek/verify.py +4 -30
- ai_edge_torch/generative/examples/deepseek/verify_util.py +76 -0
- ai_edge_torch/generative/examples/gemma/verify_util.py +51 -6
- ai_edge_torch/generative/examples/gemma3/verify_util.py +13 -0
- ai_edge_torch/generative/examples/hammer/verify.py +5 -35
- ai_edge_torch/generative/examples/hammer/verify_util.py +82 -0
- ai_edge_torch/generative/examples/llama/verify.py +5 -38
- ai_edge_torch/generative/examples/llama/verify_util.py +81 -0
- ai_edge_torch/generative/examples/openelm/verify.py +4 -31
- ai_edge_torch/generative/examples/openelm/verify_util.py +76 -0
- ai_edge_torch/generative/examples/phi/verify.py +6 -24
- ai_edge_torch/generative/examples/phi/verify_phi3.py +5 -28
- ai_edge_torch/generative/examples/phi/verify_phi4.py +5 -28
- ai_edge_torch/generative/examples/phi/verify_util.py +84 -0
- ai_edge_torch/generative/examples/qwen/verify.py +5 -35
- ai_edge_torch/generative/examples/qwen/verify_util.py +83 -0
- ai_edge_torch/generative/examples/smollm/verify.py +5 -36
- ai_edge_torch/generative/examples/smollm/verify_util.py +81 -0
- ai_edge_torch/generative/examples/tiny_llama/verify.py +4 -31
- ai_edge_torch/generative/examples/tiny_llama/verify_util.py +76 -0
- ai_edge_torch/generative/utilities/loader.py +11 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250517.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250517.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/RECORD +29 -20
- {ai_edge_torch_nightly-0.5.0.dev20250517.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250517.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250517.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/top_level.txt +0 -0
@@ -15,15 +15,9 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored AMD-Llama-135M model."""
|
17
17
|
|
18
|
-
import logging
|
19
|
-
import pathlib
|
20
|
-
|
21
18
|
from absl import app
|
22
19
|
from absl import flags
|
23
|
-
from ai_edge_torch.generative.examples.amd_llama_135m import
|
24
|
-
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
-
from ai_edge_torch.generative.utilities import verifier
|
26
|
-
import transformers
|
20
|
+
from ai_edge_torch.generative.examples.amd_llama_135m import verify_util
|
27
21
|
|
28
22
|
|
29
23
|
_PROMPTS = flags.DEFINE_multi_string(
|
@@ -39,32 +33,10 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
39
33
|
|
40
34
|
|
41
35
|
def main(_):
|
42
|
-
|
43
|
-
|
44
|
-
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
45
|
-
checkpoint, trust_remote_code=True
|
46
|
-
)
|
47
|
-
|
48
|
-
# Locate the cached dir.
|
49
|
-
cached_config_file = transformers.utils.cached_file(
|
50
|
-
checkpoint, transformers.utils.CONFIG_NAME
|
51
|
-
)
|
52
|
-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
53
|
-
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
54
|
-
reauthored_model = amd_llama_135m.build_model(str(reauthored_checkpoint))
|
55
|
-
|
56
|
-
logging.info("Loading the tokenizer from: %s", checkpoint)
|
57
|
-
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
58
|
-
|
59
|
-
verifier.verify_reauthored_model(
|
60
|
-
original_model=transformers_verifier.TransformersModelWrapper(
|
61
|
-
original_model
|
62
|
-
),
|
63
|
-
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
64
|
-
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
65
|
-
generate_prompts=_PROMPTS.value,
|
36
|
+
verify_util.verify_amd_llama_135m(
|
37
|
+
"amd/AMD-Llama-135m",
|
66
38
|
max_new_tokens=_MAX_NEW_TOKENS.value,
|
67
|
-
|
39
|
+
prompts=_PROMPTS.value,
|
68
40
|
)
|
69
41
|
|
70
42
|
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Utils for verifying the AMD-Llama-135M model."""
|
16
|
+
import logging
|
17
|
+
import os
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
|
21
|
+
from ai_edge_torch.generative.utilities import loader
|
22
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
|
27
|
+
DEFAULT_PROMPTS = ["Tell me a story?\nOnce upon a time"]
|
28
|
+
|
29
|
+
|
30
|
+
def verify_amd_llama_135m(
|
31
|
+
checkpoint_dir: str,
|
32
|
+
weight_filename: str = "model.safetensors",
|
33
|
+
max_new_tokens: int = 30,
|
34
|
+
initialize_from_local: bool = True,
|
35
|
+
prompts: list[str] | None = None,
|
36
|
+
) -> bool:
|
37
|
+
"""Verifies the reauthored AMD-Llama-135M model with a custom loader."""
|
38
|
+
logging.info("Loading the original model from: %s", checkpoint_dir)
|
39
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
40
|
+
checkpoint_dir
|
41
|
+
)
|
42
|
+
|
43
|
+
logging.info("Building the reauthored model from: %s", checkpoint_dir)
|
44
|
+
custom_loader = (
|
45
|
+
None
|
46
|
+
if initialize_from_local
|
47
|
+
else loader.get_custom_loader("", "safetensors")
|
48
|
+
)
|
49
|
+
|
50
|
+
if initialize_from_local:
|
51
|
+
# Locate the cached dir.
|
52
|
+
cached_config_file = transformers.utils.cached_file(
|
53
|
+
checkpoint_dir, transformers.utils.CONFIG_NAME
|
54
|
+
)
|
55
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
56
|
+
else:
|
57
|
+
reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
|
58
|
+
|
59
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
60
|
+
reauthored_model = amd_llama_135m.build_model(
|
61
|
+
checkpoint_path=reauthored_checkpoint,
|
62
|
+
custom_loader=custom_loader,
|
63
|
+
)
|
64
|
+
|
65
|
+
logging.info("Loading the tokenizer from: %s", checkpoint_dir)
|
66
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
|
67
|
+
return verifier.verify_reauthored_model(
|
68
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
69
|
+
original_model
|
70
|
+
),
|
71
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
72
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
73
|
+
generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
|
74
|
+
max_new_tokens=max_new_tokens,
|
75
|
+
atol=1e-04,
|
76
|
+
)
|
@@ -15,15 +15,9 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored DeepSeek R1 distilled 1.5B model."""
|
17
17
|
|
18
|
-
import logging
|
19
|
-
import pathlib
|
20
|
-
|
21
18
|
from absl import app
|
22
19
|
from absl import flags
|
23
|
-
from ai_edge_torch.generative.examples.deepseek import
|
24
|
-
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
-
from ai_edge_torch.generative.utilities import verifier
|
26
|
-
import transformers
|
20
|
+
from ai_edge_torch.generative.examples.deepseek import verify_util
|
27
21
|
|
28
22
|
|
29
23
|
_PROMPTS = flags.DEFINE_multi_string(
|
@@ -39,30 +33,10 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
39
33
|
|
40
34
|
|
41
35
|
def main(_):
|
42
|
-
|
43
|
-
|
44
|
-
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
|
-
|
46
|
-
# Locate the cached dir.
|
47
|
-
cached_config_file = transformers.utils.cached_file(
|
48
|
-
checkpoint, transformers.utils.CONFIG_NAME
|
49
|
-
)
|
50
|
-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
51
|
-
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
52
|
-
reauthored_model = deepseek.build_model(str(reauthored_checkpoint))
|
53
|
-
|
54
|
-
logging.info("Loading the tokenizer from: %s", checkpoint)
|
55
|
-
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
56
|
-
|
57
|
-
verifier.verify_reauthored_model(
|
58
|
-
original_model=transformers_verifier.TransformersModelWrapper(
|
59
|
-
original_model
|
60
|
-
),
|
61
|
-
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
62
|
-
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
63
|
-
generate_prompts=_PROMPTS.value,
|
36
|
+
verify_util.verify_deepseek_r1_distill_1_5b(
|
37
|
+
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
64
38
|
max_new_tokens=_MAX_NEW_TOKENS.value,
|
65
|
-
|
39
|
+
prompts=_PROMPTS.value,
|
66
40
|
)
|
67
41
|
|
68
42
|
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Utils for verifying the DeepSeek R1 distilled 1.5B model."""
|
16
|
+
import logging
|
17
|
+
import os
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.examples.deepseek import deepseek
|
21
|
+
from ai_edge_torch.generative.utilities import loader
|
22
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
|
27
|
+
DEFAULT_PROMPTS = ["What is the meaning of life?"]
|
28
|
+
|
29
|
+
|
30
|
+
def verify_deepseek_r1_distill_1_5b(
|
31
|
+
checkpoint_dir: str,
|
32
|
+
weight_filename: str = "model.safetensors",
|
33
|
+
max_new_tokens: int = 30,
|
34
|
+
initialize_from_local: bool = True,
|
35
|
+
prompts: list[str] | None = None,
|
36
|
+
) -> bool:
|
37
|
+
"""Verifies the reauthored DeepSeek R1 distilled 1.5B model with a custom loader."""
|
38
|
+
logging.info("Loading the original model from: %s", checkpoint_dir)
|
39
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
40
|
+
checkpoint_dir
|
41
|
+
)
|
42
|
+
|
43
|
+
logging.info("Building the reauthored model from: %s", checkpoint_dir)
|
44
|
+
custom_loader = (
|
45
|
+
None
|
46
|
+
if initialize_from_local
|
47
|
+
else loader.get_custom_loader("", "safetensors")
|
48
|
+
)
|
49
|
+
|
50
|
+
if initialize_from_local:
|
51
|
+
# Locate the cached dir.
|
52
|
+
cached_config_file = transformers.utils.cached_file(
|
53
|
+
checkpoint_dir, transformers.utils.CONFIG_NAME
|
54
|
+
)
|
55
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
56
|
+
else:
|
57
|
+
reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
|
58
|
+
|
59
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
60
|
+
reauthored_model = deepseek.build_model(
|
61
|
+
checkpoint_path=reauthored_checkpoint,
|
62
|
+
custom_loader=custom_loader,
|
63
|
+
)
|
64
|
+
|
65
|
+
logging.info("Loading the tokenizer from: %s", checkpoint_dir)
|
66
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
|
67
|
+
return verifier.verify_reauthored_model(
|
68
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
69
|
+
original_model
|
70
|
+
),
|
71
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
72
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
73
|
+
generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
|
74
|
+
max_new_tokens=max_new_tokens,
|
75
|
+
atol=1e-04,
|
76
|
+
)
|
@@ -17,11 +17,13 @@
|
|
17
17
|
|
18
18
|
import logging
|
19
19
|
import os
|
20
|
-
from typing import List, Tuple
|
20
|
+
from typing import Callable, Dict, List, Tuple
|
21
21
|
|
22
|
+
from ai_edge_torch.generative.examples.gemma import gemma1
|
22
23
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
23
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
25
|
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
26
|
+
from ai_edge_torch.generative.utilities import loader
|
25
27
|
from ai_edge_torch.generative.utilities import verifier
|
26
28
|
from gemma import config as gemma_config
|
27
29
|
from gemma import model as gemma_model
|
@@ -107,6 +109,7 @@ def verify_reauthored_gemma_model(
|
|
107
109
|
generate_prompts: List[str],
|
108
110
|
forward_input_ids: List[List[int]],
|
109
111
|
weight_filename: str = "model.ckpt",
|
112
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
|
110
113
|
tokenizer_filename: str = "tokenizer.model",
|
111
114
|
max_new_tokens: int = 20,
|
112
115
|
mask_as_input: bool = False,
|
@@ -125,7 +128,14 @@ def verify_reauthored_gemma_model(
|
|
125
128
|
|
126
129
|
logging.info("Loading the original model from: %s", checkpoint)
|
127
130
|
original_model = gemma_model.GemmaForCausalLM(config).eval()
|
128
|
-
|
131
|
+
checkpoint_path = os.path.join(checkpoint, weight_filename)
|
132
|
+
if custom_loader is None:
|
133
|
+
original_model.load_weights(checkpoint_path)
|
134
|
+
else:
|
135
|
+
original_model.load_state_dict(
|
136
|
+
custom_loader(checkpoint_path)["model_state_dict"],
|
137
|
+
strict=False,
|
138
|
+
)
|
129
139
|
|
130
140
|
return verifier.verify_reauthored_model(
|
131
141
|
original_model=GemmaWrapper(original_model),
|
@@ -144,27 +154,62 @@ def verify_reauthored_gemma_model(
|
|
144
154
|
|
145
155
|
|
146
156
|
def verify_gemma2(
|
147
|
-
|
157
|
+
checkpoint_dir: str,
|
158
|
+
weight_filename: str,
|
148
159
|
prompts: List[str],
|
149
160
|
max_new_tokens: int,
|
150
161
|
mask_as_input: bool = False,
|
151
162
|
kv_layout: kv_utils.KVLayout = kv_utils.KV_LAYOUT_DEFAULT,
|
163
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
|
152
164
|
) -> bool:
|
153
165
|
"""Verifies the reauthored Gemma2 model.
|
154
166
|
|
155
167
|
Return True if the verification passes, False otherwise.
|
156
168
|
"""
|
157
|
-
|
158
|
-
|
169
|
+
checkpoint_path = os.path.join(checkpoint_dir, weight_filename)
|
170
|
+
logging.info("Building the reauthored model from: %s", checkpoint_path)
|
171
|
+
reauthored_model = gemma2.build_2b_model(checkpoint_path, custom_loader)
|
159
172
|
|
160
173
|
return verify_reauthored_gemma_model(
|
161
|
-
checkpoint=
|
174
|
+
checkpoint=checkpoint_dir,
|
162
175
|
variant="2b-v2",
|
163
176
|
reauthored_model=reauthored_model,
|
164
177
|
generate_prompts=prompts,
|
165
178
|
forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
|
179
|
+
weight_filename=weight_filename,
|
180
|
+
custom_loader=custom_loader,
|
166
181
|
max_new_tokens=max_new_tokens,
|
167
182
|
mask_as_input=mask_as_input,
|
168
183
|
kv_layout=kv_layout,
|
169
184
|
atol=1e-04,
|
170
185
|
)
|
186
|
+
|
187
|
+
|
188
|
+
def verify_gemma1_with_custom_loader(checkpoint_dir: str) -> bool:
|
189
|
+
"""Verifies the reauthored Gemma1 model with a custom loader."""
|
190
|
+
weight_filename = "gemma-2b-it.ckpt"
|
191
|
+
checkpoint_path = os.path.join(checkpoint_dir, weight_filename)
|
192
|
+
custom_loader = loader.get_custom_loader(checkpoint_path)
|
193
|
+
reauthored_model = gemma1.build_2b_model(checkpoint_path, custom_loader)
|
194
|
+
return verify_reauthored_gemma_model(
|
195
|
+
checkpoint=checkpoint_dir,
|
196
|
+
variant="2b",
|
197
|
+
reauthored_model=reauthored_model,
|
198
|
+
weight_filename=weight_filename,
|
199
|
+
custom_loader=custom_loader,
|
200
|
+
generate_prompts=["What is the meaning of life?"],
|
201
|
+
forward_input_ids=[[1, 2, 3, 4]],
|
202
|
+
max_new_tokens=30,
|
203
|
+
)
|
204
|
+
|
205
|
+
|
206
|
+
def verify_gemma2_with_custom_loader(checkpoint_dir: str) -> bool:
|
207
|
+
"""Verifies the reauthored Gemma2 model with a custom loader."""
|
208
|
+
return verify_gemma2(
|
209
|
+
checkpoint_dir=checkpoint_dir,
|
210
|
+
weight_filename="model.ckpt",
|
211
|
+
prompts=["What is the meaning of life?"],
|
212
|
+
max_new_tokens=30,
|
213
|
+
mask_as_input=True,
|
214
|
+
custom_loader=loader.get_custom_loader("", checkpoint_format="pt"),
|
215
|
+
)
|
@@ -22,6 +22,7 @@ from typing import Callable, Dict, List, Optional, Tuple
|
|
22
22
|
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
23
23
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
|
+
from ai_edge_torch.generative.utilities import loader
|
25
26
|
from ai_edge_torch.generative.utilities import verifier
|
26
27
|
from gemma import config as gemma_config
|
27
28
|
from gemma import model as gemma_model
|
@@ -260,3 +261,15 @@ def verify_gemma3(
|
|
260
261
|
custom_loader=custom_loader,
|
261
262
|
atol=1e-04,
|
262
263
|
)
|
264
|
+
|
265
|
+
|
266
|
+
def verify_gemma3_with_custom_loader(checkpoint: str) -> bool:
|
267
|
+
"""Verifies the reauthored Gemma3 model with a custom loader."""
|
268
|
+
return verify_gemma3(
|
269
|
+
checkpoint=checkpoint,
|
270
|
+
prompts=["What is the meaning of life?"],
|
271
|
+
max_new_tokens=30,
|
272
|
+
variant="1b",
|
273
|
+
weight_filename="model.ckpt",
|
274
|
+
custom_loader=loader.get_custom_loader("", checkpoint_format="pt"),
|
275
|
+
)
|
@@ -15,15 +15,9 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored Hammer 2.1 0.5B and 1.5B models."""
|
17
17
|
|
18
|
-
import logging
|
19
|
-
import pathlib
|
20
|
-
|
21
18
|
from absl import app
|
22
19
|
from absl import flags
|
23
|
-
from ai_edge_torch.generative.examples.hammer import
|
24
|
-
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
-
from ai_edge_torch.generative.utilities import verifier
|
26
|
-
import transformers
|
20
|
+
from ai_edge_torch.generative.examples.hammer import verify_util
|
27
21
|
|
28
22
|
|
29
23
|
_MODEL_SIZE = flags.DEFINE_enum(
|
@@ -48,37 +42,13 @@ _CHECKPOINT = {
|
|
48
42
|
"1.5b": "MadeAgents/Hammer2.1-1.5b",
|
49
43
|
}
|
50
44
|
|
51
|
-
_BUILDER = {
|
52
|
-
"0.5b": hammer.build_0_5b_model,
|
53
|
-
"1.5b": hammer.build_1_5b_model,
|
54
|
-
}
|
55
|
-
|
56
45
|
|
57
46
|
def main(_):
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
# Locate the cached dir.
|
63
|
-
cached_config_file = transformers.utils.cached_file(
|
64
|
-
checkpoint, transformers.utils.CONFIG_NAME
|
65
|
-
)
|
66
|
-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
67
|
-
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
68
|
-
reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
|
69
|
-
|
70
|
-
logging.info("Loading the tokenizer from: %s", checkpoint)
|
71
|
-
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
72
|
-
|
73
|
-
verifier.verify_reauthored_model(
|
74
|
-
original_model=transformers_verifier.TransformersModelWrapper(
|
75
|
-
original_model
|
76
|
-
),
|
77
|
-
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
78
|
-
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
79
|
-
generate_prompts=_PROMPTS.value,
|
47
|
+
verify_util.verify_hammer(
|
48
|
+
model_size=_MODEL_SIZE.value,
|
49
|
+
checkpoint_dir=_CHECKPOINT[_MODEL_SIZE.value],
|
80
50
|
max_new_tokens=_MAX_NEW_TOKENS.value,
|
81
|
-
|
51
|
+
prompts=_PROMPTS.value,
|
82
52
|
)
|
83
53
|
|
84
54
|
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Utils for verifying the Hammer 2.1 model."""
|
16
|
+
import logging
|
17
|
+
import os
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.examples.hammer import hammer
|
21
|
+
from ai_edge_torch.generative.utilities import loader
|
22
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
|
27
|
+
_BUILDER = {
|
28
|
+
"0.5b": hammer.build_0_5b_model,
|
29
|
+
"1.5b": hammer.build_1_5b_model,
|
30
|
+
}
|
31
|
+
|
32
|
+
DEFAULT_PROMPTS = ["What is the meaning of life?"]
|
33
|
+
|
34
|
+
|
35
|
+
def verify_hammer(
|
36
|
+
model_size: str,
|
37
|
+
checkpoint_dir: str,
|
38
|
+
weight_filename: str = "model.safetensors",
|
39
|
+
max_new_tokens: int = 30,
|
40
|
+
initialize_from_local: bool = True,
|
41
|
+
prompts: list[str] | None = None,
|
42
|
+
) -> bool:
|
43
|
+
"""Verifies the reauthored Hammer 2.1 model with a custom loader."""
|
44
|
+
logging.info("Loading the original model from: %s", checkpoint_dir)
|
45
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
46
|
+
checkpoint_dir
|
47
|
+
)
|
48
|
+
|
49
|
+
logging.info("Building the reauthored model from: %s", checkpoint_dir)
|
50
|
+
custom_loader = (
|
51
|
+
None
|
52
|
+
if initialize_from_local
|
53
|
+
else loader.get_custom_loader("", "safetensors")
|
54
|
+
)
|
55
|
+
|
56
|
+
if initialize_from_local:
|
57
|
+
# Locate the cached dir.
|
58
|
+
cached_config_file = transformers.utils.cached_file(
|
59
|
+
checkpoint_dir, transformers.utils.CONFIG_NAME
|
60
|
+
)
|
61
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
62
|
+
else:
|
63
|
+
reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
|
64
|
+
|
65
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
66
|
+
reauthored_model = _BUILDER[model_size](
|
67
|
+
checkpoint_path=reauthored_checkpoint,
|
68
|
+
custom_loader=custom_loader,
|
69
|
+
)
|
70
|
+
|
71
|
+
logging.info("Loading the tokenizer from: %s", checkpoint_dir)
|
72
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
|
73
|
+
return verifier.verify_reauthored_model(
|
74
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
75
|
+
original_model
|
76
|
+
),
|
77
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
78
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
79
|
+
generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
|
80
|
+
max_new_tokens=max_new_tokens,
|
81
|
+
atol=1e-04,
|
82
|
+
)
|
@@ -15,15 +15,9 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored Llama 3.2-1B model."""
|
17
17
|
|
18
|
-
import logging
|
19
|
-
import pathlib
|
20
|
-
|
21
18
|
from absl import app
|
22
19
|
from absl import flags
|
23
|
-
from ai_edge_torch.generative.examples.llama import
|
24
|
-
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
-
from ai_edge_torch.generative.utilities import verifier
|
26
|
-
import transformers
|
20
|
+
from ai_edge_torch.generative.examples.llama import verify_util
|
27
21
|
|
28
22
|
_MODEL_SIZE = flags.DEFINE_enum(
|
29
23
|
"model_size",
|
@@ -47,40 +41,13 @@ _CHECKPOINT = {
|
|
47
41
|
"3b": "meta-llama/Llama-3.2-3B-Instruct",
|
48
42
|
}
|
49
43
|
|
50
|
-
_BUILDER = {
|
51
|
-
"1b": llama.build_1b_model,
|
52
|
-
"3b": llama.build_3b_model,
|
53
|
-
}
|
54
|
-
|
55
44
|
|
56
45
|
def main(_):
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
# Locate the cached dir.
|
62
|
-
cached_config_file = transformers.utils.cached_file(
|
63
|
-
checkpoint, transformers.utils.CONFIG_NAME
|
64
|
-
)
|
65
|
-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
66
|
-
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
67
|
-
reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
|
68
|
-
|
69
|
-
logging.info("Loading the tokenizer from: %s", checkpoint)
|
70
|
-
# Llama tokenizer_config.json sets a fast tokenizer class explicitly,
|
71
|
-
# "PreTrainedTokenizerFast". It works only when the fast tokenizer is
|
72
|
-
# available.
|
73
|
-
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
74
|
-
|
75
|
-
verifier.verify_reauthored_model(
|
76
|
-
original_model=transformers_verifier.TransformersModelWrapper(
|
77
|
-
original_model
|
78
|
-
),
|
79
|
-
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
80
|
-
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
81
|
-
generate_prompts=_PROMPTS.value,
|
46
|
+
verify_util.verify_llama_3_2(
|
47
|
+
model_size=_MODEL_SIZE.value,
|
48
|
+
checkpoint_dir=_CHECKPOINT[_MODEL_SIZE.value],
|
82
49
|
max_new_tokens=_MAX_NEW_TOKENS.value,
|
83
|
-
|
50
|
+
prompts=_PROMPTS.value,
|
84
51
|
)
|
85
52
|
|
86
53
|
|
@@ -0,0 +1,81 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Utils for verifying the Llama 3.2-1B model."""
|
16
|
+
import logging
|
17
|
+
import os
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.examples.llama import llama
|
21
|
+
from ai_edge_torch.generative.utilities import loader
|
22
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
_BUILDER = {
|
27
|
+
"1b": llama.build_1b_model,
|
28
|
+
"3b": llama.build_3b_model,
|
29
|
+
}
|
30
|
+
|
31
|
+
DEFAULT_PROMPTS = ["What is the meaning of life?"]
|
32
|
+
|
33
|
+
|
34
|
+
def verify_llama_3_2(
|
35
|
+
model_size: str,
|
36
|
+
checkpoint_dir: str,
|
37
|
+
weight_filename: str = "model.safetensors",
|
38
|
+
max_new_tokens: int = 30,
|
39
|
+
initialize_from_local: bool = True,
|
40
|
+
prompts: list[str] | None = None,
|
41
|
+
) -> bool:
|
42
|
+
"""Verifies the reauthored Llama 3.2 model with a custom loader."""
|
43
|
+
logging.info("Loading the original model from: %s", checkpoint_dir)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
45
|
+
checkpoint_dir
|
46
|
+
)
|
47
|
+
|
48
|
+
logging.info("Building the reauthored model from: %s", checkpoint_dir)
|
49
|
+
custom_loader = (
|
50
|
+
None
|
51
|
+
if initialize_from_local
|
52
|
+
else loader.get_custom_loader("", "safetensors")
|
53
|
+
)
|
54
|
+
|
55
|
+
if initialize_from_local:
|
56
|
+
# Locate the cached dir.
|
57
|
+
cached_config_file = transformers.utils.cached_file(
|
58
|
+
checkpoint_dir, transformers.utils.CONFIG_NAME
|
59
|
+
)
|
60
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
61
|
+
else:
|
62
|
+
reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
|
63
|
+
|
64
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
65
|
+
reauthored_model = _BUILDER[model_size](
|
66
|
+
checkpoint_path=reauthored_checkpoint,
|
67
|
+
custom_loader=custom_loader,
|
68
|
+
)
|
69
|
+
|
70
|
+
logging.info("Loading the tokenizer from: %s", checkpoint_dir)
|
71
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
|
72
|
+
return verifier.verify_reauthored_model(
|
73
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
74
|
+
original_model
|
75
|
+
),
|
76
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
77
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
78
|
+
generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
|
79
|
+
max_new_tokens=max_new_tokens,
|
80
|
+
atol=1e-04,
|
81
|
+
)
|