ai-edge-torch-nightly 0.5.0.dev20250516__py3-none-any.whl → 0.5.0.dev20250518__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +1 -0
- ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
- ai_edge_torch/_convert/fx_passes/eliminate_dead_code_pass.py +40 -0
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +1 -4
- ai_edge_torch/generative/examples/amd_llama_135m/verify.py +4 -32
- ai_edge_torch/generative/examples/amd_llama_135m/verify_util.py +76 -0
- ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -4
- ai_edge_torch/generative/examples/deepseek/verify.py +4 -30
- ai_edge_torch/generative/examples/deepseek/verify_util.py +76 -0
- ai_edge_torch/generative/examples/gemma/gemma1.py +1 -5
- ai_edge_torch/generative/examples/gemma/gemma2.py +1 -5
- ai_edge_torch/generative/examples/gemma/verify_util.py +51 -6
- ai_edge_torch/generative/examples/gemma3/decoder.py +10 -10
- ai_edge_torch/generative/examples/gemma3/gemma3.py +1 -3
- ai_edge_torch/generative/examples/gemma3/image_encoder.py +1 -4
- ai_edge_torch/generative/examples/gemma3/verify_util.py +13 -0
- ai_edge_torch/generative/examples/hammer/hammer.py +1 -4
- ai_edge_torch/generative/examples/hammer/verify.py +5 -35
- ai_edge_torch/generative/examples/hammer/verify_util.py +82 -0
- ai_edge_torch/generative/examples/llama/llama.py +1 -4
- ai_edge_torch/generative/examples/llama/verify.py +5 -38
- ai_edge_torch/generative/examples/llama/verify_util.py +81 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +1 -2
- ai_edge_torch/generative/examples/openelm/verify.py +4 -31
- ai_edge_torch/generative/examples/openelm/verify_util.py +76 -0
- ai_edge_torch/generative/examples/paligemma/decoder.py +1 -5
- ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -5
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +1 -4
- ai_edge_torch/generative/examples/phi/phi2.py +1 -4
- ai_edge_torch/generative/examples/phi/phi3.py +1 -4
- ai_edge_torch/generative/examples/phi/phi4.py +1 -4
- ai_edge_torch/generative/examples/phi/verify.py +6 -24
- ai_edge_torch/generative/examples/phi/verify_phi3.py +5 -28
- ai_edge_torch/generative/examples/phi/verify_phi4.py +5 -28
- ai_edge_torch/generative/examples/phi/verify_util.py +84 -0
- ai_edge_torch/generative/examples/qwen/qwen.py +1 -4
- ai_edge_torch/generative/examples/qwen/verify.py +5 -35
- ai_edge_torch/generative/examples/qwen/verify_util.py +83 -0
- ai_edge_torch/generative/examples/qwen_vl/decoder.py +1 -2
- ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +1 -3
- ai_edge_torch/generative/examples/smollm/smollm.py +1 -4
- ai_edge_torch/generative/examples/smollm/verify.py +5 -36
- ai_edge_torch/generative/examples/smollm/verify_util.py +81 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +6 -4
- ai_edge_torch/generative/examples/t5/t5.py +1 -3
- ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -2
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +1 -4
- ai_edge_torch/generative/examples/tiny_llama/verify.py +4 -31
- ai_edge_torch/generative/examples/tiny_llama/verify_util.py +76 -0
- ai_edge_torch/generative/layers/model_config.py +2 -2
- ai_edge_torch/generative/utilities/converter.py +2 -1
- ai_edge_torch/generative/utilities/loader.py +11 -1
- ai_edge_torch/odml_torch/lowerings/utils.py +13 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/RECORD +60 -50
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from typing import Sequence, Union
|
|
17
17
|
|
18
18
|
from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
|
19
19
|
from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
|
20
|
+
from ai_edge_torch._convert.fx_passes.eliminate_dead_code_pass import EliminateDeadCodePass
|
20
21
|
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
|
21
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
|
22
23
|
from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
|
@@ -0,0 +1,40 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Pass to eliminate dead code for ai-edge-torch conversion."""
|
16
|
+
|
17
|
+
|
18
|
+
from ai_edge_torch import fx_infra
|
19
|
+
import torch
|
20
|
+
|
21
|
+
|
22
|
+
class EliminateDeadCodePass(fx_infra.PassBase):
|
23
|
+
"""Eliminates dead code with dedicated rules for ai-edge-torch conversion."""
|
24
|
+
|
25
|
+
def call(self, graph_module: torch.fx.GraphModule):
|
26
|
+
def is_impure_node(node: torch.fx.Node):
|
27
|
+
# Starting from torch 2.7.0, random torch ops with
|
28
|
+
# _nondeterministic_seeded set are no longer considered pure. However,
|
29
|
+
# for conversion, unused random ops/tensors should still be removed.
|
30
|
+
if getattr(node.target, "_nondeterministic_seeded", False):
|
31
|
+
return False
|
32
|
+
return node.is_impure()
|
33
|
+
|
34
|
+
try:
|
35
|
+
graph_module.graph.eliminate_dead_code(is_impure_node)
|
36
|
+
except TypeError:
|
37
|
+
# eliminate_dead_code has no is_impure_node input in old torch versions.
|
38
|
+
pass
|
39
|
+
|
40
|
+
return fx_infra.PassResult(graph_module, True)
|
@@ -51,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
51
51
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
52
52
|
intermediate_size=2048,
|
53
53
|
)
|
54
|
-
norm_config = cfg.NormalizationConfig(
|
55
|
-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
56
|
-
)
|
54
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
57
55
|
block_config = cfg.TransformerBlockConfig(
|
58
56
|
attn_config=attn_config,
|
59
57
|
ff_config=ff_config,
|
@@ -69,7 +67,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
69
67
|
block_configs=block_config,
|
70
68
|
final_norm_config=norm_config,
|
71
69
|
lm_head_share_weight_with_embedding=False,
|
72
|
-
enable_hlfb=True,
|
73
70
|
)
|
74
71
|
return config
|
75
72
|
|
@@ -15,15 +15,9 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored AMD-Llama-135M model."""
|
17
17
|
|
18
|
-
import logging
|
19
|
-
import pathlib
|
20
|
-
|
21
18
|
from absl import app
|
22
19
|
from absl import flags
|
23
|
-
from ai_edge_torch.generative.examples.amd_llama_135m import
|
24
|
-
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
-
from ai_edge_torch.generative.utilities import verifier
|
26
|
-
import transformers
|
20
|
+
from ai_edge_torch.generative.examples.amd_llama_135m import verify_util
|
27
21
|
|
28
22
|
|
29
23
|
_PROMPTS = flags.DEFINE_multi_string(
|
@@ -39,32 +33,10 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
39
33
|
|
40
34
|
|
41
35
|
def main(_):
|
42
|
-
|
43
|
-
|
44
|
-
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
45
|
-
checkpoint, trust_remote_code=True
|
46
|
-
)
|
47
|
-
|
48
|
-
# Locate the cached dir.
|
49
|
-
cached_config_file = transformers.utils.cached_file(
|
50
|
-
checkpoint, transformers.utils.CONFIG_NAME
|
51
|
-
)
|
52
|
-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
53
|
-
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
54
|
-
reauthored_model = amd_llama_135m.build_model(str(reauthored_checkpoint))
|
55
|
-
|
56
|
-
logging.info("Loading the tokenizer from: %s", checkpoint)
|
57
|
-
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
58
|
-
|
59
|
-
verifier.verify_reauthored_model(
|
60
|
-
original_model=transformers_verifier.TransformersModelWrapper(
|
61
|
-
original_model
|
62
|
-
),
|
63
|
-
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
64
|
-
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
65
|
-
generate_prompts=_PROMPTS.value,
|
36
|
+
verify_util.verify_amd_llama_135m(
|
37
|
+
"amd/AMD-Llama-135m",
|
66
38
|
max_new_tokens=_MAX_NEW_TOKENS.value,
|
67
|
-
|
39
|
+
prompts=_PROMPTS.value,
|
68
40
|
)
|
69
41
|
|
70
42
|
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Utils for verifying the AMD-Llama-135M model."""
|
16
|
+
import logging
|
17
|
+
import os
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
|
21
|
+
from ai_edge_torch.generative.utilities import loader
|
22
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
|
27
|
+
DEFAULT_PROMPTS = ["Tell me a story?\nOnce upon a time"]
|
28
|
+
|
29
|
+
|
30
|
+
def verify_amd_llama_135m(
|
31
|
+
checkpoint_dir: str,
|
32
|
+
weight_filename: str = "model.safetensors",
|
33
|
+
max_new_tokens: int = 30,
|
34
|
+
initialize_from_local: bool = True,
|
35
|
+
prompts: list[str] | None = None,
|
36
|
+
) -> bool:
|
37
|
+
"""Verifies the reauthored AMD-Llama-135M model with a custom loader."""
|
38
|
+
logging.info("Loading the original model from: %s", checkpoint_dir)
|
39
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
40
|
+
checkpoint_dir
|
41
|
+
)
|
42
|
+
|
43
|
+
logging.info("Building the reauthored model from: %s", checkpoint_dir)
|
44
|
+
custom_loader = (
|
45
|
+
None
|
46
|
+
if initialize_from_local
|
47
|
+
else loader.get_custom_loader("", "safetensors")
|
48
|
+
)
|
49
|
+
|
50
|
+
if initialize_from_local:
|
51
|
+
# Locate the cached dir.
|
52
|
+
cached_config_file = transformers.utils.cached_file(
|
53
|
+
checkpoint_dir, transformers.utils.CONFIG_NAME
|
54
|
+
)
|
55
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
56
|
+
else:
|
57
|
+
reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
|
58
|
+
|
59
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
60
|
+
reauthored_model = amd_llama_135m.build_model(
|
61
|
+
checkpoint_path=reauthored_checkpoint,
|
62
|
+
custom_loader=custom_loader,
|
63
|
+
)
|
64
|
+
|
65
|
+
logging.info("Loading the tokenizer from: %s", checkpoint_dir)
|
66
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
|
67
|
+
return verifier.verify_reauthored_model(
|
68
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
69
|
+
original_model
|
70
|
+
),
|
71
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
72
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
73
|
+
generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
|
74
|
+
max_new_tokens=max_new_tokens,
|
75
|
+
atol=1e-04,
|
76
|
+
)
|
@@ -53,9 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
53
53
|
intermediate_size=8960,
|
54
54
|
)
|
55
55
|
norm_config = cfg.NormalizationConfig(
|
56
|
-
type=cfg.NormalizationType.RMS_NORM,
|
57
|
-
epsilon=1e-06,
|
58
|
-
enable_hlfb=True,
|
56
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
|
59
57
|
)
|
60
58
|
block_config = cfg.TransformerBlockConfig(
|
61
59
|
attn_config=attn_config,
|
@@ -72,7 +70,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
72
70
|
block_configs=block_config,
|
73
71
|
final_norm_config=norm_config,
|
74
72
|
lm_head_share_weight_with_embedding=False,
|
75
|
-
enable_hlfb=True,
|
76
73
|
)
|
77
74
|
return config
|
78
75
|
|
@@ -15,15 +15,9 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored DeepSeek R1 distilled 1.5B model."""
|
17
17
|
|
18
|
-
import logging
|
19
|
-
import pathlib
|
20
|
-
|
21
18
|
from absl import app
|
22
19
|
from absl import flags
|
23
|
-
from ai_edge_torch.generative.examples.deepseek import
|
24
|
-
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
-
from ai_edge_torch.generative.utilities import verifier
|
26
|
-
import transformers
|
20
|
+
from ai_edge_torch.generative.examples.deepseek import verify_util
|
27
21
|
|
28
22
|
|
29
23
|
_PROMPTS = flags.DEFINE_multi_string(
|
@@ -39,30 +33,10 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
39
33
|
|
40
34
|
|
41
35
|
def main(_):
|
42
|
-
|
43
|
-
|
44
|
-
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
|
-
|
46
|
-
# Locate the cached dir.
|
47
|
-
cached_config_file = transformers.utils.cached_file(
|
48
|
-
checkpoint, transformers.utils.CONFIG_NAME
|
49
|
-
)
|
50
|
-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
51
|
-
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
52
|
-
reauthored_model = deepseek.build_model(str(reauthored_checkpoint))
|
53
|
-
|
54
|
-
logging.info("Loading the tokenizer from: %s", checkpoint)
|
55
|
-
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
56
|
-
|
57
|
-
verifier.verify_reauthored_model(
|
58
|
-
original_model=transformers_verifier.TransformersModelWrapper(
|
59
|
-
original_model
|
60
|
-
),
|
61
|
-
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
62
|
-
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
63
|
-
generate_prompts=_PROMPTS.value,
|
36
|
+
verify_util.verify_deepseek_r1_distill_1_5b(
|
37
|
+
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
64
38
|
max_new_tokens=_MAX_NEW_TOKENS.value,
|
65
|
-
|
39
|
+
prompts=_PROMPTS.value,
|
66
40
|
)
|
67
41
|
|
68
42
|
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Utils for verifying the DeepSeek R1 distilled 1.5B model."""
|
16
|
+
import logging
|
17
|
+
import os
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.examples.deepseek import deepseek
|
21
|
+
from ai_edge_torch.generative.utilities import loader
|
22
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
|
27
|
+
DEFAULT_PROMPTS = ["What is the meaning of life?"]
|
28
|
+
|
29
|
+
|
30
|
+
def verify_deepseek_r1_distill_1_5b(
|
31
|
+
checkpoint_dir: str,
|
32
|
+
weight_filename: str = "model.safetensors",
|
33
|
+
max_new_tokens: int = 30,
|
34
|
+
initialize_from_local: bool = True,
|
35
|
+
prompts: list[str] | None = None,
|
36
|
+
) -> bool:
|
37
|
+
"""Verifies the reauthored DeepSeek R1 distilled 1.5B model with a custom loader."""
|
38
|
+
logging.info("Loading the original model from: %s", checkpoint_dir)
|
39
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
40
|
+
checkpoint_dir
|
41
|
+
)
|
42
|
+
|
43
|
+
logging.info("Building the reauthored model from: %s", checkpoint_dir)
|
44
|
+
custom_loader = (
|
45
|
+
None
|
46
|
+
if initialize_from_local
|
47
|
+
else loader.get_custom_loader("", "safetensors")
|
48
|
+
)
|
49
|
+
|
50
|
+
if initialize_from_local:
|
51
|
+
# Locate the cached dir.
|
52
|
+
cached_config_file = transformers.utils.cached_file(
|
53
|
+
checkpoint_dir, transformers.utils.CONFIG_NAME
|
54
|
+
)
|
55
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
56
|
+
else:
|
57
|
+
reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
|
58
|
+
|
59
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
60
|
+
reauthored_model = deepseek.build_model(
|
61
|
+
checkpoint_path=reauthored_checkpoint,
|
62
|
+
custom_loader=custom_loader,
|
63
|
+
)
|
64
|
+
|
65
|
+
logging.info("Loading the tokenizer from: %s", checkpoint_dir)
|
66
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
|
67
|
+
return verifier.verify_reauthored_model(
|
68
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
69
|
+
original_model
|
70
|
+
),
|
71
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
72
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
73
|
+
generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
|
74
|
+
max_new_tokens=max_new_tokens,
|
75
|
+
atol=1e-04,
|
76
|
+
)
|
@@ -65,10 +65,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
65
65
|
intermediate_size=16384,
|
66
66
|
)
|
67
67
|
norm_config = cfg.NormalizationConfig(
|
68
|
-
type=cfg.NormalizationType.RMS_NORM,
|
69
|
-
epsilon=1e-6,
|
70
|
-
zero_centered=True,
|
71
|
-
enable_hlfb=True,
|
68
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
|
72
69
|
)
|
73
70
|
block_config = cfg.TransformerBlockConfig(
|
74
71
|
attn_config=attn_config,
|
@@ -87,7 +84,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
87
84
|
block_configs=block_config,
|
88
85
|
final_norm_config=norm_config,
|
89
86
|
lm_head_use_bias=False,
|
90
|
-
enable_hlfb=True,
|
91
87
|
)
|
92
88
|
return config
|
93
89
|
|
@@ -233,10 +233,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
233
233
|
The model config for a Gemma 2B model.
|
234
234
|
"""
|
235
235
|
norm_config = cfg.NormalizationConfig(
|
236
|
-
type=cfg.NormalizationType.RMS_NORM,
|
237
|
-
epsilon=1e-6,
|
238
|
-
zero_centered=True,
|
239
|
-
enable_hlfb=True,
|
236
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
|
240
237
|
)
|
241
238
|
ff_config = cfg.FeedForwardConfig(
|
242
239
|
type=cfg.FeedForwardType.GATED,
|
@@ -284,7 +281,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
284
281
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
285
282
|
final_norm_config=norm_config,
|
286
283
|
lm_head_use_bias=False,
|
287
|
-
enable_hlfb=True,
|
288
284
|
final_logit_softcap=30.0,
|
289
285
|
)
|
290
286
|
return config
|
@@ -17,11 +17,13 @@
|
|
17
17
|
|
18
18
|
import logging
|
19
19
|
import os
|
20
|
-
from typing import List, Tuple
|
20
|
+
from typing import Callable, Dict, List, Tuple
|
21
21
|
|
22
|
+
from ai_edge_torch.generative.examples.gemma import gemma1
|
22
23
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
23
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
25
|
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
26
|
+
from ai_edge_torch.generative.utilities import loader
|
25
27
|
from ai_edge_torch.generative.utilities import verifier
|
26
28
|
from gemma import config as gemma_config
|
27
29
|
from gemma import model as gemma_model
|
@@ -107,6 +109,7 @@ def verify_reauthored_gemma_model(
|
|
107
109
|
generate_prompts: List[str],
|
108
110
|
forward_input_ids: List[List[int]],
|
109
111
|
weight_filename: str = "model.ckpt",
|
112
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
|
110
113
|
tokenizer_filename: str = "tokenizer.model",
|
111
114
|
max_new_tokens: int = 20,
|
112
115
|
mask_as_input: bool = False,
|
@@ -125,7 +128,14 @@ def verify_reauthored_gemma_model(
|
|
125
128
|
|
126
129
|
logging.info("Loading the original model from: %s", checkpoint)
|
127
130
|
original_model = gemma_model.GemmaForCausalLM(config).eval()
|
128
|
-
|
131
|
+
checkpoint_path = os.path.join(checkpoint, weight_filename)
|
132
|
+
if custom_loader is None:
|
133
|
+
original_model.load_weights(checkpoint_path)
|
134
|
+
else:
|
135
|
+
original_model.load_state_dict(
|
136
|
+
custom_loader(checkpoint_path)["model_state_dict"],
|
137
|
+
strict=False,
|
138
|
+
)
|
129
139
|
|
130
140
|
return verifier.verify_reauthored_model(
|
131
141
|
original_model=GemmaWrapper(original_model),
|
@@ -144,27 +154,62 @@ def verify_reauthored_gemma_model(
|
|
144
154
|
|
145
155
|
|
146
156
|
def verify_gemma2(
|
147
|
-
|
157
|
+
checkpoint_dir: str,
|
158
|
+
weight_filename: str,
|
148
159
|
prompts: List[str],
|
149
160
|
max_new_tokens: int,
|
150
161
|
mask_as_input: bool = False,
|
151
162
|
kv_layout: kv_utils.KVLayout = kv_utils.KV_LAYOUT_DEFAULT,
|
163
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
|
152
164
|
) -> bool:
|
153
165
|
"""Verifies the reauthored Gemma2 model.
|
154
166
|
|
155
167
|
Return True if the verification passes, False otherwise.
|
156
168
|
"""
|
157
|
-
|
158
|
-
|
169
|
+
checkpoint_path = os.path.join(checkpoint_dir, weight_filename)
|
170
|
+
logging.info("Building the reauthored model from: %s", checkpoint_path)
|
171
|
+
reauthored_model = gemma2.build_2b_model(checkpoint_path, custom_loader)
|
159
172
|
|
160
173
|
return verify_reauthored_gemma_model(
|
161
|
-
checkpoint=
|
174
|
+
checkpoint=checkpoint_dir,
|
162
175
|
variant="2b-v2",
|
163
176
|
reauthored_model=reauthored_model,
|
164
177
|
generate_prompts=prompts,
|
165
178
|
forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
|
179
|
+
weight_filename=weight_filename,
|
180
|
+
custom_loader=custom_loader,
|
166
181
|
max_new_tokens=max_new_tokens,
|
167
182
|
mask_as_input=mask_as_input,
|
168
183
|
kv_layout=kv_layout,
|
169
184
|
atol=1e-04,
|
170
185
|
)
|
186
|
+
|
187
|
+
|
188
|
+
def verify_gemma1_with_custom_loader(checkpoint_dir: str) -> bool:
|
189
|
+
"""Verifies the reauthored Gemma1 model with a custom loader."""
|
190
|
+
weight_filename = "gemma-2b-it.ckpt"
|
191
|
+
checkpoint_path = os.path.join(checkpoint_dir, weight_filename)
|
192
|
+
custom_loader = loader.get_custom_loader(checkpoint_path)
|
193
|
+
reauthored_model = gemma1.build_2b_model(checkpoint_path, custom_loader)
|
194
|
+
return verify_reauthored_gemma_model(
|
195
|
+
checkpoint=checkpoint_dir,
|
196
|
+
variant="2b",
|
197
|
+
reauthored_model=reauthored_model,
|
198
|
+
weight_filename=weight_filename,
|
199
|
+
custom_loader=custom_loader,
|
200
|
+
generate_prompts=["What is the meaning of life?"],
|
201
|
+
forward_input_ids=[[1, 2, 3, 4]],
|
202
|
+
max_new_tokens=30,
|
203
|
+
)
|
204
|
+
|
205
|
+
|
206
|
+
def verify_gemma2_with_custom_loader(checkpoint_dir: str) -> bool:
|
207
|
+
"""Verifies the reauthored Gemma2 model with a custom loader."""
|
208
|
+
return verify_gemma2(
|
209
|
+
checkpoint_dir=checkpoint_dir,
|
210
|
+
weight_filename="model.ckpt",
|
211
|
+
prompts=["What is the meaning of life?"],
|
212
|
+
max_new_tokens=30,
|
213
|
+
mask_as_input=True,
|
214
|
+
custom_loader=loader.get_custom_loader("", checkpoint_format="pt"),
|
215
|
+
)
|
@@ -149,8 +149,12 @@ class Decoder(nn.Module):
|
|
149
149
|
cache_len=attention_mask.shape[-1],
|
150
150
|
sliding_window_size=sliding_window_size,
|
151
151
|
)
|
152
|
-
#
|
153
|
-
|
152
|
+
# Expand sliding_mask to match attention_mask's dimensions
|
153
|
+
# (e.g., [B, 1, seq_len, cache_len]).
|
154
|
+
# Assuming the head dimension is dim 1 for attention_mask.
|
155
|
+
expanded_sliding_mask = sliding_mask.unsqueeze(1)
|
156
|
+
# Combine masks using logical AND (min ensures -inf propagates).
|
157
|
+
combined_mask = torch.min(attention_mask, expanded_sliding_mask)
|
154
158
|
return combined_mask
|
155
159
|
return attention_mask
|
156
160
|
|
@@ -161,9 +165,9 @@ class Decoder(nn.Module):
|
|
161
165
|
sliding_window_size: int,
|
162
166
|
) -> torch.Tensor:
|
163
167
|
"""Creates mask for sliding window attention (PyTorch)."""
|
164
|
-
|
165
|
-
|
166
|
-
)
|
168
|
+
# Use torch.arange to create a tensor with a range of integers in a
|
169
|
+
# Dynamo-friendly way.
|
170
|
+
cache_positions = torch.arange(cache_len, dtype=torch.int32)
|
167
171
|
cache_positions = cache_positions.view(1, 1, -1) # [1, 1, cache_len]
|
168
172
|
segment_pos_expanded = segment_pos.clone().unsqueeze(-1) # [B, seq_len, 1]
|
169
173
|
|
@@ -329,10 +333,7 @@ def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
|
|
329
333
|
The model config for a Gemma 1B model.
|
330
334
|
"""
|
331
335
|
norm_config = cfg.NormalizationConfig(
|
332
|
-
type=cfg.NormalizationType.RMS_NORM,
|
333
|
-
epsilon=1e-6,
|
334
|
-
zero_centered=True,
|
335
|
-
enable_hlfb=True,
|
336
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True,
|
336
337
|
)
|
337
338
|
ff_config = cfg.FeedForwardConfig(
|
338
339
|
type=cfg.FeedForwardType.GATED,
|
@@ -379,7 +380,6 @@ def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
|
|
379
380
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
380
381
|
final_norm_config=norm_config,
|
381
382
|
lm_head_use_bias=False,
|
382
|
-
enable_hlfb=True,
|
383
383
|
final_logit_softcap=None,
|
384
384
|
)
|
385
385
|
return config
|
@@ -158,9 +158,7 @@ def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
|
|
158
158
|
image_projection_scale=128**0.5,
|
159
159
|
image_projection_use_bias=False,
|
160
160
|
mm_norm_config=cfg.NormalizationConfig(
|
161
|
-
type=cfg.NormalizationType.LAYER_NORM,
|
162
|
-
epsilon=1e-6,
|
163
|
-
enable_hlfb=True,
|
161
|
+
type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
|
164
162
|
),
|
165
163
|
mm_extra_tokens=32,
|
166
164
|
)
|
@@ -98,9 +98,7 @@ def get_image_encoder_config() -> cfg.ModelConfig:
|
|
98
98
|
output_proj_use_bias=True,
|
99
99
|
)
|
100
100
|
norm_config = cfg.NormalizationConfig(
|
101
|
-
type=cfg.NormalizationType.LAYER_NORM,
|
102
|
-
epsilon=1e-6,
|
103
|
-
enable_hlfb=True,
|
101
|
+
type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
|
104
102
|
)
|
105
103
|
ff_config = cfg.FeedForwardConfig(
|
106
104
|
type=cfg.FeedForwardType.SEQUENTIAL,
|
@@ -123,7 +121,6 @@ def get_image_encoder_config() -> cfg.ModelConfig:
|
|
123
121
|
image_embedding=image_embedding_config,
|
124
122
|
block_configs=block_config,
|
125
123
|
final_norm_config=norm_config,
|
126
|
-
enable_hlfb=True,
|
127
124
|
num_mm_tokens_per_image=256,
|
128
125
|
)
|
129
126
|
return config
|
@@ -22,6 +22,7 @@ from typing import Callable, Dict, List, Optional, Tuple
|
|
22
22
|
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
23
23
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
|
+
from ai_edge_torch.generative.utilities import loader
|
25
26
|
from ai_edge_torch.generative.utilities import verifier
|
26
27
|
from gemma import config as gemma_config
|
27
28
|
from gemma import model as gemma_model
|
@@ -260,3 +261,15 @@ def verify_gemma3(
|
|
260
261
|
custom_loader=custom_loader,
|
261
262
|
atol=1e-04,
|
262
263
|
)
|
264
|
+
|
265
|
+
|
266
|
+
def verify_gemma3_with_custom_loader(checkpoint: str) -> bool:
|
267
|
+
"""Verifies the reauthored Gemma3 model with a custom loader."""
|
268
|
+
return verify_gemma3(
|
269
|
+
checkpoint=checkpoint,
|
270
|
+
prompts=["What is the meaning of life?"],
|
271
|
+
max_new_tokens=30,
|
272
|
+
variant="1b",
|
273
|
+
weight_filename="model.ckpt",
|
274
|
+
custom_loader=loader.get_custom_loader("", checkpoint_format="pt"),
|
275
|
+
)
|
@@ -45,9 +45,7 @@ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
45
45
|
intermediate_size=8960,
|
46
46
|
)
|
47
47
|
norm_config = cfg.NormalizationConfig(
|
48
|
-
type=cfg.NormalizationType.RMS_NORM,
|
49
|
-
epsilon=1e-06,
|
50
|
-
enable_hlfb=True,
|
48
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
|
51
49
|
)
|
52
50
|
block_config = cfg.TransformerBlockConfig(
|
53
51
|
attn_config=attn_config,
|
@@ -63,7 +61,6 @@ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
63
61
|
kv_cache_max_len=kv_cache_max_len,
|
64
62
|
block_configs=block_config,
|
65
63
|
final_norm_config=norm_config,
|
66
|
-
enable_hlfb=True,
|
67
64
|
)
|
68
65
|
return config
|
69
66
|
|