ai-edge-torch-nightly 0.5.0.dev20250516__py3-none-any.whl → 0.5.0.dev20250518__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. ai_edge_torch/_convert/conversion.py +1 -0
  2. ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
  3. ai_edge_torch/_convert/fx_passes/eliminate_dead_code_pass.py +40 -0
  4. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +1 -4
  5. ai_edge_torch/generative/examples/amd_llama_135m/verify.py +4 -32
  6. ai_edge_torch/generative/examples/amd_llama_135m/verify_util.py +76 -0
  7. ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -4
  8. ai_edge_torch/generative/examples/deepseek/verify.py +4 -30
  9. ai_edge_torch/generative/examples/deepseek/verify_util.py +76 -0
  10. ai_edge_torch/generative/examples/gemma/gemma1.py +1 -5
  11. ai_edge_torch/generative/examples/gemma/gemma2.py +1 -5
  12. ai_edge_torch/generative/examples/gemma/verify_util.py +51 -6
  13. ai_edge_torch/generative/examples/gemma3/decoder.py +10 -10
  14. ai_edge_torch/generative/examples/gemma3/gemma3.py +1 -3
  15. ai_edge_torch/generative/examples/gemma3/image_encoder.py +1 -4
  16. ai_edge_torch/generative/examples/gemma3/verify_util.py +13 -0
  17. ai_edge_torch/generative/examples/hammer/hammer.py +1 -4
  18. ai_edge_torch/generative/examples/hammer/verify.py +5 -35
  19. ai_edge_torch/generative/examples/hammer/verify_util.py +82 -0
  20. ai_edge_torch/generative/examples/llama/llama.py +1 -4
  21. ai_edge_torch/generative/examples/llama/verify.py +5 -38
  22. ai_edge_torch/generative/examples/llama/verify_util.py +81 -0
  23. ai_edge_torch/generative/examples/openelm/openelm.py +1 -2
  24. ai_edge_torch/generative/examples/openelm/verify.py +4 -31
  25. ai_edge_torch/generative/examples/openelm/verify_util.py +76 -0
  26. ai_edge_torch/generative/examples/paligemma/decoder.py +1 -5
  27. ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -5
  28. ai_edge_torch/generative/examples/paligemma/image_encoder.py +1 -4
  29. ai_edge_torch/generative/examples/phi/phi2.py +1 -4
  30. ai_edge_torch/generative/examples/phi/phi3.py +1 -4
  31. ai_edge_torch/generative/examples/phi/phi4.py +1 -4
  32. ai_edge_torch/generative/examples/phi/verify.py +6 -24
  33. ai_edge_torch/generative/examples/phi/verify_phi3.py +5 -28
  34. ai_edge_torch/generative/examples/phi/verify_phi4.py +5 -28
  35. ai_edge_torch/generative/examples/phi/verify_util.py +84 -0
  36. ai_edge_torch/generative/examples/qwen/qwen.py +1 -4
  37. ai_edge_torch/generative/examples/qwen/verify.py +5 -35
  38. ai_edge_torch/generative/examples/qwen/verify_util.py +83 -0
  39. ai_edge_torch/generative/examples/qwen_vl/decoder.py +1 -2
  40. ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +1 -3
  41. ai_edge_torch/generative/examples/smollm/smollm.py +1 -4
  42. ai_edge_torch/generative/examples/smollm/verify.py +5 -36
  43. ai_edge_torch/generative/examples/smollm/verify_util.py +81 -0
  44. ai_edge_torch/generative/examples/stable_diffusion/clip.py +6 -4
  45. ai_edge_torch/generative/examples/t5/t5.py +1 -3
  46. ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
  47. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -2
  48. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +1 -4
  49. ai_edge_torch/generative/examples/tiny_llama/verify.py +4 -31
  50. ai_edge_torch/generative/examples/tiny_llama/verify_util.py +76 -0
  51. ai_edge_torch/generative/layers/model_config.py +2 -2
  52. ai_edge_torch/generative/utilities/converter.py +2 -1
  53. ai_edge_torch/generative/utilities/loader.py +11 -1
  54. ai_edge_torch/odml_torch/lowerings/utils.py +13 -0
  55. ai_edge_torch/version.py +1 -1
  56. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/METADATA +1 -1
  57. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/RECORD +60 -50
  58. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/LICENSE +0 -0
  59. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/WHEEL +0 -0
  60. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/top_level.txt +0 -0
