ai-edge-torch-nightly 0.5.0.dev20250516__py3-none-any.whl → 0.5.0.dev20250518__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. ai_edge_torch/_convert/conversion.py +1 -0
  2. ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
  3. ai_edge_torch/_convert/fx_passes/eliminate_dead_code_pass.py +40 -0
  4. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +1 -4
  5. ai_edge_torch/generative/examples/amd_llama_135m/verify.py +4 -32
  6. ai_edge_torch/generative/examples/amd_llama_135m/verify_util.py +76 -0
  7. ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -4
  8. ai_edge_torch/generative/examples/deepseek/verify.py +4 -30
  9. ai_edge_torch/generative/examples/deepseek/verify_util.py +76 -0
  10. ai_edge_torch/generative/examples/gemma/gemma1.py +1 -5
  11. ai_edge_torch/generative/examples/gemma/gemma2.py +1 -5
  12. ai_edge_torch/generative/examples/gemma/verify_util.py +51 -6
  13. ai_edge_torch/generative/examples/gemma3/decoder.py +10 -10
  14. ai_edge_torch/generative/examples/gemma3/gemma3.py +1 -3
  15. ai_edge_torch/generative/examples/gemma3/image_encoder.py +1 -4
  16. ai_edge_torch/generative/examples/gemma3/verify_util.py +13 -0
  17. ai_edge_torch/generative/examples/hammer/hammer.py +1 -4
  18. ai_edge_torch/generative/examples/hammer/verify.py +5 -35
  19. ai_edge_torch/generative/examples/hammer/verify_util.py +82 -0
  20. ai_edge_torch/generative/examples/llama/llama.py +1 -4
  21. ai_edge_torch/generative/examples/llama/verify.py +5 -38
  22. ai_edge_torch/generative/examples/llama/verify_util.py +81 -0
  23. ai_edge_torch/generative/examples/openelm/openelm.py +1 -2
  24. ai_edge_torch/generative/examples/openelm/verify.py +4 -31
  25. ai_edge_torch/generative/examples/openelm/verify_util.py +76 -0
  26. ai_edge_torch/generative/examples/paligemma/decoder.py +1 -5
  27. ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -5
  28. ai_edge_torch/generative/examples/paligemma/image_encoder.py +1 -4
  29. ai_edge_torch/generative/examples/phi/phi2.py +1 -4
  30. ai_edge_torch/generative/examples/phi/phi3.py +1 -4
  31. ai_edge_torch/generative/examples/phi/phi4.py +1 -4
  32. ai_edge_torch/generative/examples/phi/verify.py +6 -24
  33. ai_edge_torch/generative/examples/phi/verify_phi3.py +5 -28
  34. ai_edge_torch/generative/examples/phi/verify_phi4.py +5 -28
  35. ai_edge_torch/generative/examples/phi/verify_util.py +84 -0
  36. ai_edge_torch/generative/examples/qwen/qwen.py +1 -4
  37. ai_edge_torch/generative/examples/qwen/verify.py +5 -35
  38. ai_edge_torch/generative/examples/qwen/verify_util.py +83 -0
  39. ai_edge_torch/generative/examples/qwen_vl/decoder.py +1 -2
  40. ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +1 -3
  41. ai_edge_torch/generative/examples/smollm/smollm.py +1 -4
  42. ai_edge_torch/generative/examples/smollm/verify.py +5 -36
  43. ai_edge_torch/generative/examples/smollm/verify_util.py +81 -0
  44. ai_edge_torch/generative/examples/stable_diffusion/clip.py +6 -4
  45. ai_edge_torch/generative/examples/t5/t5.py +1 -3
  46. ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
  47. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -2
  48. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +1 -4
  49. ai_edge_torch/generative/examples/tiny_llama/verify.py +4 -31
  50. ai_edge_torch/generative/examples/tiny_llama/verify_util.py +76 -0
  51. ai_edge_torch/generative/layers/model_config.py +2 -2
  52. ai_edge_torch/generative/utilities/converter.py +2 -1
  53. ai_edge_torch/generative/utilities/loader.py +11 -1
  54. ai_edge_torch/odml_torch/lowerings/utils.py +13 -0
  55. ai_edge_torch/version.py +1 -1
  56. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/METADATA +1 -1
  57. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/RECORD +60 -50
  58. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/LICENSE +0 -0
  59. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/WHEEL +0 -0
  60. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/top_level.txt +0 -0
