ai-edge-torch-nightly 0.5.0.dev20250516__py3-none-any.whl → 0.5.0.dev20250518__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +1 -0
- ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
- ai_edge_torch/_convert/fx_passes/eliminate_dead_code_pass.py +40 -0
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +1 -4
- ai_edge_torch/generative/examples/amd_llama_135m/verify.py +4 -32
- ai_edge_torch/generative/examples/amd_llama_135m/verify_util.py +76 -0
- ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -4
- ai_edge_torch/generative/examples/deepseek/verify.py +4 -30
- ai_edge_torch/generative/examples/deepseek/verify_util.py +76 -0
- ai_edge_torch/generative/examples/gemma/gemma1.py +1 -5
- ai_edge_torch/generative/examples/gemma/gemma2.py +1 -5
- ai_edge_torch/generative/examples/gemma/verify_util.py +51 -6
- ai_edge_torch/generative/examples/gemma3/decoder.py +10 -10
- ai_edge_torch/generative/examples/gemma3/gemma3.py +1 -3
- ai_edge_torch/generative/examples/gemma3/image_encoder.py +1 -4
- ai_edge_torch/generative/examples/gemma3/verify_util.py +13 -0
- ai_edge_torch/generative/examples/hammer/hammer.py +1 -4
- ai_edge_torch/generative/examples/hammer/verify.py +5 -35
- ai_edge_torch/generative/examples/hammer/verify_util.py +82 -0
- ai_edge_torch/generative/examples/llama/llama.py +1 -4
- ai_edge_torch/generative/examples/llama/verify.py +5 -38
- ai_edge_torch/generative/examples/llama/verify_util.py +81 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +1 -2
- ai_edge_torch/generative/examples/openelm/verify.py +4 -31
- ai_edge_torch/generative/examples/openelm/verify_util.py +76 -0
- ai_edge_torch/generative/examples/paligemma/decoder.py +1 -5
- ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -5
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +1 -4
- ai_edge_torch/generative/examples/phi/phi2.py +1 -4
- ai_edge_torch/generative/examples/phi/phi3.py +1 -4
- ai_edge_torch/generative/examples/phi/phi4.py +1 -4
- ai_edge_torch/generative/examples/phi/verify.py +6 -24
- ai_edge_torch/generative/examples/phi/verify_phi3.py +5 -28
- ai_edge_torch/generative/examples/phi/verify_phi4.py +5 -28
- ai_edge_torch/generative/examples/phi/verify_util.py +84 -0
- ai_edge_torch/generative/examples/qwen/qwen.py +1 -4
- ai_edge_torch/generative/examples/qwen/verify.py +5 -35
- ai_edge_torch/generative/examples/qwen/verify_util.py +83 -0
- ai_edge_torch/generative/examples/qwen_vl/decoder.py +1 -2
- ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +1 -3
- ai_edge_torch/generative/examples/smollm/smollm.py +1 -4
- ai_edge_torch/generative/examples/smollm/verify.py +5 -36
- ai_edge_torch/generative/examples/smollm/verify_util.py +81 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +6 -4
- ai_edge_torch/generative/examples/t5/t5.py +1 -3
- ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -2
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +1 -4
- ai_edge_torch/generative/examples/tiny_llama/verify.py +4 -31
- ai_edge_torch/generative/examples/tiny_llama/verify_util.py +76 -0
- ai_edge_torch/generative/layers/model_config.py +2 -2
- ai_edge_torch/generative/utilities/converter.py +2 -1
- ai_edge_torch/generative/utilities/loader.py +11 -1
- ai_edge_torch/odml_torch/lowerings/utils.py +13 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/RECORD +60 -50
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/top_level.txt +0 -0
@@ -15,15 +15,9 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored Hammer 2.1 0.5B and 1.5B models."""
|
17
17
|
|
18
|
-
import logging
|
19
|
-
import pathlib
|
20
|
-
|
21
18
|
from absl import app
|
22
19
|
from absl import flags
|
23
|
-
from ai_edge_torch.generative.examples.hammer import
|
24
|
-
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
-
from ai_edge_torch.generative.utilities import verifier
|
26
|
-
import transformers
|
20
|
+
from ai_edge_torch.generative.examples.hammer import verify_util
|
27
21
|
|
28
22
|
|
29
23
|
_MODEL_SIZE = flags.DEFINE_enum(
|
@@ -48,37 +42,13 @@ _CHECKPOINT = {
|
|
48
42
|
"1.5b": "MadeAgents/Hammer2.1-1.5b",
|
49
43
|
}
|
50
44
|
|
51
|
-
_BUILDER = {
|
52
|
-
"0.5b": hammer.build_0_5b_model,
|
53
|
-
"1.5b": hammer.build_1_5b_model,
|
54
|
-
}
|
55
|
-
|
56
45
|
|
57
46
|
def main(_):
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
# Locate the cached dir.
|
63
|
-
cached_config_file = transformers.utils.cached_file(
|
64
|
-
checkpoint, transformers.utils.CONFIG_NAME
|
65
|
-
)
|
66
|
-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
67
|
-
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
68
|
-
reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
|
69
|
-
|
70
|
-
logging.info("Loading the tokenizer from: %s", checkpoint)
|
71
|
-
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
72
|
-
|
73
|
-
verifier.verify_reauthored_model(
|
74
|
-
original_model=transformers_verifier.TransformersModelWrapper(
|
75
|
-
original_model
|
76
|
-
),
|
77
|
-
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
78
|
-
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
79
|
-
generate_prompts=_PROMPTS.value,
|
47
|
+
verify_util.verify_hammer(
|
48
|
+
model_size=_MODEL_SIZE.value,
|
49
|
+
checkpoint_dir=_CHECKPOINT[_MODEL_SIZE.value],
|
80
50
|
max_new_tokens=_MAX_NEW_TOKENS.value,
|
81
|
-
|
51
|
+
prompts=_PROMPTS.value,
|
82
52
|
)
|
83
53
|
|
84
54
|
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Utils for verifying the Hammer 2.1 model."""
|
16
|
+
import logging
|
17
|
+
import os
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.examples.hammer import hammer
|
21
|
+
from ai_edge_torch.generative.utilities import loader
|
22
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
|
27
|
+
_BUILDER = {
|
28
|
+
"0.5b": hammer.build_0_5b_model,
|
29
|
+
"1.5b": hammer.build_1_5b_model,
|
30
|
+
}
|
31
|
+
|
32
|
+
DEFAULT_PROMPTS = ["What is the meaning of life?"]
|
33
|
+
|
34
|
+
|
35
|
+
def verify_hammer(
|
36
|
+
model_size: str,
|
37
|
+
checkpoint_dir: str,
|
38
|
+
weight_filename: str = "model.safetensors",
|
39
|
+
max_new_tokens: int = 30,
|
40
|
+
initialize_from_local: bool = True,
|
41
|
+
prompts: list[str] | None = None,
|
42
|
+
) -> bool:
|
43
|
+
"""Verifies the reauthored Hammer 2.1 model with a custom loader."""
|
44
|
+
logging.info("Loading the original model from: %s", checkpoint_dir)
|
45
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
46
|
+
checkpoint_dir
|
47
|
+
)
|
48
|
+
|
49
|
+
logging.info("Building the reauthored model from: %s", checkpoint_dir)
|
50
|
+
custom_loader = (
|
51
|
+
None
|
52
|
+
if initialize_from_local
|
53
|
+
else loader.get_custom_loader("", "safetensors")
|
54
|
+
)
|
55
|
+
|
56
|
+
if initialize_from_local:
|
57
|
+
# Locate the cached dir.
|
58
|
+
cached_config_file = transformers.utils.cached_file(
|
59
|
+
checkpoint_dir, transformers.utils.CONFIG_NAME
|
60
|
+
)
|
61
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
62
|
+
else:
|
63
|
+
reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
|
64
|
+
|
65
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
66
|
+
reauthored_model = _BUILDER[model_size](
|
67
|
+
checkpoint_path=reauthored_checkpoint,
|
68
|
+
custom_loader=custom_loader,
|
69
|
+
)
|
70
|
+
|
71
|
+
logging.info("Loading the tokenizer from: %s", checkpoint_dir)
|
72
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
|
73
|
+
return verifier.verify_reauthored_model(
|
74
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
75
|
+
original_model
|
76
|
+
),
|
77
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
78
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
79
|
+
generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
|
80
|
+
max_new_tokens=max_new_tokens,
|
81
|
+
atol=1e-04,
|
82
|
+
)
|
@@ -121,9 +121,7 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
121
121
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
122
122
|
intermediate_size=8192,
|
123
123
|
)
|
124
|
-
norm_config = cfg.NormalizationConfig(
|
125
|
-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
|
126
|
-
)
|
124
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
127
125
|
block_config = cfg.TransformerBlockConfig(
|
128
126
|
attn_config=attn_config,
|
129
127
|
ff_config=ff_config,
|
@@ -152,7 +150,6 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
152
150
|
kv_cache_max_len=kv_cache_max_len,
|
153
151
|
block_configs=block_config,
|
154
152
|
final_norm_config=norm_config,
|
155
|
-
enable_hlfb=True,
|
156
153
|
build_rope=build_rope,
|
157
154
|
)
|
158
155
|
return config
|
@@ -15,15 +15,9 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored Llama 3.2-1B model."""
|
17
17
|
|
18
|
-
import logging
|
19
|
-
import pathlib
|
20
|
-
|
21
18
|
from absl import app
|
22
19
|
from absl import flags
|
23
|
-
from ai_edge_torch.generative.examples.llama import
|
24
|
-
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
-
from ai_edge_torch.generative.utilities import verifier
|
26
|
-
import transformers
|
20
|
+
from ai_edge_torch.generative.examples.llama import verify_util
|
27
21
|
|
28
22
|
_MODEL_SIZE = flags.DEFINE_enum(
|
29
23
|
"model_size",
|
@@ -47,40 +41,13 @@ _CHECKPOINT = {
|
|
47
41
|
"3b": "meta-llama/Llama-3.2-3B-Instruct",
|
48
42
|
}
|
49
43
|
|
50
|
-
_BUILDER = {
|
51
|
-
"1b": llama.build_1b_model,
|
52
|
-
"3b": llama.build_3b_model,
|
53
|
-
}
|
54
|
-
|
55
44
|
|
56
45
|
def main(_):
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
# Locate the cached dir.
|
62
|
-
cached_config_file = transformers.utils.cached_file(
|
63
|
-
checkpoint, transformers.utils.CONFIG_NAME
|
64
|
-
)
|
65
|
-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
66
|
-
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
67
|
-
reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
|
68
|
-
|
69
|
-
logging.info("Loading the tokenizer from: %s", checkpoint)
|
70
|
-
# Llama tokenizer_config.json sets a fast tokenizer class explicitly,
|
71
|
-
# "PreTrainedTokenizerFast". It works only when the fast tokenizer is
|
72
|
-
# available.
|
73
|
-
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
74
|
-
|
75
|
-
verifier.verify_reauthored_model(
|
76
|
-
original_model=transformers_verifier.TransformersModelWrapper(
|
77
|
-
original_model
|
78
|
-
),
|
79
|
-
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
80
|
-
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
81
|
-
generate_prompts=_PROMPTS.value,
|
46
|
+
verify_util.verify_llama_3_2(
|
47
|
+
model_size=_MODEL_SIZE.value,
|
48
|
+
checkpoint_dir=_CHECKPOINT[_MODEL_SIZE.value],
|
82
49
|
max_new_tokens=_MAX_NEW_TOKENS.value,
|
83
|
-
|
50
|
+
prompts=_PROMPTS.value,
|
84
51
|
)
|
85
52
|
|
86
53
|
|
@@ -0,0 +1,81 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Utils for verifying the Llama 3.2-1B model."""
|
16
|
+
import logging
|
17
|
+
import os
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.examples.llama import llama
|
21
|
+
from ai_edge_torch.generative.utilities import loader
|
22
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
_BUILDER = {
|
27
|
+
"1b": llama.build_1b_model,
|
28
|
+
"3b": llama.build_3b_model,
|
29
|
+
}
|
30
|
+
|
31
|
+
DEFAULT_PROMPTS = ["What is the meaning of life?"]
|
32
|
+
|
33
|
+
|
34
|
+
def verify_llama_3_2(
|
35
|
+
model_size: str,
|
36
|
+
checkpoint_dir: str,
|
37
|
+
weight_filename: str = "model.safetensors",
|
38
|
+
max_new_tokens: int = 30,
|
39
|
+
initialize_from_local: bool = True,
|
40
|
+
prompts: list[str] | None = None,
|
41
|
+
) -> bool:
|
42
|
+
"""Verifies the reauthored Llama 3.2 model with a custom loader."""
|
43
|
+
logging.info("Loading the original model from: %s", checkpoint_dir)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
45
|
+
checkpoint_dir
|
46
|
+
)
|
47
|
+
|
48
|
+
logging.info("Building the reauthored model from: %s", checkpoint_dir)
|
49
|
+
custom_loader = (
|
50
|
+
None
|
51
|
+
if initialize_from_local
|
52
|
+
else loader.get_custom_loader("", "safetensors")
|
53
|
+
)
|
54
|
+
|
55
|
+
if initialize_from_local:
|
56
|
+
# Locate the cached dir.
|
57
|
+
cached_config_file = transformers.utils.cached_file(
|
58
|
+
checkpoint_dir, transformers.utils.CONFIG_NAME
|
59
|
+
)
|
60
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
61
|
+
else:
|
62
|
+
reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
|
63
|
+
|
64
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
65
|
+
reauthored_model = _BUILDER[model_size](
|
66
|
+
checkpoint_path=reauthored_checkpoint,
|
67
|
+
custom_loader=custom_loader,
|
68
|
+
)
|
69
|
+
|
70
|
+
logging.info("Loading the tokenizer from: %s", checkpoint_dir)
|
71
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
|
72
|
+
return verifier.verify_reauthored_model(
|
73
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
74
|
+
original_model
|
75
|
+
),
|
76
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
77
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
78
|
+
generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
|
79
|
+
max_new_tokens=max_new_tokens,
|
80
|
+
atol=1e-04,
|
81
|
+
)
|
@@ -53,7 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
53
53
|
The model config for an OpenELM model.
|
54
54
|
"""
|
55
55
|
norm_config = cfg.NormalizationConfig(
|
56
|
-
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
|
56
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
|
57
57
|
)
|
58
58
|
num_heads = [12] * 4 + [16] * 14 + [20] * 12 + [24] * 6
|
59
59
|
num_query_groups = [3] * 4 + [4] * 14 + [5] * 12 + [6] * 6
|
@@ -101,7 +101,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
101
101
|
kv_cache_max_len=kv_cache_max_len,
|
102
102
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
103
103
|
final_norm_config=norm_config,
|
104
|
-
enable_hlfb=True,
|
105
104
|
)
|
106
105
|
return config
|
107
106
|
|
@@ -15,14 +15,9 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored OpenELM-3B model."""
|
17
17
|
|
18
|
-
import logging
|
19
|
-
import pathlib
|
20
18
|
from absl import app
|
21
19
|
from absl import flags
|
22
|
-
from ai_edge_torch.generative.examples.openelm import
|
23
|
-
from ai_edge_torch.generative.utilities import transformers_verifier
|
24
|
-
from ai_edge_torch.generative.utilities import verifier
|
25
|
-
import transformers
|
20
|
+
from ai_edge_torch.generative.examples.openelm import verify_util
|
26
21
|
|
27
22
|
|
28
23
|
_PROMPTS = flags.DEFINE_multi_string(
|
@@ -38,32 +33,10 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
38
33
|
|
39
34
|
|
40
35
|
def main(_):
|
41
|
-
|
42
|
-
|
43
|
-
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
44
|
-
checkpoint, trust_remote_code=True
|
45
|
-
)
|
46
|
-
|
47
|
-
# Locate the cached dir.
|
48
|
-
cached_config_file = transformers.utils.cached_file(
|
49
|
-
checkpoint, transformers.utils.CONFIG_NAME
|
50
|
-
)
|
51
|
-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
52
|
-
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
53
|
-
reauthored_model = openelm.build_model(str(reauthored_checkpoint))
|
54
|
-
|
55
|
-
tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
|
56
|
-
logging.info("Loading the tokenizer from: %s", tokenizer_checkpoint)
|
57
|
-
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
|
58
|
-
|
59
|
-
verifier.verify_reauthored_model(
|
60
|
-
original_model=transformers_verifier.TransformersModelWrapper(
|
61
|
-
original_model
|
62
|
-
),
|
63
|
-
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
64
|
-
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
65
|
-
generate_prompts=_PROMPTS.value,
|
36
|
+
verify_util.verify_openelm(
|
37
|
+
checkpoint_dir="apple/OpenELM-3B",
|
66
38
|
max_new_tokens=_MAX_NEW_TOKENS.value,
|
39
|
+
prompts=_PROMPTS.value,
|
67
40
|
)
|
68
41
|
|
69
42
|
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Utils for verifying the OpenELM model."""
|
16
|
+
import logging
|
17
|
+
import os
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.examples.openelm import openelm
|
21
|
+
from ai_edge_torch.generative.utilities import loader
|
22
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
|
27
|
+
DEFAULT_PROMPTS = ["What is the meaning of life?"]
|
28
|
+
|
29
|
+
|
30
|
+
def verify_openelm(
|
31
|
+
checkpoint_dir: str,
|
32
|
+
weight_filename: str = "model.safetensors",
|
33
|
+
max_new_tokens: int = 30,
|
34
|
+
initialize_from_local: bool = True,
|
35
|
+
prompts: list[str] | None = None,
|
36
|
+
) -> bool:
|
37
|
+
"""Verifies the reauthored OpenELM model with a custom loader."""
|
38
|
+
logging.info("Loading the original model from: %s", checkpoint_dir)
|
39
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
40
|
+
checkpoint_dir
|
41
|
+
)
|
42
|
+
|
43
|
+
logging.info("Building the reauthored model from: %s", checkpoint_dir)
|
44
|
+
custom_loader = (
|
45
|
+
None
|
46
|
+
if initialize_from_local
|
47
|
+
else loader.get_custom_loader("", "safetensors")
|
48
|
+
)
|
49
|
+
|
50
|
+
if initialize_from_local:
|
51
|
+
# Locate the cached dir.
|
52
|
+
cached_config_file = transformers.utils.cached_file(
|
53
|
+
checkpoint_dir, transformers.utils.CONFIG_NAME
|
54
|
+
)
|
55
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
56
|
+
else:
|
57
|
+
reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
|
58
|
+
|
59
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
60
|
+
reauthored_model = openelm.build_model(
|
61
|
+
checkpoint_path=reauthored_checkpoint,
|
62
|
+
custom_loader=custom_loader,
|
63
|
+
)
|
64
|
+
|
65
|
+
logging.info("Loading the tokenizer from: %s", checkpoint_dir)
|
66
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
|
67
|
+
return verifier.verify_reauthored_model(
|
68
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
69
|
+
original_model
|
70
|
+
),
|
71
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
72
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
73
|
+
generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
|
74
|
+
max_new_tokens=max_new_tokens,
|
75
|
+
atol=1e-04,
|
76
|
+
)
|
@@ -110,10 +110,7 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
110
110
|
intermediate_size=16384,
|
111
111
|
)
|
112
112
|
norm_config = cfg.NormalizationConfig(
|
113
|
-
type=cfg.NormalizationType.RMS_NORM,
|
114
|
-
epsilon=1e-6,
|
115
|
-
zero_centered=True,
|
116
|
-
enable_hlfb=True,
|
113
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
|
117
114
|
)
|
118
115
|
block_config = cfg.TransformerBlockConfig(
|
119
116
|
attn_config=attn_config,
|
@@ -132,7 +129,6 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
132
129
|
block_configs=block_config,
|
133
130
|
final_norm_config=norm_config,
|
134
131
|
lm_head_use_bias=False,
|
135
|
-
enable_hlfb=True,
|
136
132
|
)
|
137
133
|
return config
|
138
134
|
|
@@ -93,10 +93,7 @@ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
93
93
|
The model config for the decoder of a PaliGemma 3B model.
|
94
94
|
"""
|
95
95
|
norm_config = cfg.NormalizationConfig(
|
96
|
-
type=cfg.NormalizationType.RMS_NORM,
|
97
|
-
epsilon=1e-6,
|
98
|
-
zero_centered=True,
|
99
|
-
enable_hlfb=True,
|
96
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
|
100
97
|
)
|
101
98
|
ff_config = cfg.FeedForwardConfig(
|
102
99
|
type=cfg.FeedForwardType.GATED,
|
@@ -140,7 +137,6 @@ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
140
137
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
141
138
|
final_norm_config=norm_config,
|
142
139
|
lm_head_use_bias=False,
|
143
|
-
enable_hlfb=True,
|
144
140
|
final_logit_softcap=30.0,
|
145
141
|
)
|
146
142
|
return config
|
@@ -118,9 +118,7 @@ def get_image_encoder_config() -> cfg.ModelConfig:
|
|
118
118
|
use_bias=True,
|
119
119
|
)
|
120
120
|
norm_config = cfg.NormalizationConfig(
|
121
|
-
type=cfg.NormalizationType.LAYER_NORM,
|
122
|
-
epsilon=1e-6,
|
123
|
-
enable_hlfb=True,
|
121
|
+
type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
|
124
122
|
)
|
125
123
|
block_config = cfg.TransformerBlockConfig(
|
126
124
|
attn_config=attn_config,
|
@@ -137,7 +135,6 @@ def get_image_encoder_config() -> cfg.ModelConfig:
|
|
137
135
|
image_embedding=image_embedding_config,
|
138
136
|
block_configs=block_config,
|
139
137
|
final_norm_config=norm_config,
|
140
|
-
enable_hlfb=True,
|
141
138
|
)
|
142
139
|
return config
|
143
140
|
|
@@ -66,9 +66,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
66
66
|
intermediate_size=10240,
|
67
67
|
use_bias=True,
|
68
68
|
)
|
69
|
-
norm_config = cfg.NormalizationConfig(
|
70
|
-
type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=True
|
71
|
-
)
|
69
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
72
70
|
block_config = cfg.TransformerBlockConfig(
|
73
71
|
attn_config=attn_config,
|
74
72
|
ff_config=ff_config,
|
@@ -85,7 +83,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
85
83
|
final_norm_config=norm_config,
|
86
84
|
lm_head_use_bias=True,
|
87
85
|
lm_head_share_weight_with_embedding=False,
|
88
|
-
enable_hlfb=True,
|
89
86
|
)
|
90
87
|
return config
|
91
88
|
|
@@ -162,9 +162,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
162
162
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
|
163
163
|
intermediate_size=8192,
|
164
164
|
)
|
165
|
-
norm_config = cfg.NormalizationConfig(
|
166
|
-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
|
167
|
-
)
|
165
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
168
166
|
block_config = cfg.TransformerBlockConfig(
|
169
167
|
attn_config=attn_config,
|
170
168
|
ff_config=ff_config,
|
@@ -192,7 +190,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
192
190
|
block_configs=block_config,
|
193
191
|
final_norm_config=norm_config,
|
194
192
|
lm_head_share_weight_with_embedding=False,
|
195
|
-
enable_hlfb=True,
|
196
193
|
build_rope=build_rope,
|
197
194
|
)
|
198
195
|
return config
|
@@ -112,9 +112,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
112
112
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
|
113
113
|
intermediate_size=8192,
|
114
114
|
)
|
115
|
-
norm_config = cfg.NormalizationConfig(
|
116
|
-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
117
|
-
)
|
115
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
118
116
|
block_config = cfg.TransformerBlockConfig(
|
119
117
|
attn_config=attn_config,
|
120
118
|
ff_config=ff_config,
|
@@ -141,7 +139,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
141
139
|
embedding_dim=3072,
|
142
140
|
block_configs=block_config,
|
143
141
|
final_norm_config=norm_config,
|
144
|
-
enable_hlfb=True,
|
145
142
|
build_rope=build_rope,
|
146
143
|
)
|
147
144
|
return config
|
@@ -14,15 +14,10 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
"""Verifies the reauthored Phi-2 model."""
|
17
|
-
import logging
|
18
17
|
|
19
18
|
from absl import app
|
20
19
|
from absl import flags
|
21
|
-
from ai_edge_torch.generative.examples.phi import
|
22
|
-
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
|
-
from ai_edge_torch.generative.utilities import verifier
|
24
|
-
import kagglehub
|
25
|
-
import transformers
|
20
|
+
from ai_edge_torch.generative.examples.phi import verify_util
|
26
21
|
|
27
22
|
|
28
23
|
_PROMPTS = flags.DEFINE_multi_string(
|
@@ -38,25 +33,12 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
38
33
|
|
39
34
|
|
40
35
|
def main(_):
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
logging.info("Building the reauthored model from: %s", checkpoint)
|
46
|
-
reauthored_model = phi2.build_model(checkpoint)
|
47
|
-
|
48
|
-
logging.info("Loading the tokenizer from: %s", checkpoint)
|
49
|
-
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
50
|
-
|
51
|
-
verifier.verify_reauthored_model(
|
52
|
-
original_model=transformers_verifier.TransformersModelWrapper(
|
53
|
-
original_model
|
54
|
-
),
|
55
|
-
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
56
|
-
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
57
|
-
generate_prompts=_PROMPTS.value,
|
36
|
+
verify_util.verify_phi(
|
37
|
+
version="v2",
|
38
|
+
checkpoint_dir="microsoft/phi-2",
|
58
39
|
max_new_tokens=_MAX_NEW_TOKENS.value,
|
59
|
-
|
40
|
+
prompts=_PROMPTS.value,
|
41
|
+
atol=1e-02,
|
60
42
|
)
|
61
43
|
|
62
44
|
|
@@ -15,15 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored Phi-3.5 model."""
|
17
17
|
|
18
|
-
import logging
|
19
|
-
import pathlib
|
20
18
|
|
21
19
|
from absl import app
|
22
20
|
from absl import flags
|
23
|
-
from ai_edge_torch.generative.examples.phi import
|
24
|
-
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
-
from ai_edge_torch.generative.utilities import verifier
|
26
|
-
import transformers
|
21
|
+
from ai_edge_torch.generative.examples.phi import verify_util
|
27
22
|
|
28
23
|
|
29
24
|
_PROMPTS = flags.DEFINE_multi_string(
|
@@ -39,29 +34,11 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
39
34
|
|
40
35
|
|
41
36
|
def main(_):
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
# Locate the cached dir.
|
47
|
-
cached_config_file = transformers.utils.cached_file(
|
48
|
-
checkpoint, transformers.utils.CONFIG_NAME
|
49
|
-
)
|
50
|
-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
51
|
-
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
52
|
-
reauthored_model = phi3.build_model(reauthored_checkpoint)
|
53
|
-
|
54
|
-
logging.info("Loading the tokenizer from: %s", checkpoint)
|
55
|
-
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
56
|
-
|
57
|
-
verifier.verify_reauthored_model(
|
58
|
-
original_model=transformers_verifier.TransformersModelWrapper(
|
59
|
-
original_model
|
60
|
-
),
|
61
|
-
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
62
|
-
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
63
|
-
generate_prompts=_PROMPTS.value,
|
37
|
+
verify_util.verify_phi(
|
38
|
+
version="v3",
|
39
|
+
checkpoint_dir="microsoft/Phi-3.5-mini-instruct",
|
64
40
|
max_new_tokens=_MAX_NEW_TOKENS.value,
|
41
|
+
prompts=_PROMPTS.value,
|
65
42
|
)
|
66
43
|
|
67
44
|
|