@@ -15,15 +15,9 @@
15
15
 
16
16
  """Verifies the reauthored Hammer 2.1 0.5B and 1.5B models."""
17
17
 
18
- import logging
19
- import pathlib
20
-
21
18
  from absl import app
22
19
  from absl import flags
23
- from ai_edge_torch.generative.examples.hammer import hammer
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
20
+ from ai_edge_torch.generative.examples.hammer import verify_util
27
21
 
28
22
 
29
23
  _MODEL_SIZE = flags.DEFINE_enum(
@@ -48,37 +42,13 @@ _CHECKPOINT = {
48
42
  "1.5b": "MadeAgents/Hammer2.1-1.5b",
49
43
  }
50
44
 
51
- _BUILDER = {
52
- "0.5b": hammer.build_0_5b_model,
53
- "1.5b": hammer.build_1_5b_model,
54
- }
55
-
56
45
 
57
46
  def main(_):
58
- checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
59
- logging.info("Loading the original model from: %s", checkpoint)
60
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
61
-
62
- # Locate the cached dir.
63
- cached_config_file = transformers.utils.cached_file(
64
- checkpoint, transformers.utils.CONFIG_NAME
65
- )
66
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
67
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
68
- reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
69
-
70
- logging.info("Loading the tokenizer from: %s", checkpoint)
71
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
72
-
73
- verifier.verify_reauthored_model(
74
- original_model=transformers_verifier.TransformersModelWrapper(
75
- original_model
76
- ),
77
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
78
- tokenizer=verifier.TokenizerWrapper(tokenizer),
79
- generate_prompts=_PROMPTS.value,
47
+ verify_util.verify_hammer(
48
+ model_size=_MODEL_SIZE.value,
49
+ checkpoint_dir=_CHECKPOINT[_MODEL_SIZE.value],
80
50
  max_new_tokens=_MAX_NEW_TOKENS.value,
81
- atol=1e-04,
51
+ prompts=_PROMPTS.value,
82
52
  )
83
53
 
84
54
 
@@ -0,0 +1,82 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utils for verifying the Hammer 2.1 model."""
16
+ import logging
17
+ import os
18
+ import pathlib
19
+
20
+ from ai_edge_torch.generative.examples.hammer import hammer
21
+ from ai_edge_torch.generative.utilities import loader
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+
27
+ _BUILDER = {
28
+ "0.5b": hammer.build_0_5b_model,
29
+ "1.5b": hammer.build_1_5b_model,
30
+ }
31
+
32
+ DEFAULT_PROMPTS = ["What is the meaning of life?"]
33
+
34
+
35
+ def verify_hammer(
36
+ model_size: str,
37
+ checkpoint_dir: str,
38
+ weight_filename: str = "model.safetensors",
39
+ max_new_tokens: int = 30,
40
+ initialize_from_local: bool = True,
41
+ prompts: list[str] | None = None,
42
+ ) -> bool:
43
+ """Verifies the reauthored Hammer 2.1 model with a custom loader."""
44
+ logging.info("Loading the original model from: %s", checkpoint_dir)
45
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
46
+ checkpoint_dir
47
+ )
48
+
49
+ logging.info("Building the reauthored model from: %s", checkpoint_dir)
50
+ custom_loader = (
51
+ None
52
+ if initialize_from_local
53
+ else loader.get_custom_loader("", "safetensors")
54
+ )
55
+
56
+ if initialize_from_local:
57
+ # Locate the cached dir.
58
+ cached_config_file = transformers.utils.cached_file(
59
+ checkpoint_dir, transformers.utils.CONFIG_NAME
60
+ )
61
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
62
+ else:
63
+ reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
64
+
65
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
66
+ reauthored_model = _BUILDER[model_size](
67
+ checkpoint_path=reauthored_checkpoint,
68
+ custom_loader=custom_loader,
69
+ )
70
+
71
+ logging.info("Loading the tokenizer from: %s", checkpoint_dir)
72
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
73
+ return verifier.verify_reauthored_model(
74
+ original_model=transformers_verifier.TransformersModelWrapper(
75
+ original_model
76
+ ),
77
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
78
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
79
+ generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
80
+ max_new_tokens=max_new_tokens,
81
+ atol=1e-04,
82
+ )
@@ -121,9 +121,7 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
121
121
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
122
122
  intermediate_size=8192,
123
123
  )
124
- norm_config = cfg.NormalizationConfig(
125
- type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
126
- )
124
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
127
125
  block_config = cfg.TransformerBlockConfig(
128
126
  attn_config=attn_config,
129
127
  ff_config=ff_config,
@@ -152,7 +150,6 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
152
150
  kv_cache_max_len=kv_cache_max_len,
153
151
  block_configs=block_config,
154
152
  final_norm_config=norm_config,
155
- enable_hlfb=True,
156
153
  build_rope=build_rope,
157
154
  )
158
155
  return config
@@ -15,15 +15,9 @@
15
15
 
16
16
  """Verifies the reauthored Llama 3.2-1B model."""
17
17
 
18
- import logging
19
- import pathlib
20
-
21
18
  from absl import app
22
19
  from absl import flags
23
- from ai_edge_torch.generative.examples.llama import llama
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
20
+ from ai_edge_torch.generative.examples.llama import verify_util
27
21
 
28
22
  _MODEL_SIZE = flags.DEFINE_enum(
29
23
  "model_size",
@@ -47,40 +41,13 @@ _CHECKPOINT = {
47
41
  "3b": "meta-llama/Llama-3.2-3B-Instruct",
48
42
  }
49
43
 
50
- _BUILDER = {
51
- "1b": llama.build_1b_model,
52
- "3b": llama.build_3b_model,
53
- }
54
-
55
44
 
56
45
  def main(_):
57
- checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
58
- logging.info("Loading the original model from: %s", checkpoint)
59
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
60
-
61
- # Locate the cached dir.
62
- cached_config_file = transformers.utils.cached_file(
63
- checkpoint, transformers.utils.CONFIG_NAME
64
- )
65
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
66
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
67
- reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
68
-
69
- logging.info("Loading the tokenizer from: %s", checkpoint)
70
- # Llama tokenizer_config.json sets a fast tokenizer class explicitly,
71
- # "PreTrainedTokenizerFast". It works only when the fast tokenizer is
72
- # available.
73
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
74
-
75
- verifier.verify_reauthored_model(
76
- original_model=transformers_verifier.TransformersModelWrapper(
77
- original_model
78
- ),
79
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
80
- tokenizer=verifier.TokenizerWrapper(tokenizer),
81
- generate_prompts=_PROMPTS.value,
46
+ verify_util.verify_llama_3_2(
47
+ model_size=_MODEL_SIZE.value,
48
+ checkpoint_dir=_CHECKPOINT[_MODEL_SIZE.value],
82
49
  max_new_tokens=_MAX_NEW_TOKENS.value,
83
- atol=1e-04,
50
+ prompts=_PROMPTS.value,
84
51
  )
85
52
 
86
53
 
@@ -0,0 +1,81 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utils for verifying the Llama 3.2-1B model."""
16
+ import logging
17
+ import os
18
+ import pathlib
19
+
20
+ from ai_edge_torch.generative.examples.llama import llama
21
+ from ai_edge_torch.generative.utilities import loader
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _BUILDER = {
27
+ "1b": llama.build_1b_model,
28
+ "3b": llama.build_3b_model,
29
+ }
30
+
31
+ DEFAULT_PROMPTS = ["What is the meaning of life?"]
32
+
33
+
34
+ def verify_llama_3_2(
35
+ model_size: str,
36
+ checkpoint_dir: str,
37
+ weight_filename: str = "model.safetensors",
38
+ max_new_tokens: int = 30,
39
+ initialize_from_local: bool = True,
40
+ prompts: list[str] | None = None,
41
+ ) -> bool:
42
+ """Verifies the reauthored Llama 3.2 model with a custom loader."""
43
+ logging.info("Loading the original model from: %s", checkpoint_dir)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
45
+ checkpoint_dir
46
+ )
47
+
48
+ logging.info("Building the reauthored model from: %s", checkpoint_dir)
49
+ custom_loader = (
50
+ None
51
+ if initialize_from_local
52
+ else loader.get_custom_loader("", "safetensors")
53
+ )
54
+
55
+ if initialize_from_local:
56
+ # Locate the cached dir.
57
+ cached_config_file = transformers.utils.cached_file(
58
+ checkpoint_dir, transformers.utils.CONFIG_NAME
59
+ )
60
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
61
+ else:
62
+ reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
63
+
64
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
65
+ reauthored_model = _BUILDER[model_size](
66
+ checkpoint_path=reauthored_checkpoint,
67
+ custom_loader=custom_loader,
68
+ )
69
+
70
+ logging.info("Loading the tokenizer from: %s", checkpoint_dir)
71
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
72
+ return verifier.verify_reauthored_model(
73
+ original_model=transformers_verifier.TransformersModelWrapper(
74
+ original_model
75
+ ),
76
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
77
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
78
+ generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
79
+ max_new_tokens=max_new_tokens,
80
+ atol=1e-04,
81
+ )
@@ -53,7 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
53
53
  The model config for an OpenELM model.
54
54
  """
55
55
  norm_config = cfg.NormalizationConfig(
56
- type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, enable_hlfb=True
56
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
57
57
  )
58
58
  num_heads = [12] * 4 + [16] * 14 + [20] * 12 + [24] * 6
59
59
  num_query_groups = [3] * 4 + [4] * 14 + [5] * 12 + [6] * 6
@@ -101,7 +101,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
101
101
  kv_cache_max_len=kv_cache_max_len,
102
102
  block_configs=[get_block_config(i) for i in range(num_layers)],
103
103
  final_norm_config=norm_config,
104
- enable_hlfb=True,
105
104
  )
106
105
  return config
107
106
 
@@ -15,14 +15,9 @@
15
15
 
16
16
  """Verifies the reauthored OpenELM-3B model."""
17
17
 
18
- import logging
19
- import pathlib
20
18
  from absl import app
21
19
  from absl import flags
22
- from ai_edge_torch.generative.examples.openelm import openelm
23
- from ai_edge_torch.generative.utilities import transformers_verifier
24
- from ai_edge_torch.generative.utilities import verifier
25
- import transformers
20
+ from ai_edge_torch.generative.examples.openelm import verify_util
26
21
 
27
22
 
28
23
  _PROMPTS = flags.DEFINE_multi_string(
@@ -38,32 +33,10 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
38
33
 
39
34
 
40
35
  def main(_):
41
- checkpoint = "apple/OpenELM-3B"
42
- logging.info("Loading the original model from: %s", checkpoint)
43
- original_model = transformers.AutoModelForCausalLM.from_pretrained(
44
- checkpoint, trust_remote_code=True
45
- )
46
-
47
- # Locate the cached dir.
48
- cached_config_file = transformers.utils.cached_file(
49
- checkpoint, transformers.utils.CONFIG_NAME
50
- )
51
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
52
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
53
- reauthored_model = openelm.build_model(str(reauthored_checkpoint))
54
-
55
- tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
56
- logging.info("Loading the tokenizer from: %s", tokenizer_checkpoint)
57
- tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
58
-
59
- verifier.verify_reauthored_model(
60
- original_model=transformers_verifier.TransformersModelWrapper(
61
- original_model
62
- ),
63
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
- tokenizer=verifier.TokenizerWrapper(tokenizer),
65
- generate_prompts=_PROMPTS.value,
36
+ verify_util.verify_openelm(
37
+ checkpoint_dir="apple/OpenELM-3B",
66
38
  max_new_tokens=_MAX_NEW_TOKENS.value,
39
+ prompts=_PROMPTS.value,
67
40
  )
68
41
 
69
42
 
@@ -0,0 +1,76 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utils for verifying the OpenELM model."""
16
+ import logging
17
+ import os
18
+ import pathlib
19
+
20
+ from ai_edge_torch.generative.examples.openelm import openelm
21
+ from ai_edge_torch.generative.utilities import loader
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+
27
+ DEFAULT_PROMPTS = ["What is the meaning of life?"]
28
+
29
+
30
+ def verify_openelm(
31
+ checkpoint_dir: str,
32
+ weight_filename: str = "model.safetensors",
33
+ max_new_tokens: int = 30,
34
+ initialize_from_local: bool = True,
35
+ prompts: list[str] | None = None,
36
+ ) -> bool:
37
+ """Verifies the reauthored OpenELM model with a custom loader."""
38
+ logging.info("Loading the original model from: %s", checkpoint_dir)
39
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
40
+ checkpoint_dir
41
+ )
42
+
43
+ logging.info("Building the reauthored model from: %s", checkpoint_dir)
44
+ custom_loader = (
45
+ None
46
+ if initialize_from_local
47
+ else loader.get_custom_loader("", "safetensors")
48
+ )
49
+
50
+ if initialize_from_local:
51
+ # Locate the cached dir.
52
+ cached_config_file = transformers.utils.cached_file(
53
+ checkpoint_dir, transformers.utils.CONFIG_NAME
54
+ )
55
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
56
+ else:
57
+ reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
58
+
59
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
60
+ reauthored_model = openelm.build_model(
61
+ checkpoint_path=reauthored_checkpoint,
62
+ custom_loader=custom_loader,
63
+ )
64
+
65
+ logging.info("Loading the tokenizer from: %s", checkpoint_dir)
66
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
67
+ return verifier.verify_reauthored_model(
68
+ original_model=transformers_verifier.TransformersModelWrapper(
69
+ original_model
70
+ ),
71
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
72
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
73
+ generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
74
+ max_new_tokens=max_new_tokens,
75
+ atol=1e-04,
76
+ )
@@ -110,10 +110,7 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
110
110
  intermediate_size=16384,
