ai-edge-torch-nightly 0.5.0.dev20250517__py3-none-any.whl → 0.5.0.dev20250518__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. ai_edge_torch/generative/examples/amd_llama_135m/verify.py +4 -32
  2. ai_edge_torch/generative/examples/amd_llama_135m/verify_util.py +76 -0
  3. ai_edge_torch/generative/examples/deepseek/verify.py +4 -30
  4. ai_edge_torch/generative/examples/deepseek/verify_util.py +76 -0
  5. ai_edge_torch/generative/examples/gemma/verify_util.py +51 -6
  6. ai_edge_torch/generative/examples/gemma3/verify_util.py +13 -0
  7. ai_edge_torch/generative/examples/hammer/verify.py +5 -35
  8. ai_edge_torch/generative/examples/hammer/verify_util.py +82 -0
  9. ai_edge_torch/generative/examples/llama/verify.py +5 -38
  10. ai_edge_torch/generative/examples/llama/verify_util.py +81 -0
  11. ai_edge_torch/generative/examples/openelm/verify.py +4 -31
  12. ai_edge_torch/generative/examples/openelm/verify_util.py +76 -0
  13. ai_edge_torch/generative/examples/phi/verify.py +6 -24
  14. ai_edge_torch/generative/examples/phi/verify_phi3.py +5 -28
  15. ai_edge_torch/generative/examples/phi/verify_phi4.py +5 -28
  16. ai_edge_torch/generative/examples/phi/verify_util.py +84 -0
  17. ai_edge_torch/generative/examples/qwen/verify.py +5 -35
  18. ai_edge_torch/generative/examples/qwen/verify_util.py +83 -0
  19. ai_edge_torch/generative/examples/smollm/verify.py +5 -36
  20. ai_edge_torch/generative/examples/smollm/verify_util.py +81 -0
  21. ai_edge_torch/generative/examples/tiny_llama/verify.py +4 -31
  22. ai_edge_torch/generative/examples/tiny_llama/verify_util.py +76 -0
  23. ai_edge_torch/generative/utilities/loader.py +11 -1
  24. ai_edge_torch/version.py +1 -1
  25. {ai_edge_torch_nightly-0.5.0.dev20250517.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/METADATA +1 -1
  26. {ai_edge_torch_nightly-0.5.0.dev20250517.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/RECORD +29 -20
  27. {ai_edge_torch_nightly-0.5.0.dev20250517.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/LICENSE +0 -0
  28. {ai_edge_torch_nightly-0.5.0.dev20250517.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/WHEEL +0 -0
  29. {ai_edge_torch_nightly-0.5.0.dev20250517.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/top_level.txt +0 -0
@@ -15,14 +15,9 @@
15
15
 
16
16
  """Verifies the reauthored OpenELM-3B model."""
17
17
 
18
- import logging
19
- import pathlib
20
18
  from absl import app
21
19
  from absl import flags
22
- from ai_edge_torch.generative.examples.openelm import openelm
23
- from ai_edge_torch.generative.utilities import transformers_verifier
24
- from ai_edge_torch.generative.utilities import verifier
25
- import transformers
20
+ from ai_edge_torch.generative.examples.openelm import verify_util
26
21
 
27
22
 
28
23
  _PROMPTS = flags.DEFINE_multi_string(
@@ -38,32 +33,10 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
38
33
 
39
34
 
40
35
  def main(_):
41
- checkpoint = "apple/OpenELM-3B"
42
- logging.info("Loading the original model from: %s", checkpoint)
43
- original_model = transformers.AutoModelForCausalLM.from_pretrained(
44
- checkpoint, trust_remote_code=True
45
- )
46
-
47
- # Locate the cached dir.
48
- cached_config_file = transformers.utils.cached_file(
49
- checkpoint, transformers.utils.CONFIG_NAME
50
- )
51
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
52
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
53
- reauthored_model = openelm.build_model(str(reauthored_checkpoint))
54
-
55
- tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
56
- logging.info("Loading the tokenizer from: %s", tokenizer_checkpoint)
57
- tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
58
-
59
- verifier.verify_reauthored_model(
60
- original_model=transformers_verifier.TransformersModelWrapper(
61
- original_model
62
- ),
63
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
- tokenizer=verifier.TokenizerWrapper(tokenizer),
65
- generate_prompts=_PROMPTS.value,
36
+ verify_util.verify_openelm(
37
+ checkpoint_dir="apple/OpenELM-3B",
66
38
  max_new_tokens=_MAX_NEW_TOKENS.value,
39
+ prompts=_PROMPTS.value,
67
40
  )
68
41
 
69
42
 
@@ -0,0 +1,76 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utils for verifying the OpenELM model."""
16
+ import logging
17
+ import os
18
+ import pathlib
19
+
20
+ from ai_edge_torch.generative.examples.openelm import openelm
21
+ from ai_edge_torch.generative.utilities import loader
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+
27
+ DEFAULT_PROMPTS = ["What is the meaning of life?"]
28
+
29
+
30
+ def verify_openelm(
31
+ checkpoint_dir: str,
32
+ weight_filename: str = "model.safetensors",
33
+ max_new_tokens: int = 30,
34
+ initialize_from_local: bool = True,
35
+ prompts: list[str] | None = None,
36
+ ) -> bool:
37
+ """Verifies the reauthored OpenELM model with a custom loader."""
38
+ logging.info("Loading the original model from: %s", checkpoint_dir)
39
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
40
+ checkpoint_dir
41
+ )
42
+
43
+ logging.info("Building the reauthored model from: %s", checkpoint_dir)
44
+ custom_loader = (
45
+ None
46
+ if initialize_from_local
47
+ else loader.get_custom_loader("", "safetensors")
48
+ )
49
+
50
+ if initialize_from_local:
51
+ # Locate the cached dir.
52
+ cached_config_file = transformers.utils.cached_file(
53
+ checkpoint_dir, transformers.utils.CONFIG_NAME
54
+ )
55
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
56
+ else:
57
+ reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
58
+
59
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
60
+ reauthored_model = openelm.build_model(
61
+ checkpoint_path=reauthored_checkpoint,
62
+ custom_loader=custom_loader,
63
+ )
64
+
65
+ logging.info("Loading the tokenizer from: %s", checkpoint_dir)
66
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
67
+ return verifier.verify_reauthored_model(
68
+ original_model=transformers_verifier.TransformersModelWrapper(
69
+ original_model
70
+ ),
71
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
72
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
73
+ generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
74
+ max_new_tokens=max_new_tokens,
75
+ atol=1e-04,
76
+ )
@@ -14,15 +14,10 @@
14
14
  # ==============================================================================
15
15
 
16
16
  """Verifies the reauthored Phi-2 model."""
17
- import logging
18
17
 
19
18
  from absl import app
20
19
  from absl import flags
21
- from ai_edge_torch.generative.examples.phi import phi2
22
- from ai_edge_torch.generative.utilities import transformers_verifier
23
- from ai_edge_torch.generative.utilities import verifier
24
- import kagglehub
25
- import transformers
20
+ from ai_edge_torch.generative.examples.phi import verify_util
26
21
 
27
22
 
28
23
  _PROMPTS = flags.DEFINE_multi_string(
@@ -38,25 +33,12 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
38
33
 
39
34
 
40
35
  def main(_):
41
- checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
42
- logging.info("Loading the original model from: %s", checkpoint)
43
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
44
-
45
- logging.info("Building the reauthored model from: %s", checkpoint)
46
- reauthored_model = phi2.build_model(checkpoint)
47
-
48
- logging.info("Loading the tokenizer from: %s", checkpoint)
49
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
50
-
51
- verifier.verify_reauthored_model(
52
- original_model=transformers_verifier.TransformersModelWrapper(
53
- original_model
54
- ),
55
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
56
- tokenizer=verifier.TokenizerWrapper(tokenizer),
57
- generate_prompts=_PROMPTS.value,
36
+ verify_util.verify_phi(
37
+ version="v2",
38
+ checkpoint_dir="microsoft/phi-2",
58
39
  max_new_tokens=_MAX_NEW_TOKENS.value,
59
- atol=1e-03,
40
+ prompts=_PROMPTS.value,
41
+ atol=1e-02,
60
42
  )
61
43
 
62
44
 
@@ -15,15 +15,10 @@
15
15
 
16
16
  """Verifies the reauthored Phi-3.5 model."""
17
17
 
18
- import logging
19
- import pathlib
20
18
 
21
19
  from absl import app
22
20
  from absl import flags
23
- from ai_edge_torch.generative.examples.phi import phi3
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
21
+ from ai_edge_torch.generative.examples.phi import verify_util
27
22
 
28
23
 
29
24
  _PROMPTS = flags.DEFINE_multi_string(
@@ -39,29 +34,11 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
39
34
 
40
35
 
41
36
  def main(_):
42
- checkpoint = "microsoft/Phi-3.5-mini-instruct"
43
- logging.info("Loading the original model from: %s", checkpoint)
44
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
-
46
- # Locate the cached dir.
47
- cached_config_file = transformers.utils.cached_file(
48
- checkpoint, transformers.utils.CONFIG_NAME
49
- )
50
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
- reauthored_model = phi3.build_model(reauthored_checkpoint)
53
-
54
- logging.info("Loading the tokenizer from: %s", checkpoint)
55
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
56
-
57
- verifier.verify_reauthored_model(
58
- original_model=transformers_verifier.TransformersModelWrapper(
59
- original_model
60
- ),
61
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
- tokenizer=verifier.TokenizerWrapper(tokenizer),
63
- generate_prompts=_PROMPTS.value,
37
+ verify_util.verify_phi(
38
+ version="v3",
39
+ checkpoint_dir="microsoft/Phi-3.5-mini-instruct",
64
40
  max_new_tokens=_MAX_NEW_TOKENS.value,
41
+ prompts=_PROMPTS.value,
65
42
  )
66
43
 
67
44
 
@@ -15,15 +15,10 @@
15
15
 
16
16
  """Verifies the reauthored Phi-4 model."""
17
17
 
18
- import logging
19
- import pathlib
20
18
 
21
19
  from absl import app
22
20
  from absl import flags
23
- from ai_edge_torch.generative.examples.phi import phi4
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
21
+ from ai_edge_torch.generative.examples.phi import verify_util
27
22
 
28
23
 
29
24
  _PROMPTS = flags.DEFINE_multi_string(
@@ -39,29 +34,11 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
39
34
 
40
35
 
41
36
  def main(_):
42
- checkpoint = "microsoft/Phi-4-mini-instruct"
43
- logging.info("Loading the original model from: %s", checkpoint)
44
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
-
46
- # Locate the cached dir.
47
- cached_config_file = transformers.utils.cached_file(
48
- checkpoint, transformers.utils.CONFIG_NAME
49
- )
50
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
- reauthored_model = phi4.build_model(reauthored_checkpoint)
53
-
54
- logging.info("Loading the tokenizer from: %s", checkpoint)
55
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
56
-
57
- verifier.verify_reauthored_model(
58
- original_model=transformers_verifier.TransformersModelWrapper(
59
- original_model
60
- ),
61
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
- tokenizer=verifier.TokenizerWrapper(tokenizer),
63
- generate_prompts=_PROMPTS.value,
37
+ verify_util.verify_phi(
38
+ version="v4",
39
+ checkpoint_dir="microsoft/Phi-4-mini-instruct",
64
40
  max_new_tokens=_MAX_NEW_TOKENS.value,
41
+ prompts=_PROMPTS.value,
65
42
  )
66
43
 
67
44
 
@@ -0,0 +1,84 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utils for verifying the Phi model."""
16
+ import logging
17
+ import os
18
+ import pathlib
19
+
20
+ from ai_edge_torch.generative.examples.phi import phi2, phi3, phi4
21
+ from ai_edge_torch.generative.utilities import loader
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+
27
+ DEFAULT_PROMPTS = ["Instruct: Write an email about the weather Output:"]
28
+
29
+ _BUILDER = {
30
+ "v2": phi2.build_model,
31
+ "v3": phi3.build_model,
32
+ "v4": phi4.build_model,
33
+ }
34
+
35
+
36
+ def verify_phi(
37
+ version: str,
38
+ checkpoint_dir: str,
39
+ weight_filename: str = "model.safetensors",
40
+ max_new_tokens: int = 30,
41
+ initialize_from_local: bool = True,
42
+ prompts: list[str] | None = None,
43
+ atol: float = 1e-04,
44
+ ) -> bool:
45
+ """Verifies the reauthored Phi model with a custom loader."""
46
+ logging.info("Loading the original model from: %s", checkpoint_dir)
47
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
48
+ checkpoint_dir
49
+ )
50
+
51
+ logging.info("Building the reauthored model from: %s", checkpoint_dir)
52
+ custom_loader = (
53
+ None
54
+ if initialize_from_local
55
+ else loader.get_custom_loader("", "safetensors")
56
+ )
57
+
58
+ if initialize_from_local:
59
+ # Locate the cached dir.
60
+ cached_config_file = transformers.utils.cached_file(
61
+ checkpoint_dir, transformers.utils.CONFIG_NAME
62
+ )
63
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
64
+ else:
65
+ reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
66
+
67
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
68
+ reauthored_model = _BUILDER[version](
69
+ checkpoint_path=reauthored_checkpoint,
70
+ custom_loader=custom_loader,
71
+ )
72
+
73
+ logging.info("Loading the tokenizer from: %s", checkpoint_dir)
74
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
75
+ return verifier.verify_reauthored_model(
76
+ original_model=transformers_verifier.TransformersModelWrapper(
77
+ original_model
78
+ ),
79
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
80
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
81
+ generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
82
+ max_new_tokens=max_new_tokens,
83
+ atol=atol,
84
+ )
@@ -15,15 +15,10 @@
15
15
 
16
16
  """Verifies the reauthored Qwen 2.5 0.5B, 1.5B, and 3B models."""
17
17
 
18
- import logging
19
- import pathlib
20
18
 
21
19
  from absl import app
22
20
  from absl import flags
23
- from ai_edge_torch.generative.examples.qwen import qwen
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
21
+ from ai_edge_torch.generative.examples.qwen import verify_util
27
22
 
28
23
 
29
24
  _MODEL_SIZE = flags.DEFINE_enum(
@@ -49,38 +44,13 @@ _CHECKPOINT = {
49
44
  "3b": "Qwen/Qwen2.5-3B-Instruct",
50
45
  }
51
46
 
52
- _BUILDER = {
53
- "0.5b": qwen.build_0_5b_model,
54
- "1.5b": qwen.build_1_5b_model,
55
- "3b": qwen.build_3b_model,
56
- }
57
-
58
47
 
59
48
  def main(_):
60
- checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
61
- logging.info("Loading the original model from: %s", checkpoint)
62
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
63
-
64
- # Locate the cached dir.
65
- cached_config_file = transformers.utils.cached_file(
66
- checkpoint, transformers.utils.CONFIG_NAME
67
- )
68
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
69
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
70
- reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
71
-
72
- logging.info("Loading the tokenizer from: %s", checkpoint)
73
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
74
-
75
- verifier.verify_reauthored_model(
76
- original_model=transformers_verifier.TransformersModelWrapper(
77
- original_model
78
- ),
79
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
80
- tokenizer=verifier.TokenizerWrapper(tokenizer),
81
- generate_prompts=_PROMPTS.value,
49
+ verify_util.verify_qwen(
50
+ model_size=_MODEL_SIZE.value,
51
+ checkpoint_dir=_CHECKPOINT[_MODEL_SIZE.value],
82
52
  max_new_tokens=_MAX_NEW_TOKENS.value,
83
- atol=1e-04,
53
+ prompts=_PROMPTS.value,
84
54
  )
85
55
 
86
56
 
@@ -0,0 +1,83 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utils for verifying the Qwen model."""
16
+ import logging
17
+ import os
18
+ import pathlib
19
+
20
+ from ai_edge_torch.generative.examples.qwen import qwen
21
+ from ai_edge_torch.generative.utilities import loader
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+
27
+ _BUILDER = {
28
+ "0.5b": qwen.build_0_5b_model,
29
+ "1.5b": qwen.build_1_5b_model,
30
+ "3b": qwen.build_3b_model,
31
+ }
32
+
33
+ DEFAULT_PROMPTS = ["What is the meaning of life?"]
34
+
35
+
36
+ def verify_qwen(
37
+ model_size: str,
38
+ checkpoint_dir: str,
39
+ weight_filename: str = "model.safetensors",
40
+ max_new_tokens: int = 30,
41
+ initialize_from_local: bool = True,
42
+ prompts: list[str] | None = None,
43
+ ) -> bool:
44
+ """Verifies the reauthored Llama 3.2 model with a custom loader."""
45
+ logging.info("Loading the original model from: %s", checkpoint_dir)
46
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
47
+ checkpoint_dir
48
+ )
49
+
50
+ logging.info("Building the reauthored model from: %s", checkpoint_dir)
51
+ custom_loader = (
52
+ None
53
+ if initialize_from_local
54
+ else loader.get_custom_loader("", "safetensors")
55
+ )
56
+
57
+ if initialize_from_local:
58
+ # Locate the cached dir.
59
+ cached_config_file = transformers.utils.cached_file(
60
+ checkpoint_dir, transformers.utils.CONFIG_NAME
61
+ )
62
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
63
+ else:
64
+ reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
65
+
66
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
67
+ reauthored_model = _BUILDER[model_size](
68
+ checkpoint_path=reauthored_checkpoint,
69
+ custom_loader=custom_loader,
70
+ )
71
+
72
+ logging.info("Loading the tokenizer from: %s", checkpoint_dir)
73
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
74
+ return verifier.verify_reauthored_model(
75
+ original_model=transformers_verifier.TransformersModelWrapper(
76
+ original_model
77
+ ),
78
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
79
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
80
+ generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
81
+ max_new_tokens=max_new_tokens,
82
+ atol=1e-04,
83
+ )
@@ -15,15 +15,9 @@
15
15
 
16
16
  """Verifies the reauthored SmolLM-135M model."""
17
17
 
18
- import logging
19
- import pathlib
20
-
21
18
  from absl import app
22
19
  from absl import flags
23
- from ai_edge_torch.generative.examples.smollm import smollm
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
20
+ from ai_edge_torch.generative.examples.smollm import verify_util
27
21
 
28
22
 
29
23
  _PROMPTS = flags.DEFINE_multi_string(
@@ -47,38 +41,13 @@ _CHECKPOINT = {
47
41
  "v2": "HuggingFaceTB/SmolLM2-135M-Instruct",
48
42
  }
49
43
 
50
- _BUILDER = {
51
- "v1": smollm.build_model,
52
- "v2": smollm.build_model_v2,
53
- }
54
-
55
44
 
56
45
  def main(_):
57
- checkpoint = _CHECKPOINT[_MODEL_VERSION.value]
58
- builder = _BUILDER[_MODEL_VERSION.value]
59
- logging.info("Loading the original model from: %s", checkpoint)
60
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
61
-
62
- # Locate the cached dir.
63
- cached_config_file = transformers.utils.cached_file(
64
- checkpoint, transformers.utils.CONFIG_NAME
65
- )
66
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
67
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
68
- reauthored_model = builder(reauthored_checkpoint)
69
-
70
- logging.info("Loading the tokenizer from: %s", checkpoint)
71
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
72
-
73
- verifier.verify_reauthored_model(
74
- original_model=transformers_verifier.TransformersModelWrapper(
75
- original_model
76
- ),
77
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
78
- tokenizer=verifier.TokenizerWrapper(tokenizer),
79
- generate_prompts=_PROMPTS.value,
46
+ verify_util.verify_smollm_135m(
47
+ model_version=_MODEL_VERSION.value,
48
+ checkpoint_dir=_CHECKPOINT[_MODEL_VERSION.value],
80
49
  max_new_tokens=_MAX_NEW_TOKENS.value,
81
- atol=1e-04,
50
+ prompts=_PROMPTS.value,
82
51
  )
83
52
 
84
53
 
@@ -0,0 +1,81 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utils for verifying the SmoLLM model."""
16
+ import logging
17
+ import os
18
+ import pathlib
19
+
20
+ from ai_edge_torch.generative.examples.smollm import smollm
21
+ from ai_edge_torch.generative.utilities import loader
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _BUILDER = {
27
+ "v1": smollm.build_model,
28
+ "v2": smollm.build_model_v2,
29
+ }
30
+
31
+ DEFAULT_PROMPTS = ["What is the meaning of life?"]
32
+
33
+
34
+ def verify_smollm_135m(
35
+ model_version: str,
36
+ checkpoint_dir: str,
37
+ weight_filename: str = "model.safetensors",
38
+ max_new_tokens: int = 30,
39
+ initialize_from_local: bool = True,
40
+ prompts: list[str] | None = None,
41
+ ) -> bool:
42
+ """Verifies the reauthored SmoLLM model with a custom loader."""
43
+ logging.info("Loading the original model from: %s", checkpoint_dir)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
45
+ checkpoint_dir
46
+ )
47
+
48
+ logging.info("Building the reauthored model from: %s", checkpoint_dir)
49
+ custom_loader = (
50
+ None
51
+ if initialize_from_local
52
+ else loader.get_custom_loader("", "safetensors")
53
+ )
54
+
55
+ if initialize_from_local:
56
+ # Locate the cached dir.
57
+ cached_config_file = transformers.utils.cached_file(
58
+ checkpoint_dir, transformers.utils.CONFIG_NAME
59
+ )
60
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
61
+ else:
62
+ reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
63
+
64
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
65
+ reauthored_model = _BUILDER[model_version](
66
+ checkpoint_path=reauthored_checkpoint,
67
+ custom_loader=custom_loader,
68
+ )
69
+
70
+ logging.info("Loading the tokenizer from: %s", checkpoint_dir)
71
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
72
+ return verifier.verify_reauthored_model(
73
+ original_model=transformers_verifier.TransformersModelWrapper(
74
+ original_model
75
+ ),
76
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
77
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
78
+ generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
79
+ max_new_tokens=max_new_tokens,
80
+ atol=1e-04,
81
+ )