ai-edge-torch-nightly 0.5.0.dev20250516__py3-none-any.whl → 0.5.0.dev20250518__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. ai_edge_torch/_convert/conversion.py +1 -0
  2. ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
  3. ai_edge_torch/_convert/fx_passes/eliminate_dead_code_pass.py +40 -0
  4. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +1 -4
  5. ai_edge_torch/generative/examples/amd_llama_135m/verify.py +4 -32
  6. ai_edge_torch/generative/examples/amd_llama_135m/verify_util.py +76 -0
  7. ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -4
  8. ai_edge_torch/generative/examples/deepseek/verify.py +4 -30
  9. ai_edge_torch/generative/examples/deepseek/verify_util.py +76 -0
  10. ai_edge_torch/generative/examples/gemma/gemma1.py +1 -5
  11. ai_edge_torch/generative/examples/gemma/gemma2.py +1 -5
  12. ai_edge_torch/generative/examples/gemma/verify_util.py +51 -6
  13. ai_edge_torch/generative/examples/gemma3/decoder.py +10 -10
  14. ai_edge_torch/generative/examples/gemma3/gemma3.py +1 -3
  15. ai_edge_torch/generative/examples/gemma3/image_encoder.py +1 -4
  16. ai_edge_torch/generative/examples/gemma3/verify_util.py +13 -0
  17. ai_edge_torch/generative/examples/hammer/hammer.py +1 -4
  18. ai_edge_torch/generative/examples/hammer/verify.py +5 -35
  19. ai_edge_torch/generative/examples/hammer/verify_util.py +82 -0
  20. ai_edge_torch/generative/examples/llama/llama.py +1 -4
  21. ai_edge_torch/generative/examples/llama/verify.py +5 -38
  22. ai_edge_torch/generative/examples/llama/verify_util.py +81 -0
  23. ai_edge_torch/generative/examples/openelm/openelm.py +1 -2
  24. ai_edge_torch/generative/examples/openelm/verify.py +4 -31
  25. ai_edge_torch/generative/examples/openelm/verify_util.py +76 -0
  26. ai_edge_torch/generative/examples/paligemma/decoder.py +1 -5
  27. ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -5
  28. ai_edge_torch/generative/examples/paligemma/image_encoder.py +1 -4
  29. ai_edge_torch/generative/examples/phi/phi2.py +1 -4
  30. ai_edge_torch/generative/examples/phi/phi3.py +1 -4
  31. ai_edge_torch/generative/examples/phi/phi4.py +1 -4
  32. ai_edge_torch/generative/examples/phi/verify.py +6 -24
  33. ai_edge_torch/generative/examples/phi/verify_phi3.py +5 -28
  34. ai_edge_torch/generative/examples/phi/verify_phi4.py +5 -28
  35. ai_edge_torch/generative/examples/phi/verify_util.py +84 -0
  36. ai_edge_torch/generative/examples/qwen/qwen.py +1 -4
  37. ai_edge_torch/generative/examples/qwen/verify.py +5 -35
  38. ai_edge_torch/generative/examples/qwen/verify_util.py +83 -0
  39. ai_edge_torch/generative/examples/qwen_vl/decoder.py +1 -2
  40. ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +1 -3
  41. ai_edge_torch/generative/examples/smollm/smollm.py +1 -4
  42. ai_edge_torch/generative/examples/smollm/verify.py +5 -36
  43. ai_edge_torch/generative/examples/smollm/verify_util.py +81 -0
  44. ai_edge_torch/generative/examples/stable_diffusion/clip.py +6 -4
  45. ai_edge_torch/generative/examples/t5/t5.py +1 -3
  46. ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
  47. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -2
  48. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +1 -4
  49. ai_edge_torch/generative/examples/tiny_llama/verify.py +4 -31
  50. ai_edge_torch/generative/examples/tiny_llama/verify_util.py +76 -0
  51. ai_edge_torch/generative/layers/model_config.py +2 -2
  52. ai_edge_torch/generative/utilities/converter.py +2 -1
  53. ai_edge_torch/generative/utilities/loader.py +11 -1
  54. ai_edge_torch/odml_torch/lowerings/utils.py +13 -0
  55. ai_edge_torch/version.py +1 -1
  56. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/METADATA +1 -1
  57. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/RECORD +60 -50
  58. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/LICENSE +0 -0
  59. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/WHEEL +0 -0
  60. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250518.dist-info}/top_level.txt +0 -0