@@ -15,15 +15,10 @@
15
15
 
16
16
  """Verifies the reauthored Phi-4 model."""
17
17
 
18
- import logging
19
- import pathlib
20
18
 
21
19
  from absl import app
22
20
  from absl import flags
23
- from ai_edge_torch.generative.examples.phi import phi4
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
21
+ from ai_edge_torch.generative.examples.phi import verify_util
27
22
 
28
23
 
29
24
  _PROMPTS = flags.DEFINE_multi_string(
@@ -39,29 +34,11 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
39
34
 
40
35
 
41
36
  def main(_):
42
- checkpoint = "microsoft/Phi-4-mini-instruct"
43
- logging.info("Loading the original model from: %s", checkpoint)
44
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
-
46
- # Locate the cached dir.
47
- cached_config_file = transformers.utils.cached_file(
48
- checkpoint, transformers.utils.CONFIG_NAME
49
- )
50
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
- reauthored_model = phi4.build_model(reauthored_checkpoint)
53
-
54
- logging.info("Loading the tokenizer from: %s", checkpoint)
55
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
56
-
57
- verifier.verify_reauthored_model(
58
- original_model=transformers_verifier.TransformersModelWrapper(
59
- original_model
60
- ),
61
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
- tokenizer=verifier.TokenizerWrapper(tokenizer),
63
- generate_prompts=_PROMPTS.value,
37
+ verify_util.verify_phi(
38
+ version="v4",
39
+ checkpoint_dir="microsoft/Phi-4-mini-instruct",
64
40
  max_new_tokens=_MAX_NEW_TOKENS.value,
41
+ prompts=_PROMPTS.value,
65
42
  )
66
43
 
67
44
 
@@ -0,0 +1,84 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utils for verifying the Phi model."""
16
+ import logging
17
+ import os
18
+ import pathlib
19
+
20
+ from ai_edge_torch.generative.examples.phi import phi2, phi3, phi4
21
+ from ai_edge_torch.generative.utilities import loader
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+
27
+ DEFAULT_PROMPTS = ["Instruct: Write an email about the weather Output:"]
28
+
29
+ _BUILDER = {
30
+ "v2": phi2.build_model,
31
+ "v3": phi3.build_model,
32
+ "v4": phi4.build_model,
33
+ }
34
+
35
+
36
+ def verify_phi(
37
+ version: str,
38
+ checkpoint_dir: str,
39
+ weight_filename: str = "model.safetensors",
40
+ max_new_tokens: int = 30,
41
+ initialize_from_local: bool = True,
42
+ prompts: list[str] | None = None,
43
+ atol: float = 1e-04,
44
+ ) -> bool:
45
+ """Verifies the reauthored Phi model with a custom loader."""
46
+ logging.info("Loading the original model from: %s", checkpoint_dir)
47
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
48
+ checkpoint_dir
49
+ )
50
+
51
+ logging.info("Building the reauthored model from: %s", checkpoint_dir)
52
+ custom_loader = (
53
+ None
54
+ if initialize_from_local
55
+ else loader.get_custom_loader("", "safetensors")
56
+ )
57
+
58
+ if initialize_from_local:
59
+ # Locate the cached dir.
60
+ cached_config_file = transformers.utils.cached_file(
61
+ checkpoint_dir, transformers.utils.CONFIG_NAME
62
+ )
63
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
64
+ else:
65
+ reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
66
+
67
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
68
+ reauthored_model = _BUILDER[version](
69
+ checkpoint_path=reauthored_checkpoint,
70
+ custom_loader=custom_loader,
71
+ )
72
+
73
+ logging.info("Loading the tokenizer from: %s", checkpoint_dir)
74
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
75
+ return verifier.verify_reauthored_model(
76
+ original_model=transformers_verifier.TransformersModelWrapper(
77
+ original_model
78
+ ),
79
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
80
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
81
+ generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
82
+ max_new_tokens=max_new_tokens,
83
+ atol=atol,
84
+ )
@@ -53,9 +53,7 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
53
53
  intermediate_size=11008,
54
54
  )
55
55
  norm_config = cfg.NormalizationConfig(
56
- type=cfg.NormalizationType.RMS_NORM,
57
- epsilon=1e-06,
58
- enable_hlfb=True,
56
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
59
57
  )
60
58
  block_config = cfg.TransformerBlockConfig(
61
59
  attn_config=attn_config,
@@ -71,7 +69,6 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
71
69
  kv_cache_max_len=kv_cache_max_len,
72
70
  block_configs=block_config,
73
71
  final_norm_config=norm_config,
74
- enable_hlfb=True,
75
72
  )
76
73
  return config
77
74
 
@@ -15,15 +15,10 @@
15
15
 
16
16
  """Verifies the reauthored Qwen 2.5 0.5B, 1.5B, and 3B models."""
17
17
 
18
- import logging
19
- import pathlib
20
18
 
21
19
  from absl import app
22
20
  from absl import flags
23
- from ai_edge_torch.generative.examples.qwen import qwen
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
21
+ from ai_edge_torch.generative.examples.qwen import verify_util
27
22
 
28
23
 
29
24
  _MODEL_SIZE = flags.DEFINE_enum(
@@ -49,38 +44,13 @@ _CHECKPOINT = {
49
44
  "3b": "Qwen/Qwen2.5-3B-Instruct",
50
45
  }
51
46
 
52
- _BUILDER = {
53
- "0.5b": qwen.build_0_5b_model,
54
- "1.5b": qwen.build_1_5b_model,
55
- "3b": qwen.build_3b_model,
56
- }
57
-
58
47
 
59
48
  def main(_):
60
- checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
61
- logging.info("Loading the original model from: %s", checkpoint)
62
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
63
-
64
- # Locate the cached dir.
65
- cached_config_file = transformers.utils.cached_file(
66
- checkpoint, transformers.utils.CONFIG_NAME
67
- )
68
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
69
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
70
- reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
71
-
72
- logging.info("Loading the tokenizer from: %s", checkpoint)
73
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
74
-
75
- verifier.verify_reauthored_model(
76
- original_model=transformers_verifier.TransformersModelWrapper(
77
- original_model
78
- ),
79
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
80
- tokenizer=verifier.TokenizerWrapper(tokenizer),
81
- generate_prompts=_PROMPTS.value,
49
+ verify_util.verify_qwen(
50
+ model_size=_MODEL_SIZE.value,
51
+ checkpoint_dir=_CHECKPOINT[_MODEL_SIZE.value],
82
52
  max_new_tokens=_MAX_NEW_TOKENS.value,
83
- atol=1e-04,
53
+ prompts=_PROMPTS.value,
84
54
  )
85
55
 
86
56
 
@@ -0,0 +1,83 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utils for verifying the Qwen model."""
16
+ import logging
17
+ import os
18
+ import pathlib
19
+
20
+ from ai_edge_torch.generative.examples.qwen import qwen
21
+ from ai_edge_torch.generative.utilities import loader
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+
27
+ _BUILDER = {
28
+ "0.5b": qwen.build_0_5b_model,
29
+ "1.5b": qwen.build_1_5b_model,
30
+ "3b": qwen.build_3b_model,
31
+ }
32
+
33
+ DEFAULT_PROMPTS = ["What is the meaning of life?"]
34
+
35
+
36
+ def verify_qwen(
37
+ model_size: str,
38
+ checkpoint_dir: str,
39
+ weight_filename: str = "model.safetensors",
40
+ max_new_tokens: int = 30,
41
+ initialize_from_local: bool = True,
42
+ prompts: list[str] | None = None,
43
+ ) -> bool:
44
+ """Verifies the reauthored Llama 3.2 model with a custom loader."""
45
+ logging.info("Loading the original model from: %s", checkpoint_dir)
46
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
47
+ checkpoint_dir
48
+ )
49
+
50
+ logging.info("Building the reauthored model from: %s", checkpoint_dir)
51
+ custom_loader = (
52
+ None
53
+ if initialize_from_local
54
+ else loader.get_custom_loader("", "safetensors")
55
+ )
56
+
57
+ if initialize_from_local:
58
+ # Locate the cached dir.
59
+ cached_config_file = transformers.utils.cached_file(
60
+ checkpoint_dir, transformers.utils.CONFIG_NAME
61
+ )
62
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
63
+ else:
64
+ reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
65
+
66
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
67
+ reauthored_model = _BUILDER[model_size](
68
+ checkpoint_path=reauthored_checkpoint,
69
+ custom_loader=custom_loader,
70
+ )
71
+
72
+ logging.info("Loading the tokenizer from: %s", checkpoint_dir)
73
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
74
+ return verifier.verify_reauthored_model(
75
+ original_model=transformers_verifier.TransformersModelWrapper(
76
+ original_model
77
+ ),
78
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
79
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
80
+ generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
81
+ max_new_tokens=max_new_tokens,
82
+ atol=1e-04,
83
+ )
@@ -97,7 +97,7 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
97
97
  intermediate_size=11008,
98
98
  )
