ai-edge-torch-nightly 0.5.0.dev20250517__py3-none-any.whl → 0.5.0.dev20250518__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. ai_edge_torch/generative/examples/amd_llama_135m/verify.py +4 -32
  2. ai_edge_torch/generative/examples/amd_llama_135m/verify_util.py +76 -0
  3. ai_edge_torch/generative/examples/deepseek/verify.py +4 -30
  4. ai_edge_torch/generative/examples/deepseek/verify_util.py +76 -0
  5. ai_edge_torch/generative/examples/gemma/verify_util.py +51 -6
  6. ai_edge_torch/generative/examples/gemma3/verify_util.py +13 -0
  7. ai_edge_torch/generative/examples/hammer/verify.py +5 -35
  8. ai_edge_torch/generative/examples/hammer/verify_util.py +82 -0
  9. ai_edge_torch/generative/examples/llama/verify.py +5 -38
  10. ai_edge_torch/generative/examples/llama/verify_util.py +81 -0
  11. ai_edge_torch/generative/examples/openelm/verify.py +4 -31
  12. ai_edge_torch/generative/examples/openelm/verify_util.py +76 -0
  13. ai_edge_torch/generative/examples/phi/verify.py +6 -24
  14. ai_edge_torch/generative/examples/phi/verify_phi3.py +5 -28
  15. ai_edge_torch/generative/examples/phi/verify_phi4.py +5 -28
  16. ai_edge_torch/generative/examples/phi/verify_util.py +84 -0
  17. ai_edge_torch/generative/examples/qwen/verify.py +5 -35
  18. ai_edge_torch/generative/examples/qwen/verify_util.py +83 -0
  19. ai_edge_torch/generative/examples/smollm/verify.py +5 -36
  20. ai_edge_torch/generative/examples/smollm/verify_util.py +81 -0
  21. ai_edge_torch/generative/examples/tiny_llama/verify.py +4 -31
  22. ai_edge_torch/generative/examples/tiny_llama/verify_util.py +76 -0
  23. ai_edge_torch/generative/utilities/loader.py +11 -1
  24. ai_edge_torch/version.py +1 -1
  25. {ai_edge_torch_nightly-0.5.0.dev20250517.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/METADATA +1 -1
  26. {ai_edge_torch_nightly-0.5.0.dev20250517.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/RECORD +29 -20
  27. {ai_edge_torch_nightly-0.5.0.dev20250517.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/LICENSE +0 -0
  28. {ai_edge_torch_nightly-0.5.0.dev20250517.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/WHEEL +0 -0
  29. {ai_edge_torch_nightly-0.5.0.dev20250517.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/top_level.txt +0 -0
@@ -15,15 +15,9 @@
15
15
 
16
16
  """Verifies the reauthored AMD-Llama-135M model."""
17
17
 
18
- import logging
19
- import pathlib
20
-
21
18
  from absl import app
22
19
  from absl import flags
23
- from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
20
+ from ai_edge_torch.generative.examples.amd_llama_135m import verify_util
27
21
 
28
22
 
29
23
  _PROMPTS = flags.DEFINE_multi_string(
@@ -39,32 +33,10 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
39
33
 
40
34
 
41
35
  def main(_):
42
- checkpoint = "amd/AMD-Llama-135m"
43
- logging.info("Loading the original model from: %s", checkpoint)
44
- original_model = transformers.AutoModelForCausalLM.from_pretrained(
45
- checkpoint, trust_remote_code=True
46
- )
47
-
48
- # Locate the cached dir.
49
- cached_config_file = transformers.utils.cached_file(
50
- checkpoint, transformers.utils.CONFIG_NAME
51
- )
52
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
53
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
54
- reauthored_model = amd_llama_135m.build_model(str(reauthored_checkpoint))
55
-
56
- logging.info("Loading the tokenizer from: %s", checkpoint)
57
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
58
-
59
- verifier.verify_reauthored_model(
60
- original_model=transformers_verifier.TransformersModelWrapper(
61
- original_model
62
- ),
63
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
- tokenizer=verifier.TokenizerWrapper(tokenizer),
65
- generate_prompts=_PROMPTS.value,
36
+ verify_util.verify_amd_llama_135m(
37
+ "amd/AMD-Llama-135m",
66
38
  max_new_tokens=_MAX_NEW_TOKENS.value,
67
- atol=1e-04,
39
+ prompts=_PROMPTS.value,
68
40
  )
69
41
 
70
42
 
@@ -0,0 +1,76 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utils for verifying the AMD-Llama-135M model."""
16
+ import logging
17
+ import os
18
+ import pathlib
19
+
20
+ from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
21
+ from ai_edge_torch.generative.utilities import loader
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+
27
+ DEFAULT_PROMPTS = ["Tell me a story?\nOnce upon a time"]
28
+
29
+
30
+ def verify_amd_llama_135m(
31
+ checkpoint_dir: str,
32
+ weight_filename: str = "model.safetensors",
33
+ max_new_tokens: int = 30,
34
+ initialize_from_local: bool = True,
35
+ prompts: list[str] | None = None,
36
+ ) -> bool:
37
+ """Verifies the reauthored AMD-Llama-135M model with a custom loader."""
38
+ logging.info("Loading the original model from: %s", checkpoint_dir)
39
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
40
+ checkpoint_dir
41
+ )
42
+
43
+ logging.info("Building the reauthored model from: %s", checkpoint_dir)
44
+ custom_loader = (
45
+ None
46
+ if initialize_from_local
47
+ else loader.get_custom_loader("", "safetensors")
48
+ )
49
+
50
+ if initialize_from_local:
51
+ # Locate the cached dir.
52
+ cached_config_file = transformers.utils.cached_file(
53
+ checkpoint_dir, transformers.utils.CONFIG_NAME
54
+ )
55
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
56
+ else:
57
+ reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
58
+
59
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
60
+ reauthored_model = amd_llama_135m.build_model(
61
+ checkpoint_path=reauthored_checkpoint,
62
+ custom_loader=custom_loader,
63
+ )
64
+
65
+ logging.info("Loading the tokenizer from: %s", checkpoint_dir)
66
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
67
+ return verifier.verify_reauthored_model(
68
+ original_model=transformers_verifier.TransformersModelWrapper(
69
+ original_model
70
+ ),
71
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
72
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
73
+ generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
74
+ max_new_tokens=max_new_tokens,
75
+ atol=1e-04,
76
+ )
@@ -15,15 +15,9 @@
15
15
 
16
16
  """Verifies the reauthored DeepSeek R1 distilled 1.5B model."""
17
17
 
18
- import logging
19
- import pathlib
20
-
21
18
  from absl import app
22
19
  from absl import flags
23
- from ai_edge_torch.generative.examples.deepseek import deepseek
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
20
+ from ai_edge_torch.generative.examples.deepseek import verify_util
27
21
 
28
22
 
29
23
  _PROMPTS = flags.DEFINE_multi_string(
@@ -39,30 +33,10 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
39
33
 
40
34
 
41
35
  def main(_):
42
- checkpoint = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
43
- logging.info("Loading the original model from: %s", checkpoint)
44
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
-
46
- # Locate the cached dir.
47
- cached_config_file = transformers.utils.cached_file(
48
- checkpoint, transformers.utils.CONFIG_NAME
49
- )
50
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
- reauthored_model = deepseek.build_model(str(reauthored_checkpoint))
53
-
54
- logging.info("Loading the tokenizer from: %s", checkpoint)
55
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
56
-
57
- verifier.verify_reauthored_model(
58
- original_model=transformers_verifier.TransformersModelWrapper(
59
- original_model
60
- ),
61
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
- tokenizer=verifier.TokenizerWrapper(tokenizer),
63
- generate_prompts=_PROMPTS.value,
36
+ verify_util.verify_deepseek_r1_distill_1_5b(
37
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
64
38
  max_new_tokens=_MAX_NEW_TOKENS.value,
65
- atol=1e-04,
39
+ prompts=_PROMPTS.value,
66
40
  )
67
41
 
68
42
 
@@ -0,0 +1,76 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utils for verifying the DeepSeek R1 distilled 1.5B model."""
16
+ import logging
17
+ import os
18
+ import pathlib
19
+
20
+ from ai_edge_torch.generative.examples.deepseek import deepseek
21
+ from ai_edge_torch.generative.utilities import loader
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+
27
+ DEFAULT_PROMPTS = ["What is the meaning of life?"]
28
+
29
+
30
+ def verify_deepseek_r1_distill_1_5b(
31
+ checkpoint_dir: str,
32
+ weight_filename: str = "model.safetensors",
33
+ max_new_tokens: int = 30,
34
+ initialize_from_local: bool = True,
35
+ prompts: list[str] | None = None,
36
+ ) -> bool:
37
+ """Verifies the reauthored DeepSeek R1 distilled 1.5B model with a custom loader."""
38
+ logging.info("Loading the original model from: %s", checkpoint_dir)
39
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
40
+ checkpoint_dir
41
+ )
42
+
43
+ logging.info("Building the reauthored model from: %s", checkpoint_dir)
44
+ custom_loader = (
45
+ None
46
+ if initialize_from_local
47
+ else loader.get_custom_loader("", "safetensors")
48
+ )
49
+
50
+ if initialize_from_local:
51
+ # Locate the cached dir.
52
+ cached_config_file = transformers.utils.cached_file(
53
+ checkpoint_dir, transformers.utils.CONFIG_NAME
54
+ )
55
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
56
+ else:
57
+ reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
58
+
59
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
60
+ reauthored_model = deepseek.build_model(
61
+ checkpoint_path=reauthored_checkpoint,
62
+ custom_loader=custom_loader,
63
+ )
64
+
65
+ logging.info("Loading the tokenizer from: %s", checkpoint_dir)
66
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
67
+ return verifier.verify_reauthored_model(
68
+ original_model=transformers_verifier.TransformersModelWrapper(
69
+ original_model
70
+ ),
71
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
72
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
73
+ generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
74
+ max_new_tokens=max_new_tokens,
75
+ atol=1e-04,
76
+ )
@@ -17,11 +17,13 @@
17
17
 
18
18
  import logging
19
19
  import os
20
- from typing import List, Tuple
20
+ from typing import Callable, Dict, List, Tuple
21
21
 
22
+ from ai_edge_torch.generative.examples.gemma import gemma1
22
23
  from ai_edge_torch.generative.examples.gemma import gemma2
23
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
25
  import ai_edge_torch.generative.layers.kv_cache as kv_utils
26
+ from ai_edge_torch.generative.utilities import loader
25
27
  from ai_edge_torch.generative.utilities import verifier
26
28
  from gemma import config as gemma_config
27
29
  from gemma import model as gemma_model
@@ -107,6 +109,7 @@ def verify_reauthored_gemma_model(
107
109
  generate_prompts: List[str],
108
110
  forward_input_ids: List[List[int]],
109
111
  weight_filename: str = "model.ckpt",
112
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
110
113
  tokenizer_filename: str = "tokenizer.model",
111
114
  max_new_tokens: int = 20,
112
115
  mask_as_input: bool = False,
@@ -125,7 +128,14 @@ def verify_reauthored_gemma_model(
125
128
 
126
129
  logging.info("Loading the original model from: %s", checkpoint)
127
130
  original_model = gemma_model.GemmaForCausalLM(config).eval()
128
- original_model.load_weights(os.path.join(checkpoint, weight_filename))
131
+ checkpoint_path = os.path.join(checkpoint, weight_filename)
132
+ if custom_loader is None:
133
+ original_model.load_weights(checkpoint_path)
134
+ else:
135
+ original_model.load_state_dict(
136
+ custom_loader(checkpoint_path)["model_state_dict"],
137
+ strict=False,
138
+ )
129
139
 
130
140
  return verifier.verify_reauthored_model(
131
141
  original_model=GemmaWrapper(original_model),
@@ -144,27 +154,62 @@ def verify_reauthored_gemma_model(
144
154
 
145
155
 
146
156
  def verify_gemma2(
147
- gemma2_model_path: str,
157
+ checkpoint_dir: str,
158
+ weight_filename: str,
148
159
  prompts: List[str],
149
160
  max_new_tokens: int,
150
161
  mask_as_input: bool = False,
151
162
  kv_layout: kv_utils.KVLayout = kv_utils.KV_LAYOUT_DEFAULT,
163
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
152
164
  ) -> bool:
153
165
  """Verifies the reauthored Gemma2 model.
154
166
 
155
167
  Return True if the verification passes, False otherwise.
156
168
  """
157
- logging.info("Building the reauthored model from: %s", gemma2_model_path)
158
- reauthored_model = gemma2.build_2b_model(gemma2_model_path)
169
+ checkpoint_path = os.path.join(checkpoint_dir, weight_filename)
170
+ logging.info("Building the reauthored model from: %s", checkpoint_path)
171
+ reauthored_model = gemma2.build_2b_model(checkpoint_path, custom_loader)
159
172
 
160
173
  return verify_reauthored_gemma_model(
161
- checkpoint=gemma2_model_path,
174
+ checkpoint=checkpoint_dir,
162
175
  variant="2b-v2",
163
176
  reauthored_model=reauthored_model,
164
177
  generate_prompts=prompts,
165
178
  forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
179
+ weight_filename=weight_filename,
180
+ custom_loader=custom_loader,
166
181
  max_new_tokens=max_new_tokens,
167
182
  mask_as_input=mask_as_input,
168
183
  kv_layout=kv_layout,
169
184
  atol=1e-04,
170
185
  )
186
+
187
+
188
+ def verify_gemma1_with_custom_loader(checkpoint_dir: str) -> bool:
189
+ """Verifies the reauthored Gemma1 model with a custom loader."""
190
+ weight_filename = "gemma-2b-it.ckpt"
191
+ checkpoint_path = os.path.join(checkpoint_dir, weight_filename)
192
+ custom_loader = loader.get_custom_loader(checkpoint_path)
193
+ reauthored_model = gemma1.build_2b_model(checkpoint_path, custom_loader)
194
+ return verify_reauthored_gemma_model(
195
+ checkpoint=checkpoint_dir,
196
+ variant="2b",
197
+ reauthored_model=reauthored_model,
198
+ weight_filename=weight_filename,
199
+ custom_loader=custom_loader,
200
+ generate_prompts=["What is the meaning of life?"],
201
+ forward_input_ids=[[1, 2, 3, 4]],
202
+ max_new_tokens=30,
203
+ )
204
+
205
+
206
+ def verify_gemma2_with_custom_loader(checkpoint_dir: str) -> bool:
207
+ """Verifies the reauthored Gemma2 model with a custom loader."""
208
+ return verify_gemma2(
209
+ checkpoint_dir=checkpoint_dir,
210
+ weight_filename="model.ckpt",
211
+ prompts=["What is the meaning of life?"],
212
+ max_new_tokens=30,
213
+ mask_as_input=True,
214
+ custom_loader=loader.get_custom_loader("", checkpoint_format="pt"),
215
+ )
@@ -22,6 +22,7 @@ from typing import Callable, Dict, List, Optional, Tuple
22
22
  from ai_edge_torch.generative.examples.gemma3 import gemma3
23
23
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
+ from ai_edge_torch.generative.utilities import loader
25
26
  from ai_edge_torch.generative.utilities import verifier
26
27
  from gemma import config as gemma_config
27
28
  from gemma import model as gemma_model
@@ -260,3 +261,15 @@ def verify_gemma3(
260
261
  custom_loader=custom_loader,
261
262
  atol=1e-04,
262
263
  )
264
+
265
+
266
+ def verify_gemma3_with_custom_loader(checkpoint: str) -> bool:
267
+ """Verifies the reauthored Gemma3 model with a custom loader."""
268
+ return verify_gemma3(
269
+ checkpoint=checkpoint,
270
+ prompts=["What is the meaning of life?"],
271
+ max_new_tokens=30,
272
+ variant="1b",
273
+ weight_filename="model.ckpt",
274
+ custom_loader=loader.get_custom_loader("", checkpoint_format="pt"),
275
+ )
@@ -15,15 +15,9 @@
15
15
 
16
16
  """Verifies the reauthored Hammer 2.1 0.5B and 1.5B models."""
17
17
 
18
- import logging
19
- import pathlib
20
-
21
18
  from absl import app
22
19
  from absl import flags
23
- from ai_edge_torch.generative.examples.hammer import hammer
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
20
+ from ai_edge_torch.generative.examples.hammer import verify_util
27
21
 
28
22
 
29
23
  _MODEL_SIZE = flags.DEFINE_enum(
@@ -48,37 +42,13 @@ _CHECKPOINT = {
48
42
  "1.5b": "MadeAgents/Hammer2.1-1.5b",
49
43
  }
50
44
 
51
- _BUILDER = {
52
- "0.5b": hammer.build_0_5b_model,
53
- "1.5b": hammer.build_1_5b_model,
54
- }
55
-
56
45
 
57
46
  def main(_):
58
- checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
59
- logging.info("Loading the original model from: %s", checkpoint)
60
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
61
-
62
- # Locate the cached dir.
63
- cached_config_file = transformers.utils.cached_file(
64
- checkpoint, transformers.utils.CONFIG_NAME
65
- )
66
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
67
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
68
- reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
69
-
70
- logging.info("Loading the tokenizer from: %s", checkpoint)
71
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
72
-
73
- verifier.verify_reauthored_model(
74
- original_model=transformers_verifier.TransformersModelWrapper(
75
- original_model
76
- ),
77
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
78
- tokenizer=verifier.TokenizerWrapper(tokenizer),
79
- generate_prompts=_PROMPTS.value,
47
+ verify_util.verify_hammer(
48
+ model_size=_MODEL_SIZE.value,
49
+ checkpoint_dir=_CHECKPOINT[_MODEL_SIZE.value],
80
50
  max_new_tokens=_MAX_NEW_TOKENS.value,
81
- atol=1e-04,
51
+ prompts=_PROMPTS.value,
82
52
  )
83
53
 
84
54
 
@@ -0,0 +1,82 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utils for verifying the Hammer 2.1 model."""
16
+ import logging
17
+ import os
18
+ import pathlib
19
+
20
+ from ai_edge_torch.generative.examples.hammer import hammer
21
+ from ai_edge_torch.generative.utilities import loader
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+
27
+ _BUILDER = {
28
+ "0.5b": hammer.build_0_5b_model,
29
+ "1.5b": hammer.build_1_5b_model,
30
+ }
31
+
32
+ DEFAULT_PROMPTS = ["What is the meaning of life?"]
33
+
34
+
35
+ def verify_hammer(
36
+ model_size: str,
37
+ checkpoint_dir: str,
38
+ weight_filename: str = "model.safetensors",
39
+ max_new_tokens: int = 30,
40
+ initialize_from_local: bool = True,
41
+ prompts: list[str] | None = None,
42
+ ) -> bool:
43
+ """Verifies the reauthored Hammer 2.1 model with a custom loader."""
44
+ logging.info("Loading the original model from: %s", checkpoint_dir)
45
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
46
+ checkpoint_dir
47
+ )
48
+
49
+ logging.info("Building the reauthored model from: %s", checkpoint_dir)
50
+ custom_loader = (
51
+ None
52
+ if initialize_from_local
53
+ else loader.get_custom_loader("", "safetensors")
54
+ )
55
+
56
+ if initialize_from_local:
57
+ # Locate the cached dir.
58
+ cached_config_file = transformers.utils.cached_file(
59
+ checkpoint_dir, transformers.utils.CONFIG_NAME
60
+ )
61
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
62
+ else:
63
+ reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
64
+
65
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
66
+ reauthored_model = _BUILDER[model_size](
67
+ checkpoint_path=reauthored_checkpoint,
68
+ custom_loader=custom_loader,
69
+ )
70
+
71
+ logging.info("Loading the tokenizer from: %s", checkpoint_dir)
72
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
73
+ return verifier.verify_reauthored_model(
74
+ original_model=transformers_verifier.TransformersModelWrapper(
75
+ original_model
76
+ ),
77
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
78
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
79
+ generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
80
+ max_new_tokens=max_new_tokens,
81
+ atol=1e-04,
82
+ )
@@ -15,15 +15,9 @@
15
15
 
16
16
  """Verifies the reauthored Llama 3.2-1B model."""
17
17
 
18
- import logging
19
- import pathlib
20
-
21
18
  from absl import app
22
19
  from absl import flags
23
- from ai_edge_torch.generative.examples.llama import llama
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
20
+ from ai_edge_torch.generative.examples.llama import verify_util
27
21
 
28
22
  _MODEL_SIZE = flags.DEFINE_enum(
29
23
  "model_size",
@@ -47,40 +41,13 @@ _CHECKPOINT = {
47
41
  "3b": "meta-llama/Llama-3.2-3B-Instruct",
48
42
  }
49
43
 
50
- _BUILDER = {
51
- "1b": llama.build_1b_model,
52
- "3b": llama.build_3b_model,
53
- }
54
-
55
44
 
56
45
  def main(_):
57
- checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
58
- logging.info("Loading the original model from: %s", checkpoint)
59
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
60
-
61
- # Locate the cached dir.
62
- cached_config_file = transformers.utils.cached_file(
63
- checkpoint, transformers.utils.CONFIG_NAME
64
- )
65
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
66
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
67
- reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
68
-
69
- logging.info("Loading the tokenizer from: %s", checkpoint)
70
- # Llama tokenizer_config.json sets a fast tokenizer class explicitly,
71
- # "PreTrainedTokenizerFast". It works only when the fast tokenizer is
72
- # available.
73
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
74
-
75
- verifier.verify_reauthored_model(
76
- original_model=transformers_verifier.TransformersModelWrapper(
77
- original_model
78
- ),
79
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
80
- tokenizer=verifier.TokenizerWrapper(tokenizer),
81
- generate_prompts=_PROMPTS.value,
46
+ verify_util.verify_llama_3_2(
47
+ model_size=_MODEL_SIZE.value,
48
+ checkpoint_dir=_CHECKPOINT[_MODEL_SIZE.value],
82
49
  max_new_tokens=_MAX_NEW_TOKENS.value,
83
- atol=1e-04,
50
+ prompts=_PROMPTS.value,
84
51
  )
85
52
 
86
53
 
@@ -0,0 +1,81 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utils for verifying the Llama 3.2-1B model."""
16
+ import logging
17
+ import os
18
+ import pathlib
19
+
20
+ from ai_edge_torch.generative.examples.llama import llama
21
+ from ai_edge_torch.generative.utilities import loader
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _BUILDER = {
27
+ "1b": llama.build_1b_model,
28
+ "3b": llama.build_3b_model,
29
+ }
30
+
31
+ DEFAULT_PROMPTS = ["What is the meaning of life?"]
32
+
33
+
34
+ def verify_llama_3_2(
35
+ model_size: str,
36
+ checkpoint_dir: str,
37
+ weight_filename: str = "model.safetensors",
38
+ max_new_tokens: int = 30,
39
+ initialize_from_local: bool = True,
40
+ prompts: list[str] | None = None,
41
+ ) -> bool:
42
+ """Verifies the reauthored Llama 3.2 model with a custom loader."""
43
+ logging.info("Loading the original model from: %s", checkpoint_dir)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
45
+ checkpoint_dir
46
+ )
47
+
48
+ logging.info("Building the reauthored model from: %s", checkpoint_dir)
49
+ custom_loader = (
50
+ None
51
+ if initialize_from_local
52
+ else loader.get_custom_loader("", "safetensors")
53
+ )
54
+
55
+ if initialize_from_local:
56
+ # Locate the cached dir.
57
+ cached_config_file = transformers.utils.cached_file(
58
+ checkpoint_dir, transformers.utils.CONFIG_NAME
59
+ )
60
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
61
+ else:
62
+ reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
63
+
64
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
65
+ reauthored_model = _BUILDER[model_size](
66
+ checkpoint_path=reauthored_checkpoint,
67
+ custom_loader=custom_loader,
68
+ )
69
+
70
+ logging.info("Loading the tokenizer from: %s", checkpoint_dir)
71
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
72
+ return verifier.verify_reauthored_model(
73
+ original_model=transformers_verifier.TransformersModelWrapper(
74
+ original_model
75
+ ),
76
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
77
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
78
+ generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
79
+ max_new_tokens=max_new_tokens,
80
+ atol=1e-04,
81
+ )