@@ -38,6 +38,7 @@ def _run_convert_passes(
38
38
  )
39
39
 
40
40
  passes = [
41
+ fx_passes.EliminateDeadCodePass(),
41
42
  fx_passes.OptimizeLayoutTransposesPass(),
42
43
  fx_passes.CanonicalizePass(),
43
44
  fx_passes.BuildAtenCompositePass(),
@@ -17,6 +17,7 @@ from typing import Sequence, Union
17
17
 
18
18
  from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
19
19
  from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
20
+ from ai_edge_torch._convert.fx_passes.eliminate_dead_code_pass import EliminateDeadCodePass
20
21
  from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
21
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
22
23
  from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
@@ -0,0 +1,40 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Pass to eliminate dead code for ai-edge-torch conversion."""
16
+
17
+
18
+ from ai_edge_torch import fx_infra
19
+ import torch
20
+
21
+
22
+ class EliminateDeadCodePass(fx_infra.PassBase):
23
+ """Eliminates dead code with dedicated rules for ai-edge-torch conversion."""
24
+
25
+ def call(self, graph_module: torch.fx.GraphModule):
26
+ def is_impure_node(node: torch.fx.Node):
27
+ # Starting from torch 2.7.0, random torch ops with
28
+ # _nondeterministic_seeded set are no longer considered pure. However,
29
+ # for conversion, unused random ops/tensors should still be removed.
30
+ if getattr(node.target, "_nondeterministic_seeded", False):
31
+ return False
32
+ return node.is_impure()
33
+
34
+ try:
35
+ graph_module.graph.eliminate_dead_code(is_impure_node)
36
+ except TypeError:
37
+ # eliminate_dead_code has no is_impure_node input in old torch versions.
38
+ pass
39
+
40
+ return fx_infra.PassResult(graph_module, True)
@@ -51,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
51
51
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
52
52
  intermediate_size=2048,
53
53
  )
54
- norm_config = cfg.NormalizationConfig(
55
- type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
56
- )
54
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
57
55
  block_config = cfg.TransformerBlockConfig(
58
56
  attn_config=attn_config,
59
57
  ff_config=ff_config,
@@ -69,7 +67,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
69
67
  block_configs=block_config,
70
68
  final_norm_config=norm_config,
71
69
  lm_head_share_weight_with_embedding=False,
72
- enable_hlfb=True,
73
70
  )
74
71
  return config
75
72
 
@@ -15,15 +15,9 @@
15
15
 
16
16
  """Verifies the reauthored AMD-Llama-135M model."""
17
17
 
18
- import logging
19
- import pathlib
20
-
21
18
  from absl import app
22
19
  from absl import flags
23
- from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
20
+ from ai_edge_torch.generative.examples.amd_llama_135m import verify_util
27
21
 
28
22
 
29
23
  _PROMPTS = flags.DEFINE_multi_string(
@@ -39,32 +33,10 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
39
33
 
40
34
 
41
35
  def main(_):
42
- checkpoint = "amd/AMD-Llama-135m"
43
- logging.info("Loading the original model from: %s", checkpoint)
44
- original_model = transformers.AutoModelForCausalLM.from_pretrained(
45
- checkpoint, trust_remote_code=True
46
- )
47
-
48
- # Locate the cached dir.
49
- cached_config_file = transformers.utils.cached_file(
50
- checkpoint, transformers.utils.CONFIG_NAME
51
- )
52
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
53
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
54
- reauthored_model = amd_llama_135m.build_model(str(reauthored_checkpoint))
55
-
56
- logging.info("Loading the tokenizer from: %s", checkpoint)
57
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
58
-
59
- verifier.verify_reauthored_model(
60
- original_model=transformers_verifier.TransformersModelWrapper(
61
- original_model
62
- ),
63
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
- tokenizer=verifier.TokenizerWrapper(tokenizer),
65
- generate_prompts=_PROMPTS.value,
36
+ verify_util.verify_amd_llama_135m(
37
+ "amd/AMD-Llama-135m",
66
38
  max_new_tokens=_MAX_NEW_TOKENS.value,
67
- atol=1e-04,
39
+ prompts=_PROMPTS.value,
68
40
  )
69
41
 
70
42
 
@@ -0,0 +1,76 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utils for verifying the AMD-Llama-135M model."""
16
+ import logging
17
+ import os
18
+ import pathlib
19
+
20
+ from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
21
+ from ai_edge_torch.generative.utilities import loader
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+
27
+ DEFAULT_PROMPTS = ["Tell me a story?\nOnce upon a time"]
28
+
29
+
30
+ def verify_amd_llama_135m(
31
+ checkpoint_dir: str,
32
+ weight_filename: str = "model.safetensors",
33
+ max_new_tokens: int = 30,
34
+ initialize_from_local: bool = True,
35
+ prompts: list[str] | None = None,
36
+ ) -> bool:
37
+ """Verifies the reauthored AMD-Llama-135M model with a custom loader."""
38
+ logging.info("Loading the original model from: %s", checkpoint_dir)
39
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
40
+ checkpoint_dir
41
+ )
42
+
43
+ logging.info("Building the reauthored model from: %s", checkpoint_dir)
44
+ custom_loader = (
45
+ None
46
+ if initialize_from_local
47
+ else loader.get_custom_loader("", "safetensors")
48
+ )
49
+
50
+ if initialize_from_local:
51
+ # Locate the cached dir.
52
+ cached_config_file = transformers.utils.cached_file(
53
+ checkpoint_dir, transformers.utils.CONFIG_NAME
54
+ )
55
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
56
+ else:
57
+ reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
58
+
59
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
60
+ reauthored_model = amd_llama_135m.build_model(
61
+ checkpoint_path=reauthored_checkpoint,
62
+ custom_loader=custom_loader,
63
+ )
64
+
65
+ logging.info("Loading the tokenizer from: %s", checkpoint_dir)
66
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
67
+ return verifier.verify_reauthored_model(
68
+ original_model=transformers_verifier.TransformersModelWrapper(
69
+ original_model
70
+ ),
71
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
72
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
73
+ generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
74
+ max_new_tokens=max_new_tokens,
75
+ atol=1e-04,
76
+ )
@@ -53,9 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
53
53
  intermediate_size=8960,