99
99
  norm_config = cfg.NormalizationConfig(
100
- type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06, enable_hlfb=True
100
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
101
101
  )
102
102
  block_config = cfg.TransformerBlockConfig(
103
103
  attn_config=attn_config,
@@ -113,7 +113,6 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
113
113
  kv_cache_max_len=kv_cache_max_len,
114
114
  block_configs=block_config,
115
115
  final_norm_config=norm_config,
116
- enable_hlfb=True,
117
116
  )
118
117
  return config
119
118
 
@@ -332,8 +332,7 @@ def get_image_encoder_config(image_size: Tuple[int, int]) -> QwenVLImageConfig:
332
332
  use_bias=True,
333
333
  )
334
334
  norm_config = cfg.NormalizationConfig(
335
- type=cfg.NormalizationType.RMS_NORM,
336
- epsilon=1e-6,
335
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
337
336
  )
338
337
  block_config = cfg.TransformerBlockConfig(
339
338
  attn_config=attn_config,
@@ -359,7 +358,6 @@ def get_image_encoder_config(image_size: Tuple[int, int]) -> QwenVLImageConfig:
359
358
  window_size=112,
360
359
  spatial_merge_size=2,
361
360
  full_atten_block_indexes=[7, 15, 23, 31],
362
- enable_hlfb=True,
363
361
  )
