ai-edge-torch-nightly 0.5.0.dev20250516__py3-none-any.whl → 0.5.0.dev20250518__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +1 -0
- ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
- ai_edge_torch/_convert/fx_passes/eliminate_dead_code_pass.py +40 -0
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +1 -4
- ai_edge_torch/generative/examples/amd_llama_135m/verify.py +4 -32
- ai_edge_torch/generative/examples/amd_llama_135m/verify_util.py +76 -0
- ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -4
- ai_edge_torch/generative/examples/deepseek/verify.py +4 -30
- ai_edge_torch/generative/examples/deepseek/verify_util.py +76 -0
- ai_edge_torch/generative/examples/gemma/gemma1.py +1 -5
- ai_edge_torch/generative/examples/gemma/gemma2.py +1 -5
- ai_edge_torch/generative/examples/gemma/verify_util.py +51 -6
- ai_edge_torch/generative/examples/gemma3/decoder.py +10 -10
- ai_edge_torch/generative/examples/gemma3/gemma3.py +1 -3
- ai_edge_torch/generative/examples/gemma3/image_encoder.py +1 -4
- ai_edge_torch/generative/examples/gemma3/verify_util.py +13 -0
- ai_edge_torch/generative/examples/hammer/hammer.py +1 -4
- ai_edge_torch/generative/examples/hammer/verify.py +5 -35
- ai_edge_torch/generative/examples/hammer/verify_util.py +82 -0
- ai_edge_torch/generative/examples/llama/llama.py +1 -4
- ai_edge_torch/generative/examples/llama/verify.py +5 -38
- ai_edge_torch/generative/examples/llama/verify_util.py +81 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +1 -2
- ai_edge_torch/generative/examples/openelm/verify.py +4 -31
- ai_edge_torch/generative/examples/openelm/verify_util.py +76 -0
- ai_edge_torch/generative/examples/paligemma/decoder.py +1 -5
- ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -5
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +1 -4
- ai_edge_torch/generative/examples/phi/phi2.py +1 -4
- ai_edge_torch/generative/examples/phi/phi3.py +1 -4
- ai_edge_torch/generative/examples/phi/phi4.py +1 -4
- ai_edge_torch/generative/examples/phi/verify.py +6 -24
- ai_edge_torch/generative/examples/phi/verify_phi3.py +5 -28
- ai_edge_torch/generative/examples/phi/verify_phi4.py +5 -28
- ai_edge_torch/generative/examples/phi/verify_util.py +84 -0
- ai_edge_torch/generative/examples/qwen/qwen.py +1 -4
- ai_edge_torch/generative/examples/qwen/verify.py +5 -35
- ai_edge_torch/generative/examples/qwen/verify_util.py +83 -0
- ai_edge_torch/generative/examples/qwen_vl/decoder.py +1 -2
- ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +1 -3
- ai_edge_torch/generative/examples/smollm/smollm.py +1 -4
- ai_edge_torch/generative/examples/smollm/verify.py +5 -36
- ai_edge_torch/generative/examples/smollm/verify_util.py +81 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +6 -4
- ai_edge_torch/generative/examples/t5/t5.py +1 -3
- ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -2
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +1 -4
- ai_edge_torch/generative/examples/tiny_llama/verify.py +4 -31
- ai_edge_torch/generative/examples/tiny_llama/verify_util.py +76 -0
- ai_edge_torch/generative/layers/model_config.py +2 -2
- ai_edge_torch/generative/utilities/converter.py +2 -1
- ai_edge_torch/generative/utilities/loader.py +11 -1
- ai_edge_torch/odml_torch/lowerings/utils.py +13 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/RECORD +60 -50
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/top_level.txt +0 -0
@@ -15,15 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored Phi-4 model."""
|
17
17
|
|
18
|
-
import logging
|
19
|
-
import pathlib
|
20
18
|
|
21
19
|
from absl import app
|
22
20
|
from absl import flags
|
23
|
-
from ai_edge_torch.generative.examples.phi import
|
24
|
-
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
-
from ai_edge_torch.generative.utilities import verifier
|
26
|
-
import transformers
|
21
|
+
from ai_edge_torch.generative.examples.phi import verify_util
|
27
22
|
|
28
23
|
|
29
24
|
_PROMPTS = flags.DEFINE_multi_string(
|
@@ -39,29 +34,11 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
39
34
|
|
40
35
|
|
41
36
|
def main(_):
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
# Locate the cached dir.
|
47
|
-
cached_config_file = transformers.utils.cached_file(
|
48
|
-
checkpoint, transformers.utils.CONFIG_NAME
|
49
|
-
)
|
50
|
-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
51
|
-
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
52
|
-
reauthored_model = phi4.build_model(reauthored_checkpoint)
|
53
|
-
|
54
|
-
logging.info("Loading the tokenizer from: %s", checkpoint)
|
55
|
-
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
56
|
-
|
57
|
-
verifier.verify_reauthored_model(
|
58
|
-
original_model=transformers_verifier.TransformersModelWrapper(
|
59
|
-
original_model
|
60
|
-
),
|
61
|
-
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
62
|
-
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
63
|
-
generate_prompts=_PROMPTS.value,
|
37
|
+
verify_util.verify_phi(
|
38
|
+
version="v4",
|
39
|
+
checkpoint_dir="microsoft/Phi-4-mini-instruct",
|
64
40
|
max_new_tokens=_MAX_NEW_TOKENS.value,
|
41
|
+
prompts=_PROMPTS.value,
|
65
42
|
)
|
66
43
|
|
67
44
|
|
@@ -0,0 +1,84 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Utils for verifying the Phi model."""
|
16
|
+
import logging
|
17
|
+
import os
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.examples.phi import phi2, phi3, phi4
|
21
|
+
from ai_edge_torch.generative.utilities import loader
|
22
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
|
27
|
+
DEFAULT_PROMPTS = ["Instruct: Write an email about the weather Output:"]
|
28
|
+
|
29
|
+
_BUILDER = {
|
30
|
+
"v2": phi2.build_model,
|
31
|
+
"v3": phi3.build_model,
|
32
|
+
"v4": phi4.build_model,
|
33
|
+
}
|
34
|
+
|
35
|
+
|
36
|
+
def verify_phi(
|
37
|
+
version: str,
|
38
|
+
checkpoint_dir: str,
|
39
|
+
weight_filename: str = "model.safetensors",
|
40
|
+
max_new_tokens: int = 30,
|
41
|
+
initialize_from_local: bool = True,
|
42
|
+
prompts: list[str] | None = None,
|
43
|
+
atol: float = 1e-04,
|
44
|
+
) -> bool:
|
45
|
+
"""Verifies the reauthored Phi model with a custom loader."""
|
46
|
+
logging.info("Loading the original model from: %s", checkpoint_dir)
|
47
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
48
|
+
checkpoint_dir
|
49
|
+
)
|
50
|
+
|
51
|
+
logging.info("Building the reauthored model from: %s", checkpoint_dir)
|
52
|
+
custom_loader = (
|
53
|
+
None
|
54
|
+
if initialize_from_local
|
55
|
+
else loader.get_custom_loader("", "safetensors")
|
56
|
+
)
|
57
|
+
|
58
|
+
if initialize_from_local:
|
59
|
+
# Locate the cached dir.
|
60
|
+
cached_config_file = transformers.utils.cached_file(
|
61
|
+
checkpoint_dir, transformers.utils.CONFIG_NAME
|
62
|
+
)
|
63
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
64
|
+
else:
|
65
|
+
reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
|
66
|
+
|
67
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
68
|
+
reauthored_model = _BUILDER[version](
|
69
|
+
checkpoint_path=reauthored_checkpoint,
|
70
|
+
custom_loader=custom_loader,
|
71
|
+
)
|
72
|
+
|
73
|
+
logging.info("Loading the tokenizer from: %s", checkpoint_dir)
|
74
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
|
75
|
+
return verifier.verify_reauthored_model(
|
76
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
77
|
+
original_model
|
78
|
+
),
|
79
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
80
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
81
|
+
generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
|
82
|
+
max_new_tokens=max_new_tokens,
|
83
|
+
atol=atol,
|
84
|
+
)
|
@@ -53,9 +53,7 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
53
53
|
intermediate_size=11008,
|
54
54
|
)
|
55
55
|
norm_config = cfg.NormalizationConfig(
|
56
|
-
type=cfg.NormalizationType.RMS_NORM,
|
57
|
-
epsilon=1e-06,
|
58
|
-
enable_hlfb=True,
|
56
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
|
59
57
|
)
|
60
58
|
block_config = cfg.TransformerBlockConfig(
|
61
59
|
attn_config=attn_config,
|
@@ -71,7 +69,6 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
71
69
|
kv_cache_max_len=kv_cache_max_len,
|
72
70
|
block_configs=block_config,
|
73
71
|
final_norm_config=norm_config,
|
74
|
-
enable_hlfb=True,
|
75
72
|
)
|
76
73
|
return config
|
77
74
|
|
@@ -15,15 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored Qwen 2.5 0.5B, 1.5B, and 3B models."""
|
17
17
|
|
18
|
-
import logging
|
19
|
-
import pathlib
|
20
18
|
|
21
19
|
from absl import app
|
22
20
|
from absl import flags
|
23
|
-
from ai_edge_torch.generative.examples.qwen import
|
24
|
-
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
-
from ai_edge_torch.generative.utilities import verifier
|
26
|
-
import transformers
|
21
|
+
from ai_edge_torch.generative.examples.qwen import verify_util
|
27
22
|
|
28
23
|
|
29
24
|
_MODEL_SIZE = flags.DEFINE_enum(
|
@@ -49,38 +44,13 @@ _CHECKPOINT = {
|
|
49
44
|
"3b": "Qwen/Qwen2.5-3B-Instruct",
|
50
45
|
}
|
51
46
|
|
52
|
-
_BUILDER = {
|
53
|
-
"0.5b": qwen.build_0_5b_model,
|
54
|
-
"1.5b": qwen.build_1_5b_model,
|
55
|
-
"3b": qwen.build_3b_model,
|
56
|
-
}
|
57
|
-
|
58
47
|
|
59
48
|
def main(_):
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
# Locate the cached dir.
|
65
|
-
cached_config_file = transformers.utils.cached_file(
|
66
|
-
checkpoint, transformers.utils.CONFIG_NAME
|
67
|
-
)
|
68
|
-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
69
|
-
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
70
|
-
reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
|
71
|
-
|
72
|
-
logging.info("Loading the tokenizer from: %s", checkpoint)
|
73
|
-
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
74
|
-
|
75
|
-
verifier.verify_reauthored_model(
|
76
|
-
original_model=transformers_verifier.TransformersModelWrapper(
|
77
|
-
original_model
|
78
|
-
),
|
79
|
-
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
80
|
-
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
81
|
-
generate_prompts=_PROMPTS.value,
|
49
|
+
verify_util.verify_qwen(
|
50
|
+
model_size=_MODEL_SIZE.value,
|
51
|
+
checkpoint_dir=_CHECKPOINT[_MODEL_SIZE.value],
|
82
52
|
max_new_tokens=_MAX_NEW_TOKENS.value,
|
83
|
-
|
53
|
+
prompts=_PROMPTS.value,
|
84
54
|
)
|
85
55
|
|
86
56
|
|
@@ -0,0 +1,83 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Utils for verifying the Qwen model."""
|
16
|
+
import logging
|
17
|
+
import os
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.examples.qwen import qwen
|
21
|
+
from ai_edge_torch.generative.utilities import loader
|
22
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
|
27
|
+
_BUILDER = {
|
28
|
+
"0.5b": qwen.build_0_5b_model,
|
29
|
+
"1.5b": qwen.build_1_5b_model,
|
30
|
+
"3b": qwen.build_3b_model,
|
31
|
+
}
|
32
|
+
|
33
|
+
DEFAULT_PROMPTS = ["What is the meaning of life?"]
|
34
|
+
|
35
|
+
|
36
|
+
def verify_qwen(
|
37
|
+
model_size: str,
|
38
|
+
checkpoint_dir: str,
|
39
|
+
weight_filename: str = "model.safetensors",
|
40
|
+
max_new_tokens: int = 30,
|
41
|
+
initialize_from_local: bool = True,
|
42
|
+
prompts: list[str] | None = None,
|
43
|
+
) -> bool:
|
44
|
+
"""Verifies the reauthored Llama 3.2 model with a custom loader."""
|
45
|
+
logging.info("Loading the original model from: %s", checkpoint_dir)
|
46
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
47
|
+
checkpoint_dir
|
48
|
+
)
|
49
|
+
|
50
|
+
logging.info("Building the reauthored model from: %s", checkpoint_dir)
|
51
|
+
custom_loader = (
|
52
|
+
None
|
53
|
+
if initialize_from_local
|
54
|
+
else loader.get_custom_loader("", "safetensors")
|
55
|
+
)
|
56
|
+
|
57
|
+
if initialize_from_local:
|
58
|
+
# Locate the cached dir.
|
59
|
+
cached_config_file = transformers.utils.cached_file(
|
60
|
+
checkpoint_dir, transformers.utils.CONFIG_NAME
|
61
|
+
)
|
62
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
63
|
+
else:
|
64
|
+
reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
|
65
|
+
|
66
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
67
|
+
reauthored_model = _BUILDER[model_size](
|
68
|
+
checkpoint_path=reauthored_checkpoint,
|
69
|
+
custom_loader=custom_loader,
|
70
|
+
)
|
71
|
+
|
72
|
+
logging.info("Loading the tokenizer from: %s", checkpoint_dir)
|
73
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
|
74
|
+
return verifier.verify_reauthored_model(
|
75
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
76
|
+
original_model
|
77
|
+
),
|
78
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
79
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
80
|
+
generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
|
81
|
+
max_new_tokens=max_new_tokens,
|
82
|
+
atol=1e-04,
|
83
|
+
)
|
@@ -97,7 +97,7 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
97
97
|
intermediate_size=11008,
|
98
98
|
)
|
99
99
|
norm_config = cfg.NormalizationConfig(
|
100
|
-
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
|
100
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
|
101
101
|
)
|
102
102
|
block_config = cfg.TransformerBlockConfig(
|
103
103
|
attn_config=attn_config,
|
@@ -113,7 +113,6 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
113
113
|
kv_cache_max_len=kv_cache_max_len,
|
114
114
|
block_configs=block_config,
|
115
115
|
final_norm_config=norm_config,
|
116
|
-
enable_hlfb=True,
|
117
116
|
)
|
118
117
|
return config
|
119
118
|
|
@@ -332,8 +332,7 @@ def get_image_encoder_config(image_size: Tuple[int, int]) -> QwenVLImageConfig:
|
|
332
332
|
use_bias=True,
|
333
333
|
)
|
334
334
|
norm_config = cfg.NormalizationConfig(
|
335
|
-
type=cfg.NormalizationType.RMS_NORM,
|
336
|
-
epsilon=1e-6,
|
335
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
|
337
336
|
)
|
338
337
|
block_config = cfg.TransformerBlockConfig(
|
339
338
|
attn_config=attn_config,
|
@@ -359,7 +358,6 @@ def get_image_encoder_config(image_size: Tuple[int, int]) -> QwenVLImageConfig:
|
|
359
358
|
window_size=112,
|
360
359
|
spatial_merge_size=2,
|
361
360
|
full_atten_block_indexes=[7, 15, 23, 31],
|
362
|
-
enable_hlfb=True,
|
363
361
|
)
|
364
362
|
return config
|
365
363
|
|
@@ -51,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
51
51
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
52
52
|
intermediate_size=1536,
|
53
53
|
)
|
54
|
-
norm_config = cfg.NormalizationConfig(
|
55
|
-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
56
|
-
)
|
54
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
57
55
|
block_config = cfg.TransformerBlockConfig(
|
58
56
|
attn_config=attn_config,
|
59
57
|
ff_config=ff_config,
|
@@ -68,7 +66,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
68
66
|
kv_cache_max_len=kv_cache_max_len,
|
69
67
|
block_configs=block_config,
|
70
68
|
final_norm_config=norm_config,
|
71
|
-
enable_hlfb=True,
|
72
69
|
)
|
73
70
|
return config
|
74
71
|
|
@@ -15,15 +15,9 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored SmolLM-135M model."""
|
17
17
|
|
18
|
-
import logging
|
19
|
-
import pathlib
|
20
|
-
|
21
18
|
from absl import app
|
22
19
|
from absl import flags
|
23
|
-
from ai_edge_torch.generative.examples.smollm import
|
24
|
-
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
-
from ai_edge_torch.generative.utilities import verifier
|
26
|
-
import transformers
|
20
|
+
from ai_edge_torch.generative.examples.smollm import verify_util
|
27
21
|
|
28
22
|
|
29
23
|
_PROMPTS = flags.DEFINE_multi_string(
|
@@ -47,38 +41,13 @@ _CHECKPOINT = {
|
|
47
41
|
"v2": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
48
42
|
}
|
49
43
|
|
50
|
-
_BUILDER = {
|
51
|
-
"v1": smollm.build_model,
|
52
|
-
"v2": smollm.build_model_v2,
|
53
|
-
}
|
54
|
-
|
55
44
|
|
56
45
|
def main(_):
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
61
|
-
|
62
|
-
# Locate the cached dir.
|
63
|
-
cached_config_file = transformers.utils.cached_file(
|
64
|
-
checkpoint, transformers.utils.CONFIG_NAME
|
65
|
-
)
|
66
|
-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
67
|
-
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
68
|
-
reauthored_model = builder(reauthored_checkpoint)
|
69
|
-
|
70
|
-
logging.info("Loading the tokenizer from: %s", checkpoint)
|
71
|
-
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
72
|
-
|
73
|
-
verifier.verify_reauthored_model(
|
74
|
-
original_model=transformers_verifier.TransformersModelWrapper(
|
75
|
-
original_model
|
76
|
-
),
|
77
|
-
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
78
|
-
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
79
|
-
generate_prompts=_PROMPTS.value,
|
46
|
+
verify_util.verify_smollm_135m(
|
47
|
+
model_version=_MODEL_VERSION.value,
|
48
|
+
checkpoint_dir=_CHECKPOINT[_MODEL_VERSION.value],
|
80
49
|
max_new_tokens=_MAX_NEW_TOKENS.value,
|
81
|
-
|
50
|
+
prompts=_PROMPTS.value,
|
82
51
|
)
|
83
52
|
|
84
53
|
|
@@ -0,0 +1,81 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Utils for verifying the SmoLLM model."""
|
16
|
+
import logging
|
17
|
+
import os
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.examples.smollm import smollm
|
21
|
+
from ai_edge_torch.generative.utilities import loader
|
22
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
_BUILDER = {
|
27
|
+
"v1": smollm.build_model,
|
28
|
+
"v2": smollm.build_model_v2,
|
29
|
+
}
|
30
|
+
|
31
|
+
DEFAULT_PROMPTS = ["What is the meaning of life?"]
|
32
|
+
|
33
|
+
|
34
|
+
def verify_smollm_135m(
|
35
|
+
model_version: str,
|
36
|
+
checkpoint_dir: str,
|
37
|
+
weight_filename: str = "model.safetensors",
|
38
|
+
max_new_tokens: int = 30,
|
39
|
+
initialize_from_local: bool = True,
|
40
|
+
prompts: list[str] | None = None,
|
41
|
+
) -> bool:
|
42
|
+
"""Verifies the reauthored SmoLLM model with a custom loader."""
|
43
|
+
logging.info("Loading the original model from: %s", checkpoint_dir)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
45
|
+
checkpoint_dir
|
46
|
+
)
|
47
|
+
|
48
|
+
logging.info("Building the reauthored model from: %s", checkpoint_dir)
|
49
|
+
custom_loader = (
|
50
|
+
None
|
51
|
+
if initialize_from_local
|
52
|
+
else loader.get_custom_loader("", "safetensors")
|
53
|
+
)
|
54
|
+
|
55
|
+
if initialize_from_local:
|
56
|
+
# Locate the cached dir.
|
57
|
+
cached_config_file = transformers.utils.cached_file(
|
58
|
+
checkpoint_dir, transformers.utils.CONFIG_NAME
|
59
|
+
)
|
60
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
61
|
+
else:
|
62
|
+
reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
|
63
|
+
|
64
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
65
|
+
reauthored_model = _BUILDER[model_version](
|
66
|
+
checkpoint_path=reauthored_checkpoint,
|
67
|
+
custom_loader=custom_loader,
|
68
|
+
)
|
69
|
+
|
70
|
+
logging.info("Loading the tokenizer from: %s", checkpoint_dir)
|
71
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
|
72
|
+
return verifier.verify_reauthored_model(
|
73
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
74
|
+
original_model
|
75
|
+
),
|
76
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
77
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
78
|
+
generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
|
79
|
+
max_new_tokens=max_new_tokens,
|
80
|
+
atol=1e-04,
|
81
|
+
)
|
@@ -113,7 +113,9 @@ def get_model_config() -> cfg.ModelConfig:
|
|
113
113
|
use_bias=True,
|
114
114
|
)
|
115
115
|
|
116
|
-
norm_config = cfg.NormalizationConfig(
|
116
|
+
norm_config = cfg.NormalizationConfig(
|
117
|
+
type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=False
|
118
|
+
)
|
117
119
|
|
118
120
|
block_config = cfg.TransformerBlockConfig(
|
119
121
|
attn_config=attn_config,
|
@@ -129,7 +131,6 @@ def get_model_config() -> cfg.ModelConfig:
|
|
129
131
|
embedding_dim=embedding_dim,
|
130
132
|
block_configs=block_config,
|
131
133
|
final_norm_config=norm_config,
|
132
|
-
enable_hlfb=True,
|
133
134
|
)
|
134
135
|
|
135
136
|
return config
|
@@ -164,7 +165,9 @@ def get_fake_model_config() -> cfg.ModelConfig:
|
|
164
165
|
use_bias=True,
|
165
166
|
)
|
166
167
|
|
167
|
-
norm_config = cfg.NormalizationConfig(
|
168
|
+
norm_config = cfg.NormalizationConfig(
|
169
|
+
type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=False
|
170
|
+
)
|
168
171
|
|
169
172
|
block_config = cfg.TransformerBlockConfig(
|
170
173
|
attn_config=attn_config,
|
@@ -180,7 +183,6 @@ def get_fake_model_config() -> cfg.ModelConfig:
|
|
180
183
|
embedding_dim=embedding_dim,
|
181
184
|
block_configs=block_config,
|
182
185
|
final_norm_config=norm_config,
|
183
|
-
enable_hlfb=True,
|
184
186
|
)
|
185
187
|
|
186
188
|
return config
|
@@ -393,8 +393,7 @@ def get_model_config_t5() -> cfg.ModelConfig:
|
|
393
393
|
)
|
394
394
|
# T5 Confirmed as RMS Norm and eps = 1e-6 TJA.
|
395
395
|
norm_config = cfg.NormalizationConfig(
|
396
|
-
type=cfg.NormalizationType.RMS_NORM,
|
397
|
-
epsilon=1e-6,
|
396
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, enable_hlfb=False
|
398
397
|
)
|
399
398
|
block_config = cfg.TransformerBlockConfig(
|
400
399
|
attn_config=attn_config,
|
@@ -411,7 +410,6 @@ def get_model_config_t5() -> cfg.ModelConfig:
|
|
411
410
|
block_configs=block_config,
|
412
411
|
final_norm_config=norm_config,
|
413
412
|
lm_head_use_bias=False,
|
414
|
-
enable_hlfb=True,
|
415
413
|
)
|
416
414
|
return config
|
417
415
|
|
@@ -138,7 +138,9 @@ def get_model_config() -> cfg.ModelConfig:
|
|
138
138
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
139
139
|
intermediate_size=256,
|
140
140
|
)
|
141
|
-
norm_config = cfg.NormalizationConfig(
|
141
|
+
norm_config = cfg.NormalizationConfig(
|
142
|
+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=False
|
143
|
+
)
|
142
144
|
block_config = cfg.TransformerBlockConfig(
|
143
145
|
attn_config=attn_config,
|
144
146
|
ff_config=ff_config,
|
@@ -152,5 +154,6 @@ def get_model_config() -> cfg.ModelConfig:
|
|
152
154
|
embedding_dim=128,
|
153
155
|
block_configs=block_config,
|
154
156
|
final_norm_config=norm_config,
|
157
|
+
enable_hlfb=False,
|
155
158
|
)
|
156
159
|
return config
|
@@ -108,7 +108,9 @@ def get_model_config() -> cfg.ModelConfig:
|
|
108
108
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
109
109
|
intermediate_size=256,
|
110
110
|
)
|
111
|
-
norm_config = cfg.NormalizationConfig(
|
111
|
+
norm_config = cfg.NormalizationConfig(
|
112
|
+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=False
|
113
|
+
)
|
112
114
|
block_config = cfg.TransformerBlockConfig(
|
113
115
|
attn_config=attn_config,
|
114
116
|
ff_config=ff_config,
|
@@ -122,7 +124,6 @@ def get_model_config() -> cfg.ModelConfig:
|
|
122
124
|
embedding_dim=128,
|
123
125
|
block_configs=block_config,
|
124
126
|
final_norm_config=norm_config,
|
125
|
-
enable_hlfb=True,
|
126
127
|
)
|
127
128
|
return config
|
128
129
|
|
@@ -51,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
51
51
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
52
52
|
intermediate_size=5632,
|
53
53
|
)
|
54
|
-
norm_config = cfg.NormalizationConfig(
|
55
|
-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
56
|
-
)
|
54
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
57
55
|
block_config = cfg.TransformerBlockConfig(
|
58
56
|
attn_config=attn_config,
|
59
57
|
ff_config=ff_config,
|
@@ -69,7 +67,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
69
67
|
block_configs=block_config,
|
70
68
|
final_norm_config=norm_config,
|
71
69
|
lm_head_share_weight_with_embedding=False,
|
72
|
-
enable_hlfb=True,
|
73
70
|
)
|
74
71
|
return config
|
75
72
|
|
@@ -15,15 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored TinyLlama-1.1B model."""
|
17
17
|
|
18
|
-
import logging
|
19
|
-
import pathlib
|
20
18
|
|
21
19
|
from absl import app
|
22
20
|
from absl import flags
|
23
|
-
from ai_edge_torch.generative.examples.tiny_llama import
|
24
|
-
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
-
from ai_edge_torch.generative.utilities import verifier
|
26
|
-
import transformers
|
21
|
+
from ai_edge_torch.generative.examples.tiny_llama import verify_util
|
27
22
|
|
28
23
|
|
29
24
|
_PROMPTS = flags.DEFINE_multi_string(
|
@@ -39,32 +34,10 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
39
34
|
|
40
35
|
|
41
36
|
def main(_):
|
42
|
-
|
43
|
-
|
44
|
-
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
45
|
-
checkpoint, trust_remote_code=True
|
46
|
-
)
|
47
|
-
|
48
|
-
# Locate the cached dir.
|
49
|
-
cached_config_file = transformers.utils.cached_file(
|
50
|
-
checkpoint, transformers.utils.CONFIG_NAME
|
51
|
-
)
|
52
|
-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
53
|
-
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
54
|
-
reauthored_model = tiny_llama.build_model(str(reauthored_checkpoint))
|
55
|
-
|
56
|
-
logging.info("Loading the tokenizer from: %s", checkpoint)
|
57
|
-
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
58
|
-
|
59
|
-
verifier.verify_reauthored_model(
|
60
|
-
original_model=transformers_verifier.TransformersModelWrapper(
|
61
|
-
original_model
|
62
|
-
),
|
63
|
-
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
64
|
-
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
65
|
-
generate_prompts=_PROMPTS.value,
|
37
|
+
verify_util.verify_tiny_llama(
|
38
|
+
checkpoint_dir="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
66
39
|
max_new_tokens=_MAX_NEW_TOKENS.value,
|
67
|
-
|
40
|
+
prompts=_PROMPTS.value,
|
68
41
|
)
|
69
42
|
|
70
43
|
|