111
111
  )
112
112
  norm_config = cfg.NormalizationConfig(
113
- type=cfg.NormalizationType.RMS_NORM,
114
- epsilon=1e-6,
115
- zero_centered=True,
116
- enable_hlfb=True,
113
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
117
114
  )
118
115
  block_config = cfg.TransformerBlockConfig(
119
116
  attn_config=attn_config,
@@ -132,7 +129,6 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
132
129
  block_configs=block_config,
133
130
  final_norm_config=norm_config,
134
131
  lm_head_use_bias=False,
135
- enable_hlfb=True,
136
132
  )
137
133
  return config
138
134
 
@@ -93,10 +93,7 @@ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
93
93
  The model config for the decoder of a PaliGemma 3B model.
94
94
  """
95
95
  norm_config = cfg.NormalizationConfig(
96
- type=cfg.NormalizationType.RMS_NORM,
97
- epsilon=1e-6,
98
- zero_centered=True,
99
- enable_hlfb=True,
96
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
100
97
  )
101
98
  ff_config = cfg.FeedForwardConfig(
102
99
  type=cfg.FeedForwardType.GATED,
@@ -140,7 +137,6 @@ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
140
137
  block_configs=[get_block_config(i) for i in range(num_layers)],
141
138
  final_norm_config=norm_config,
142
139
  lm_head_use_bias=False,
143
- enable_hlfb=True,
144
140
  final_logit_softcap=30.0,
145
141
  )
146
142
  return config
@@ -118,9 +118,7 @@ def get_image_encoder_config() -> cfg.ModelConfig:
118
118
  use_bias=True,
119
119
  )
120
120
  norm_config = cfg.NormalizationConfig(
121
- type=cfg.NormalizationType.LAYER_NORM,
122
- epsilon=1e-6,
123
- enable_hlfb=True,
121
+ type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
124
122
  )
125
123
  block_config = cfg.TransformerBlockConfig(
126
124
  attn_config=attn_config,
@@ -137,7 +135,6 @@ def get_image_encoder_config() -> cfg.ModelConfig:
137
135
  image_embedding=image_embedding_config,
138
136
  block_configs=block_config,
139
137
  final_norm_config=norm_config,
140
- enable_hlfb=True,
141
138
  )
142
139
  return config
143
140
 
@@ -66,9 +66,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
66
66
  intermediate_size=10240,
67
67
  use_bias=True,
68
68
  )
69
- norm_config = cfg.NormalizationConfig(
70
- type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=True
71
- )
69
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
72
70
  block_config = cfg.TransformerBlockConfig(
73
71
  attn_config=attn_config,
74
72
  ff_config=ff_config,
@@ -85,7 +83,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
85
83
  final_norm_config=norm_config,
86
84
  lm_head_use_bias=True,
87
85
  lm_head_share_weight_with_embedding=False,
88
- enable_hlfb=True,
89
86
  )
90
87
  return config
91
88
 
@@ -162,9 +162,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
162
162
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
163
163
  intermediate_size=8192,
164
164
  )
165
- norm_config = cfg.NormalizationConfig(
166
- type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
167
- )
165
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
168
166
  block_config = cfg.TransformerBlockConfig(
169
167
  attn_config=attn_config,
170
168
  ff_config=ff_config,
@@ -192,7 +190,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
192
190
  block_configs=block_config,
193
191
  final_norm_config=norm_config,
194
192
  lm_head_share_weight_with_embedding=False,
195
- enable_hlfb=True,
196
193
  build_rope=build_rope,
197
194
  )
198
195
  return config
@@ -112,9 +112,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
112
112
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
113
113
  intermediate_size=8192,
114
114
  )
115
- norm_config = cfg.NormalizationConfig(
116
- type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
117
- )
115
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
118
116
  block_config = cfg.TransformerBlockConfig(
119
117
  attn_config=attn_config,
120
118
  ff_config=ff_config,
@@ -141,7 +139,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
141
139
  embedding_dim=3072,
142
140
  block_configs=block_config,
143
141
  final_norm_config=norm_config,
144
- enable_hlfb=True,
145
142
  build_rope=build_rope,
146
143
  )
147
144
  return config
@@ -14,15 +14,10 @@
14
14
  # ==============================================================================
15
15
 
16
16
  """Verifies the reauthored Phi-2 model."""
17
- import logging
18
17
 
19
18
  from absl import app
20
19
  from absl import flags
21
- from ai_edge_torch.generative.examples.phi import phi2
22
- from ai_edge_torch.generative.utilities import transformers_verifier
23
- from ai_edge_torch.generative.utilities import verifier
24
- import kagglehub
25
- import transformers
20
+ from ai_edge_torch.generative.examples.phi import verify_util
26
21
 
27
22
 
28
23
  _PROMPTS = flags.DEFINE_multi_string(
@@ -38,25 +33,12 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
38
33
 
39
34
 
40
35
  def main(_):
41
- checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
42
- logging.info("Loading the original model from: %s", checkpoint)
43
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
44
-
45
- logging.info("Building the reauthored model from: %s", checkpoint)
46
- reauthored_model = phi2.build_model(checkpoint)
47
-
48
- logging.info("Loading the tokenizer from: %s", checkpoint)
49
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
50
-
51
- verifier.verify_reauthored_model(
52
- original_model=transformers_verifier.TransformersModelWrapper(
53
- original_model
54
- ),
55
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
56
- tokenizer=verifier.TokenizerWrapper(tokenizer),
57
- generate_prompts=_PROMPTS.value,
36
+ verify_util.verify_phi(
37
+ version="v2",
38
+ checkpoint_dir="microsoft/phi-2",
58
39
  max_new_tokens=_MAX_NEW_TOKENS.value,
59
- atol=1e-03,
40
+ prompts=_PROMPTS.value,
41
+ atol=1e-02,
60
42
  )
61
43
 
62
44
 
@@ -15,15 +15,10 @@
15
15
 
16
16
  """Verifies the reauthored Phi-3.5 model."""
17
17
 
18
- import logging
19
- import pathlib
20
18
 
21
19
  from absl import app
22
20
  from absl import flags
23
- from ai_edge_torch.generative.examples.phi import phi3
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
21
+ from ai_edge_torch.generative.examples.phi import verify_util
27
22
 
28
23
 
29
24
  _PROMPTS = flags.DEFINE_multi_string(
@@ -39,29 +34,11 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
39
34
 
40
35
 
41
36
  def main(_):
42
- checkpoint = "microsoft/Phi-3.5-mini-instruct"
43
- logging.info("Loading the original model from: %s", checkpoint)
44
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
-
46
- # Locate the cached dir.
47
- cached_config_file = transformers.utils.cached_file(
48
- checkpoint, transformers.utils.CONFIG_NAME
49
- )
50
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
- reauthored_model = phi3.build_model(reauthored_checkpoint)
53
-
54
- logging.info("Loading the tokenizer from: %s", checkpoint)
55
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
56
-
57
- verifier.verify_reauthored_model(
58
- original_model=transformers_verifier.TransformersModelWrapper(
59
- original_model
60
- ),
61
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
- tokenizer=verifier.TokenizerWrapper(tokenizer),
63
- generate_prompts=_PROMPTS.value,
37
+ verify_util.verify_phi(
38
+ version="v3",
39
+ checkpoint_dir="microsoft/Phi-3.5-mini-instruct",
64
40
  max_new_tokens=_MAX_NEW_TOKENS.value,
41
+ prompts=_PROMPTS.value,
65
42
  )
66
43
 
67
44