364
362
  return config
365
363
 
@@ -51,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
51
51
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
52
52
  intermediate_size=1536,
53
53
  )
54
- norm_config = cfg.NormalizationConfig(
55
- type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
56
- )
54
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
57
55
  block_config = cfg.TransformerBlockConfig(
58
56
  attn_config=attn_config,
59
57
  ff_config=ff_config,
@@ -68,7 +66,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
68
66
  kv_cache_max_len=kv_cache_max_len,
69
67
  block_configs=block_config,
70
68
  final_norm_config=norm_config,
71
- enable_hlfb=True,
72
69
  )
73
70
  return config
74
71
 
@@ -15,15 +15,9 @@
15
15
 
16
16
  """Verifies the reauthored SmolLM-135M model."""
17
17
 
18
- import logging
19
- import pathlib
20
-
21
18
  from absl import app
22
19
  from absl import flags
23
- from ai_edge_torch.generative.examples.smollm import smollm
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
20
+ from ai_edge_torch.generative.examples.smollm import verify_util
27
21
 
28
22
 
29
23
  _PROMPTS = flags.DEFINE_multi_string(
@@ -47,38 +41,13 @@ _CHECKPOINT = {
47
41
  "v2": "HuggingFaceTB/SmolLM2-135M-Instruct",
48
42
  }
49
43
 
50
- _BUILDER = {
51
- "v1": smollm.build_model,
52
- "v2": smollm.build_model_v2,
53
- }
54
-
55
44
 
56
45
  def main(_):
57
- checkpoint = _CHECKPOINT[_MODEL_VERSION.value]
58
- builder = _BUILDER[_MODEL_VERSION.value]
59
- logging.info("Loading the original model from: %s", checkpoint)
60
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
61
-
62
- # Locate the cached dir.
63
- cached_config_file = transformers.utils.cached_file(
64
- checkpoint, transformers.utils.CONFIG_NAME
65
- )
66
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
67
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
68
- reauthored_model = builder(reauthored_checkpoint)
69
-
70
- logging.info("Loading the tokenizer from: %s", checkpoint)
71
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
72
-
73
- verifier.verify_reauthored_model(
74
- original_model=transformers_verifier.TransformersModelWrapper(
75
- original_model
76
- ),
77
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
78
- tokenizer=verifier.TokenizerWrapper(tokenizer),
79
- generate_prompts=_PROMPTS.value,
46
+ verify_util.verify_smollm_135m(
47
+ model_version=_MODEL_VERSION.value,
48
+ checkpoint_dir=_CHECKPOINT[_MODEL_VERSION.value],
80
49
  max_new_tokens=_MAX_NEW_TOKENS.value,
81
- atol=1e-04,
50
+ prompts=_PROMPTS.value,
82
51
  )
83
52
 
84
53
 
@@ -0,0 +1,81 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utils for verifying the SmoLLM model."""
16
+ import logging
17
+ import os
18
+ import pathlib
19
+
20
+ from ai_edge_torch.generative.examples.smollm import smollm
21
+ from ai_edge_torch.generative.utilities import loader
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _BUILDER = {
27
+ "v1": smollm.build_model,
28
+ "v2": smollm.build_model_v2,
29
+ }
30
+
31
+ DEFAULT_PROMPTS = ["What is the meaning of life?"]
32
+
33
+
34
+ def verify_smollm_135m(
35
+ model_version: str,
36
+ checkpoint_dir: str,
37
+ weight_filename: str = "model.safetensors",
38
+ max_new_tokens: int = 30,
39
+ initialize_from_local: bool = True,
40
+ prompts: list[str] | None = None,
41
+ ) -> bool:
42
+ """Verifies the reauthored SmoLLM model with a custom loader."""
43
+ logging.info("Loading the original model from: %s", checkpoint_dir)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
45
+ checkpoint_dir
46
+ )
47
+
48
+ logging.info("Building the reauthored model from: %s", checkpoint_dir)
49
+ custom_loader = (
50
+ None
51
+ if initialize_from_local
52
+ else loader.get_custom_loader("", "safetensors")
53
+ )
54
+
55
+ if initialize_from_local:
56
+ # Locate the cached dir.
57
+ cached_config_file = transformers.utils.cached_file(
58
+ checkpoint_dir, transformers.utils.CONFIG_NAME
59
+ )
60
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
61
+ else:
62
+ reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
63
+
64
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
65
+ reauthored_model = _BUILDER[model_version](
66
+ checkpoint_path=reauthored_checkpoint,
67
+ custom_loader=custom_loader,
68
+ )
69
+
70
+ logging.info("Loading the tokenizer from: %s", checkpoint_dir)
71
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
72
+ return verifier.verify_reauthored_model(
73
+ original_model=transformers_verifier.TransformersModelWrapper(
74
+ original_model
75
+ ),
76
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
77
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
78
+ generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
79
+ max_new_tokens=max_new_tokens,
80
+ atol=1e-04,
81
+ )
@@ -113,7 +113,9 @@ def get_model_config() -> cfg.ModelConfig:
113
113
  use_bias=True,
114
114
  )
115
115
 
116
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
116
+ norm_config = cfg.NormalizationConfig(
117
+ type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=False
118
+ )
117
119
 
118
120
  block_config = cfg.TransformerBlockConfig(
119
121
  attn_config=attn_config,
@@ -129,7 +131,6 @@ def get_model_config() -> cfg.ModelConfig:
129
131
  embedding_dim=embedding_dim,
130
132
  block_configs=block_config,
131
133
  final_norm_config=norm_config,
132
- enable_hlfb=True,
133
134
  )
134
135
 
135
136
  return config
@@ -164,7 +165,9 @@ def get_fake_model_config() -> cfg.ModelConfig:
164
165
  use_bias=True,
165
166
  )
166
167
 
167
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
168
+ norm_config = cfg.NormalizationConfig(
169
+ type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=False
170
+ )
168
171
 
169
172
  block_config = cfg.TransformerBlockConfig(
170
173
  attn_config=attn_config,
@@ -180,7 +183,6 @@ def get_fake_model_config() -> cfg.ModelConfig:
180
183
  embedding_dim=embedding_dim,
181
184
  block_configs=block_config,
182
185
  final_norm_config=norm_config,
183
- enable_hlfb=True,
184
186
  )
185
187
 
186
188
  return config
@@ -393,8 +393,7 @@ def get_model_config_t5() -> cfg.ModelConfig:
393
393
  )
394
394
  # T5 Confirmed as RMS Norm and eps = 1e-6 TJA.
395
395
  norm_config = cfg.NormalizationConfig(
396
- type=cfg.NormalizationType.RMS_NORM,
397
- epsilon=1e-6,
396
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, enable_hlfb=False
398
397
  )
399
398
  block_config = cfg.TransformerBlockConfig(
400
399
  attn_config=attn_config,
@@ -411,7 +410,6 @@ def get_model_config_t5() -> cfg.ModelConfig:
411
410
  block_configs=block_config,
412
411
  final_norm_config=norm_config,
413
412
  lm_head_use_bias=False,
414
- enable_hlfb=True,
415
413
  )
416
414
  return config
417
415
 
@@ -138,7 +138,9 @@ def get_model_config() -> cfg.ModelConfig:
138
138
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
139
139
  intermediate_size=256,
140
140
  )
141
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
141
+ norm_config = cfg.NormalizationConfig(
142
+ type=cfg.NormalizationType.RMS_NORM, enable_hlfb=False
143
+ )
142
144
  block_config = cfg.TransformerBlockConfig(
143
145
  attn_config=attn_config,
144
146
  ff_config=ff_config,
@@ -152,5 +154,6 @@ def get_model_config() -> cfg.ModelConfig:
152
154
  embedding_dim=128,
153
155
  block_configs=block_config,
154
156
  final_norm_config=norm_config,
157
+ enable_hlfb=False,
155
158
  )
156
159
  return config
@@ -108,7 +108,9 @@ def get_model_config() -> cfg.ModelConfig:
108
108
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
109
109
  intermediate_size=256,
110
110
  )
111
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
111
+ norm_config = cfg.NormalizationConfig(
112
+ type=cfg.NormalizationType.RMS_NORM, enable_hlfb=False
113
+ )
112
114
  block_config = cfg.TransformerBlockConfig(
113
115
  attn_config=attn_config,
114
116
  ff_config=ff_config,
@@ -122,7 +124,6 @@ def get_model_config() -> cfg.ModelConfig:
122
124
  embedding_dim=128,
123
125
  block_configs=block_config,
124
126
  final_norm_config=norm_config,
125
- enable_hlfb=True,
126
127
  )
127
128
  return config
128
129
 
@@ -51,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
51
51
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
52
52
  intermediate_size=5632,
53
53
  )
54
- norm_config = cfg.NormalizationConfig(
55
- type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
56
- )
54
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
57
55
  block_config = cfg.TransformerBlockConfig(
58
56
  attn_config=attn_config,
59
57
  ff_config=ff_config,
@@ -69,7 +67,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
69
67
  block_configs=block_config,
70
68
  final_norm_config=norm_config,
71
69
  lm_head_share_weight_with_embedding=False,
72
- enable_hlfb=True,
73
70
  )
74
71
  return config
75
72
 
@@ -15,15 +15,10 @@
15
15
 
16
16
  """Verifies the reauthored TinyLlama-1.1B model."""
17
17
 
18
- import logging
19
- import pathlib
20
18
 
21
19
  from absl import app
22
20
  from absl import flags
23
- from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
21
+ from ai_edge_torch.generative.examples.tiny_llama import verify_util
27
22
 
28
23
 
29
24
  _PROMPTS = flags.DEFINE_multi_string(
@@ -39,32 +34,10 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
39
34
 
40
35
 
41
36
  def main(_):
42
- checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
43
- logging.info("Loading the original model from: %s", checkpoint)
44
- original_model = transformers.AutoModelForCausalLM.from_pretrained(
45
- checkpoint, trust_remote_code=True
46
- )
47
-
48
- # Locate the cached dir.
49
- cached_config_file = transformers.utils.cached_file(
50
- checkpoint, transformers.utils.CONFIG_NAME
51
- )
52
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
53
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
54
- reauthored_model = tiny_llama.build_model(str(reauthored_checkpoint))
55
-
56
- logging.info("Loading the tokenizer from: %s", checkpoint)
57
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
58
-
59
- verifier.verify_reauthored_model(
60
- original_model=transformers_verifier.TransformersModelWrapper(
61
- original_model
62
- ),
63
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
- tokenizer=verifier.TokenizerWrapper(tokenizer),
65
- generate_prompts=_PROMPTS.value,
37
+ verify_util.verify_tiny_llama(
38
+ checkpoint_dir="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
66
39
  max_new_tokens=_MAX_NEW_TOKENS.value,
67
- atol=1e-04,
40
+ prompts=_PROMPTS.value,
68
41
  )
69
42
 
70
43