54
54
  )
55
55
  norm_config = cfg.NormalizationConfig(
56
- type=cfg.NormalizationType.RMS_NORM,
57
- epsilon=1e-06,
58
- enable_hlfb=True,
56
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
59
57
  )
60
58
  block_config = cfg.TransformerBlockConfig(
61
59
  attn_config=attn_config,
@@ -72,7 +70,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
72
70
  block_configs=block_config,
73
71
  final_norm_config=norm_config,
74
72
  lm_head_share_weight_with_embedding=False,
75
- enable_hlfb=True,
76
73
  )
77
74
  return config
78
75
 
@@ -15,15 +15,9 @@
15
15
 
16
16
  """Verifies the reauthored DeepSeek R1 distilled 1.5B model."""
17
17
 
18
- import logging
19
- import pathlib
20
-
21
18
  from absl import app
22
19
  from absl import flags
23
- from ai_edge_torch.generative.examples.deepseek import deepseek
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
20
+ from ai_edge_torch.generative.examples.deepseek import verify_util
27
21
 
28
22
 
29
23
  _PROMPTS = flags.DEFINE_multi_string(
@@ -39,30 +33,10 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
39
33
 
40
34
 
41
35
  def main(_):
42
- checkpoint = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
43
- logging.info("Loading the original model from: %s", checkpoint)
44
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
-
46
- # Locate the cached dir.
47
- cached_config_file = transformers.utils.cached_file(
48
- checkpoint, transformers.utils.CONFIG_NAME
49
- )
50
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
- reauthored_model = deepseek.build_model(str(reauthored_checkpoint))
53
-
54
- logging.info("Loading the tokenizer from: %s", checkpoint)
55
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
56
-
57
- verifier.verify_reauthored_model(
58
- original_model=transformers_verifier.TransformersModelWrapper(
59
- original_model
60
- ),
61
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
- tokenizer=verifier.TokenizerWrapper(tokenizer),
63
- generate_prompts=_PROMPTS.value,
36
+ verify_util.verify_deepseek_r1_distill_1_5b(
37
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
64
38
  max_new_tokens=_MAX_NEW_TOKENS.value,
65
- atol=1e-04,
39
+ prompts=_PROMPTS.value,
66
40
  )
67
41
 
68
42
 
@@ -0,0 +1,76 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utils for verifying the DeepSeek R1 distilled 1.5B model."""
16
+ import logging
17
+ import os
18
+ import pathlib
19
+
20
+ from ai_edge_torch.generative.examples.deepseek import deepseek
21
+ from ai_edge_torch.generative.utilities import loader
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+
27
+ DEFAULT_PROMPTS = ["What is the meaning of life?"]
28
+
29
+
30
+ def verify_deepseek_r1_distill_1_5b(
31
+ checkpoint_dir: str,
32
+ weight_filename: str = "model.safetensors",
33
+ max_new_tokens: int = 30,
34
+ initialize_from_local: bool = True,
35
+ prompts: list[str] | None = None,
36
+ ) -> bool:
37
+ """Verifies the reauthored DeepSeek R1 distilled 1.5B model with a custom loader."""
38
+ logging.info("Loading the original model from: %s", checkpoint_dir)
39
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
40
+ checkpoint_dir
41
+ )
42
+
43
+ logging.info("Building the reauthored model from: %s", checkpoint_dir)
44
+ custom_loader = (
45
+ None
46
+ if initialize_from_local
47
+ else loader.get_custom_loader("", "safetensors")
48
+ )
49
+
50
+ if initialize_from_local:
51
+ # Locate the cached dir.
52
+ cached_config_file = transformers.utils.cached_file(
53
+ checkpoint_dir, transformers.utils.CONFIG_NAME
54
+ )
55
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
56
+ else:
57
+ reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
58
+
59
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
60
+ reauthored_model = deepseek.build_model(
61
+ checkpoint_path=reauthored_checkpoint,
62
+ custom_loader=custom_loader,
63
+ )
64
+
65
+ logging.info("Loading the tokenizer from: %s", checkpoint_dir)
66
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
67
+ return verifier.verify_reauthored_model(
68
+ original_model=transformers_verifier.TransformersModelWrapper(
69
+ original_model
70
+ ),
71
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
72
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
73
+ generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
74
+ max_new_tokens=max_new_tokens,
75
+ atol=1e-04,
76
+ )
@@ -65,10 +65,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
65
65
  intermediate_size=16384,
66
66
  )
67
67
  norm_config = cfg.NormalizationConfig(
68
- type=cfg.NormalizationType.RMS_NORM,
69
- epsilon=1e-6,
70
- zero_centered=True,
71
- enable_hlfb=True,
68
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
72
69
  )
73
70
  block_config = cfg.TransformerBlockConfig(
74
71
  attn_config=attn_config,
@@ -87,7 +84,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
87
84
  block_configs=block_config,
88
85
  final_norm_config=norm_config,
89
86
  lm_head_use_bias=False,
90
- enable_hlfb=True,
91
87
  )
92
88
  return config
93
89
 
@@ -233,10 +233,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
233
233
  The model config for a Gemma 2B model.
234
234
  """
235
235
  norm_config = cfg.NormalizationConfig(
236
- type=cfg.NormalizationType.RMS_NORM,
237
- epsilon=1e-6,
238
- zero_centered=True,
239
- enable_hlfb=True,
236
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
240
237
  )
241
238
  ff_config = cfg.FeedForwardConfig(
242
239
  type=cfg.FeedForwardType.GATED,
@@ -284,7 +281,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
284
281
  block_configs=[get_block_config(i) for i in range(num_layers)],
285
282
  final_norm_config=norm_config,
286
283
  lm_head_use_bias=False,
287
- enable_hlfb=True,
288
284
  final_logit_softcap=30.0,
289
285
  )
290
286
  return config
@@ -17,11 +17,13 @@
17
17
 
18
18
  import logging
19
19
  import os
20
- from typing import List, Tuple
20
+ from typing import Callable, Dict, List, Tuple
21
21
 
22
+ from ai_edge_torch.generative.examples.gemma import gemma1
22
23
  from ai_edge_torch.generative.examples.gemma import gemma2
23
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
25
  import ai_edge_torch.generative.layers.kv_cache as kv_utils
26
+ from ai_edge_torch.generative.utilities import loader
25
27
  from ai_edge_torch.generative.utilities import verifier
26
28
  from gemma import config as gemma_config
27
29
  from gemma import model as gemma_model
@@ -107,6 +109,7 @@ def verify_reauthored_gemma_model(
107
109
  generate_prompts: List[str],
108
110
  forward_input_ids: List[List[int]],
109
111
  weight_filename: str = "model.ckpt",
112
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
110
113
  tokenizer_filename: str = "tokenizer.model",
111
114
  max_new_tokens: int = 20,
112
115
  mask_as_input: bool = False,
@@ -125,7 +128,14 @@ def verify_reauthored_gemma_model(
125
128
 
126
129
  logging.info("Loading the original model from: %s", checkpoint)
127
130
  original_model = gemma_model.GemmaForCausalLM(config).eval()
128
- original_model.load_weights(os.path.join(checkpoint, weight_filename))
131
+ checkpoint_path = os.path.join(checkpoint, weight_filename)
132
+ if custom_loader is None:
133
+ original_model.load_weights(checkpoint_path)
134
+ else:
135
+ original_model.load_state_dict(
136
+ custom_loader(checkpoint_path)["model_state_dict"],
137
+ strict=False,
138
+ )
129
139
 
130
140
  return verifier.verify_reauthored_model(
131
141
  original_model=GemmaWrapper(original_model),
@@ -144,27 +154,62 @@ def verify_reauthored_gemma_model(
144
154
 
145
155
 
146
156
  def verify_gemma2(
147
- gemma2_model_path: str,
157
+ checkpoint_dir: str,
158
+ weight_filename: str,
148
159
  prompts: List[str],
149
160
  max_new_tokens: int,
150
161
  mask_as_input: bool = False,
151
162
  kv_layout: kv_utils.KVLayout = kv_utils.KV_LAYOUT_DEFAULT,
163
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
152
164
  ) -> bool:
153
165
  """Verifies the reauthored Gemma2 model.
154
166
 
155
167
  Return True if the verification passes, False otherwise.
156
168
  """
157
- logging.info("Building the reauthored model from: %s", gemma2_model_path)
158
- reauthored_model = gemma2.build_2b_model(gemma2_model_path)
169
+ checkpoint_path = os.path.join(checkpoint_dir, weight_filename)
170
+ logging.info("Building the reauthored model from: %s", checkpoint_path)
171
+ reauthored_model = gemma2.build_2b_model(checkpoint_path, custom_loader)
159
172
 
160
173
  return verify_reauthored_gemma_model(
161
- checkpoint=gemma2_model_path,
174
+ checkpoint=checkpoint_dir,
162
175
  variant="2b-v2",
163
176
  reauthored_model=reauthored_model,
164
177
  generate_prompts=prompts,
165
178
  forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
179
+ weight_filename=weight_filename,
180
+ custom_loader=custom_loader,
166
181
  max_new_tokens=max_new_tokens,
167
182
  mask_as_input=mask_as_input,
168
183
  kv_layout=kv_layout,
169
184
  atol=1e-04,
170
185
  )
186
+
187
+
188
+ def verify_gemma1_with_custom_loader(checkpoint_dir: str) -> bool:
189
+ """Verifies the reauthored Gemma1 model with a custom loader."""
190
+ weight_filename = "gemma-2b-it.ckpt"
191
+ checkpoint_path = os.path.join(checkpoint_dir, weight_filename)
192
+ custom_loader = loader.get_custom_loader(checkpoint_path)
193
+ reauthored_model = gemma1.build_2b_model(checkpoint_path, custom_loader)
194
+ return verify_reauthored_gemma_model(
195
+ checkpoint=checkpoint_dir,
196
+ variant="2b",
197
+ reauthored_model=reauthored_model,
198
+ weight_filename=weight_filename,
199
+ custom_loader=custom_loader,
200
+ generate_prompts=["What is the meaning of life?"],
201
+ forward_input_ids=[[1, 2, 3, 4]],
202
+ max_new_tokens=30,
203
+ )
204
+
205
+
206
+ def verify_gemma2_with_custom_loader(checkpoint_dir: str) -> bool:
207
+ """Verifies the reauthored Gemma2 model with a custom loader."""
208
+ return verify_gemma2(
209
+ checkpoint_dir=checkpoint_dir,
210
+ weight_filename="model.ckpt",
211
+ prompts=["What is the meaning of life?"],
212
+ max_new_tokens=30,
213
+ mask_as_input=True,
214
+ custom_loader=loader.get_custom_loader("", checkpoint_format="pt"),
215
+ )
@@ -149,8 +149,12 @@ class Decoder(nn.Module):
149
149
  cache_len=attention_mask.shape[-1],
150
150
  sliding_window_size=sliding_window_size,
151
151
  )
152
- # Combine masks using logical AND (min in this case).
153
- combined_mask = torch.min(attention_mask, sliding_mask)
152
+ # Expand sliding_mask to match attention_mask's dimensions
153
+ # (e.g., [B, 1, seq_len, cache_len]).
154
+ # Assuming the head dimension is dim 1 for attention_mask.
155
+ expanded_sliding_mask = sliding_mask.unsqueeze(1)
156
+ # Combine masks using logical AND (min ensures -inf propagates).
157
+ combined_mask = torch.min(attention_mask, expanded_sliding_mask)
154
158
  return combined_mask
155
159
  return attention_mask
156
160
 
@@ -161,9 +165,9 @@ class Decoder(nn.Module):
161
165
  sliding_window_size: int,
162
166
  ) -> torch.Tensor:
163
167
  """Creates mask for sliding window attention (PyTorch)."""
164
- cache_positions = torch.tensor(
165
- [i for i in range(cache_len)], dtype=torch.int32
166
- )
168
+ # Use torch.arange to create a tensor with a range of integers in a
169
+ # Dynamo-friendly way.
170
+ cache_positions = torch.arange(cache_len, dtype=torch.int32)
167
171
  cache_positions = cache_positions.view(1, 1, -1) # [1, 1, cache_len]
168
172
  segment_pos_expanded = segment_pos.clone().unsqueeze(-1) # [B, seq_len, 1]
169
173
 
@@ -329,10 +333,7 @@ def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
329
333
  The model config for a Gemma 1B model.
330
334
  """
331
335
  norm_config = cfg.NormalizationConfig(
332
- type=cfg.NormalizationType.RMS_NORM,
333
- epsilon=1e-6,
334
- zero_centered=True,
335
- enable_hlfb=True,
336
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True,
336
337
  )
337
338
  ff_config = cfg.FeedForwardConfig(
338
339
  type=cfg.FeedForwardType.GATED,
@@ -379,7 +380,6 @@ def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
379
380
  block_configs=[get_block_config(i) for i in range(num_layers)],
380
381
  final_norm_config=norm_config,
381
382
  lm_head_use_bias=False,
382
- enable_hlfb=True,
383
383
  final_logit_softcap=None,
384
384
  )
385
385
  return config
@@ -158,9 +158,7 @@ def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
158
158
  image_projection_scale=128**0.5,
159
159
  image_projection_use_bias=False,
160
160
  mm_norm_config=cfg.NormalizationConfig(
161
- type=cfg.NormalizationType.LAYER_NORM,
162
- epsilon=1e-6,
163
- enable_hlfb=True,
161
+ type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
164
162
  ),
165
163
  mm_extra_tokens=32,
166
164
  )
@@ -98,9 +98,7 @@ def get_image_encoder_config() -> cfg.ModelConfig:
98
98
  output_proj_use_bias=True,
99
99
  )
100
100
  norm_config = cfg.NormalizationConfig(
101
- type=cfg.NormalizationType.LAYER_NORM,
102
- epsilon=1e-6,
103
- enable_hlfb=True,
101
+ type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
104
102
  )
105
103
  ff_config = cfg.FeedForwardConfig(
106
104
  type=cfg.FeedForwardType.SEQUENTIAL,
@@ -123,7 +121,6 @@ def get_image_encoder_config() -> cfg.ModelConfig:
123
121
  image_embedding=image_embedding_config,
124
122
  block_configs=block_config,
125
123
  final_norm_config=norm_config,
126
- enable_hlfb=True,
127
124
  num_mm_tokens_per_image=256,
128
125
  )
129
126
  return config
@@ -22,6 +22,7 @@ from typing import Callable, Dict, List, Optional, Tuple
22
22
  from ai_edge_torch.generative.examples.gemma3 import gemma3
23
23
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
+ from ai_edge_torch.generative.utilities import loader
25
26
  from ai_edge_torch.generative.utilities import verifier
26
27
  from gemma import config as gemma_config
27
28
  from gemma import model as gemma_model
@@ -260,3 +261,15 @@ def verify_gemma3(
260
261
  custom_loader=custom_loader,
261
262
  atol=1e-04,
262
263
  )
264
+
265
+
266
+ def verify_gemma3_with_custom_loader(checkpoint: str) -> bool:
267
+ """Verifies the reauthored Gemma3 model with a custom loader."""
268
+ return verify_gemma3(
269
+ checkpoint=checkpoint,
270
+ prompts=["What is the meaning of life?"],
271
+ max_new_tokens=30,
272
+ variant="1b",
273
+ weight_filename="model.ckpt",
274
+ custom_loader=loader.get_custom_loader("", checkpoint_format="pt"),
275
+ )
@@ -45,9 +45,7 @@ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
45
45
  intermediate_size=8960,
46
46
  )
47
47
  norm_config = cfg.NormalizationConfig(
48
- type=cfg.NormalizationType.RMS_NORM,
49
- epsilon=1e-06,
50
- enable_hlfb=True,
48
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
51
49
  )
52
50
  block_config = cfg.TransformerBlockConfig(
53
51
  attn_config=attn_config,
@@ -63,7 +61,6 @@ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
63
61
  kv_cache_max_len=kv_cache_max_len,
64
62
  block_configs=block_config,
65
63
  final_norm_config=norm_config,
66
- enable_hlfb=True,
67
64
  )
68
65
  return config
69
66