ai-edge-torch-nightly 0.5.0.dev20250515__py3-none-any.whl → 0.5.0.dev20250517__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. ai_edge_torch/__init__.py +1 -0
  2. ai_edge_torch/_convert/conversion.py +24 -0
  3. ai_edge_torch/_convert/converter.py +57 -3
  4. ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
  5. ai_edge_torch/_convert/fx_passes/eliminate_dead_code_pass.py +40 -0
  6. ai_edge_torch/_convert/test/test_convert.py +25 -0
  7. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +10 -6
  8. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +7 -1
  9. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +7 -2
  10. ai_edge_torch/generative/examples/deepseek/deepseek.py +9 -5
  11. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +7 -1
  12. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +7 -1
  13. ai_edge_torch/generative/examples/gemma/gemma1.py +10 -6
  14. ai_edge_torch/generative/examples/gemma/gemma2.py +8 -7
  15. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +5 -14
  16. ai_edge_torch/generative/examples/gemma3/decoder.py +10 -10
  17. ai_edge_torch/generative/examples/gemma3/gemma3.py +1 -3
  18. ai_edge_torch/generative/examples/gemma3/image_encoder.py +1 -4
  19. ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +7 -1
  20. ai_edge_torch/generative/examples/hammer/hammer.py +15 -6
  21. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +7 -1
  22. ai_edge_torch/generative/examples/llama/llama.py +26 -10
  23. ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +0 -1
  24. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +7 -1
  25. ai_edge_torch/generative/examples/openelm/openelm.py +9 -3
  26. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +6 -1
  27. ai_edge_torch/generative/examples/paligemma/decoder.py +1 -4
  28. ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -4
  29. ai_edge_torch/generative/examples/paligemma/image_encoder.py +3 -5
  30. ai_edge_torch/generative/examples/paligemma/paligemma.py +12 -5
  31. ai_edge_torch/generative/examples/paligemma/verify.py +27 -5
  32. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +7 -1
  33. ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +7 -1
  34. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +7 -1
  35. ai_edge_torch/generative/examples/phi/phi2.py +9 -5
  36. ai_edge_torch/generative/examples/phi/phi3.py +8 -6
  37. ai_edge_torch/generative/examples/phi/phi4.py +8 -6
  38. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +7 -1
  39. ai_edge_torch/generative/examples/qwen/qwen.py +21 -7
  40. ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +6 -1
  41. ai_edge_torch/generative/examples/qwen_vl/decoder.py +1 -3
  42. ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +13 -7
  43. ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +12 -4
  44. ai_edge_torch/generative/examples/qwen_vl/verify.py +26 -5
  45. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +7 -2
  46. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +7 -1
  47. ai_edge_torch/generative/examples/smollm/smollm.py +15 -6
  48. ai_edge_torch/generative/examples/smollm/verify.py +2 -2
  49. ai_edge_torch/generative/examples/stable_diffusion/clip.py +8 -5
  50. ai_edge_torch/generative/examples/t5/t5.py +1 -3
  51. ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
  52. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -2
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +7 -1
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +9 -5
  55. ai_edge_torch/generative/layers/model_config.py +2 -2
  56. ai_edge_torch/generative/utilities/converter.py +18 -5
  57. ai_edge_torch/generative/utilities/loader.py +19 -0
  58. ai_edge_torch/odml_torch/lowerings/utils.py +13 -0
  59. ai_edge_torch/version.py +1 -1
  60. {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/METADATA +1 -1
  61. {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/RECORD +64 -63
  62. {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/LICENSE +0 -0
  63. {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/WHEEL +0 -0
  64. {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py CHANGED
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  from ai_edge_torch._config import config
16
16
  from ai_edge_torch._convert.converter import convert
17
+ from ai_edge_torch._convert.converter import experimental_add_compilation_backend
17
18
  from ai_edge_torch._convert.converter import signature
18
19
  from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
19
20
  from ai_edge_torch.model import Model
@@ -26,6 +26,9 @@ from ai_edge_torch.generative import fx_passes as generative_fx_passes
26
26
  from ai_edge_torch.quantize import quant_config as qcfg
27
27
  import torch
28
28
 
29
+ from ai_edge_litert.aot import aot_compile as aot_compile_lib
30
+ from ai_edge_litert.aot.core import types as litert_types
31
+
29
32
 
30
33
  def _run_convert_passes(
31
34
  exported_program: torch.export.ExportedProgram,
@@ -35,6 +38,7 @@ def _run_convert_passes(
35
38
  )
36
39
 
37
40
  passes = [
41
+ fx_passes.EliminateDeadCodePass(),
38
42
  fx_passes.OptimizeLayoutTransposesPass(),
39
43
  fx_passes.CanonicalizePass(),
40
44
  fx_passes.BuildAtenCompositePass(),
@@ -153,3 +157,23 @@ def convert_signatures(
153
157
  )
154
158
 
155
159
  return model.TfLiteModel(tflite_model)
160
+
161
+
162
+ def aot_compile(
163
+ compilation_configs: list[litert_types.CompilationConfig],
164
+ cpu_model: model.TfLiteModel,
165
+ ) -> litert_types.CompilationResult:
166
+ """Compiles the given CPU model.
167
+
168
+ Args:
169
+ compilation_configs: The list of compilation configs to use.
170
+ cpu_model: The CPU model to compile.
171
+
172
+ Returns:
173
+ The compilation result.
174
+ """
175
+ litert_model = litert_types.Model.create_from_bytes(cpu_model.tflite_model())
176
+ return aot_compile_lib.aot_compile(
177
+ litert_model,
178
+ config=compilation_configs,
179
+ )
@@ -23,6 +23,9 @@ from ai_edge_torch._convert import signature as signature_module
23
23
  from ai_edge_torch.quantize import quant_config as qcfg
24
24
  import torch
25
25
 
26
+ from ai_edge_litert.aot.core import types as litert_types
27
+ from ai_edge_litert.aot.vendors import import_vendor as vendor_lib
28
+
26
29
 
27
30
  class Converter:
28
31
  """A converter for converting PyTorch models to edge models.
@@ -32,6 +35,7 @@ class Converter:
32
35
 
33
36
  def __init__(self):
34
37
  self._signatures: list[signature_module.Signature] = []
38
+ self._compilation_configs: list[litert_types.CompilationConfig] = []
35
39
 
36
40
  def signature(
37
41
  self,
@@ -96,6 +100,31 @@ class Converter:
96
100
  )
97
101
  return self
98
102
 
103
+ def experimental_add_compilation_backend(
104
+ self,
105
+ target: litert_types.Target | None = None,
106
+ **kwargs: litert_types.Config,
107
+ ) -> Converter:
108
+ """Adds an AOT compilation target to the converter.
109
+
110
+ NOTE: This API is experimental and subject to change.
111
+
112
+ Args:
113
+ target: The target to compile for. If not specified, will compile to all
114
+ registered AOT targets in ai_edge_litert. See ai_edge_litert.aot.vendors
115
+ for more details. Adding a same target multiple times will be a no-op.
116
+ **kwargs: Additional arguments to pass to the backend compiler.
117
+
118
+ Returns:
119
+ The converter object itself.
120
+ """
121
+ if target is None:
122
+ target = vendor_lib.AllRegisteredTarget()
123
+ if isinstance(target, litert_types.Target):
124
+ target = litert_types.CompilationConfig(target=target, **kwargs)
125
+ self._compilation_configs.append(target)
126
+ return self
127
+
99
128
  def convert(
100
129
  self,
101
130
  module: torch.nn.Module = None,
@@ -107,7 +136,7 @@ class Converter:
107
136
  dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
108
137
  _ai_edge_converter_flags: Optional[dict[str, Any]] = None,
109
138
  _saved_model_dir: Optional[str] = None,
110
- ) -> model.TfLiteModel:
139
+ ) -> model.TfLiteModel | litert_types.CompilationResult:
111
140
  """Finalizes the conversion and produces an edge model.
112
141
 
113
142
  This could be called with no arguments as follows:
@@ -144,7 +173,9 @@ class Converter:
144
173
  specified, a random temporary directory would be used.
145
174
 
146
175
  Returns:
147
- The converted edge model.
176
+ The converted edge model. If compilation configs are provided, returns the
177
+ compilation result that contains the compiled edge models for different
178
+ targets.
148
179
 
149
180
  Raises:
150
181
  ValueError: If the arguments are not provided as expected. See the example
@@ -169,13 +200,16 @@ class Converter:
169
200
  "sample_args or sample_kwargs must be provided if a module is"
170
201
  " specified."
171
202
  )
172
- return conversion.convert_signatures(
203
+ converted_model = conversion.convert_signatures(
173
204
  self._signatures,
174
205
  strict_export=strict_export,
175
206
  quant_config=quant_config,
176
207
  _tfl_converter_flags=_ai_edge_converter_flags,
177
208
  _saved_model_dir=_saved_model_dir,
178
209
  )
210
+ if self._compilation_configs:
211
+ return conversion.aot_compile(self._compilation_configs, converted_model)
212
+ return converted_model
179
213
 
180
214
 
181
215
  def signature(
@@ -211,6 +245,26 @@ def signature(
211
245
  )
212
246
 
213
247
 
248
+ def experimental_add_compilation_backend(
249
+ target: litert_types.Target | None = None,
250
+ **kwargs: litert_types.Config,
251
+ ) -> Converter:
252
+ """Adds an AOT compilation target to the converter.
253
+
254
+ NOTE: This API is experimental and subject to change.
255
+
256
+ Args:
257
+ target: The target to compile for. If not specified, will compile to all
258
+ registered AOT targets in ai_edge_litert. See ai_edge_litert.aot.vendors
259
+ for more details. Adding a same target multiple times will be a no-op.
260
+ **kwargs: Additional arguments to pass to the backend compiler.
261
+
262
+ Returns:
263
+ The converter object itself.
264
+ """
265
+ return Converter().experimental_add_compilation_backend(target, **kwargs)
266
+
267
+
214
268
  def convert(
215
269
  module: torch.nn.Module = None,
216
270
  sample_args=None,
@@ -17,6 +17,7 @@ from typing import Sequence, Union
17
17
 
18
18
  from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
19
19
  from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
20
+ from ai_edge_torch._convert.fx_passes.eliminate_dead_code_pass import EliminateDeadCodePass
20
21
  from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
21
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
22
23
  from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
@@ -0,0 +1,40 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Pass to eliminate dead code for ai-edge-torch conversion."""
16
+
17
+
18
+ from ai_edge_torch import fx_infra
19
+ import torch
20
+
21
+
22
+ class EliminateDeadCodePass(fx_infra.PassBase):
23
+ """Eliminates dead code with dedicated rules for ai-edge-torch conversion."""
24
+
25
+ def call(self, graph_module: torch.fx.GraphModule):
26
+ def is_impure_node(node: torch.fx.Node):
27
+ # Starting from torch 2.7.0, random torch ops with
28
+ # _nondeterministic_seeded set are no longer considered pure. However,
29
+ # for conversion, unused random ops/tensors should still be removed.
30
+ if getattr(node.target, "_nondeterministic_seeded", False):
31
+ return False
32
+ return node.is_impure()
33
+
34
+ try:
35
+ graph_module.graph.eliminate_dead_code(is_impure_node)
36
+ except TypeError:
37
+ # eliminate_dead_code has no is_impure_node input in old torch versions.
38
+ pass
39
+
40
+ return fx_infra.PassResult(graph_module, True)
@@ -29,6 +29,8 @@ from torch.ao.quantization import quantize_pt2e
29
29
  import torchvision
30
30
 
31
31
  from absl.testing import absltest as googletest
32
+ from ai_edge_litert.aot.core import types as litert_types
33
+ from ai_edge_litert.aot.vendors import fallback_backend
32
34
  from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import
33
35
 
34
36
 
@@ -574,6 +576,29 @@ class TestConvert(googletest.TestCase):
574
576
  self.fail(f"Conversion failed with bloat16 inputs: {err}")
575
577
  # pylint: enable=broad-except
576
578
 
579
+ def test_compile_model(self):
580
+ """Tests AOT compilation of a simple Add module."""
581
+
582
+ class Add(nn.Module):
583
+
584
+ def forward(self, a, b):
585
+ return a + b
586
+
587
+ args = (
588
+ torch.randn((5, 10)),
589
+ torch.randn((5, 10)),
590
+ )
591
+ torch_module = Add().eval()
592
+ compilation_result = ai_edge_torch.experimental_add_compilation_backend(
593
+ fallback_backend.FallbackTarget()
594
+ ).convert(torch_module, args)
595
+ assert isinstance(compilation_result, litert_types.CompilationResult)
596
+ self.assertLen(compilation_result.models_with_backend, 1)
597
+ self.assertEqual(
598
+ compilation_result.models_with_backend[0][0].id(),
599
+ fallback_backend.FallbackBackend.id(),
600
+ )
601
+
577
602
 
578
603
  if __name__ == "__main__":
579
604
  googletest.main()
@@ -15,8 +15,10 @@
15
15
 
16
16
  """Example of building AMD-Llama-135m."""
17
17
 
18
+ from typing import Callable, Dict
18
19
  import ai_edge_torch.generative.layers.model_config as cfg
19
20
  from ai_edge_torch.generative.utilities import model_builder
21
+ import torch
20
22
  from torch import nn
21
23
 
22
24
  TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
@@ -49,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
49
51
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
50
52
  intermediate_size=2048,
51
53
  )
52
- norm_config = cfg.NormalizationConfig(
53
- type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
54
- )
54
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
55
55
  block_config = cfg.TransformerBlockConfig(
56
56
  attn_config=attn_config,
57
57
  ff_config=ff_config,
@@ -67,7 +67,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
67
67
  block_configs=block_config,
68
68
  final_norm_config=norm_config,
69
69
  lm_head_share_weight_with_embedding=False,
70
- enable_hlfb=True,
71
70
  )
72
71
  return config
73
72
 
@@ -80,10 +79,15 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
80
79
  return config
81
80
 
82
81
 
83
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
82
+ def build_model(
83
+ checkpoint_path: str,
84
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
85
+ **kwargs
86
+ ) -> nn.Module:
84
87
  return model_builder.build_decoder_only_model(
85
88
  checkpoint_path=checkpoint_path,
86
89
  config=get_model_config(**kwargs),
87
90
  tensor_names=TENSOR_NAMES,
88
- model_class=AmdLlama
91
+ model_class=AmdLlama,
92
+ custom_loader=custom_loader,
89
93
  )
@@ -19,13 +19,19 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags("amd-llama-135m")
24
25
 
25
26
 
26
27
  def main(_):
28
+ checkpoint_path = flags.FLAGS.checkpoint_path
27
29
  pytorch_model = amd_llama_135m.build_model(
28
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
30
+ checkpoint_path,
31
+ custom_loader=loader.maybe_get_custom_loader(
32
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
+ ),
34
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
29
35
  )
30
36
  converter.convert_to_tflite(
31
37
  pytorch_model,
@@ -17,15 +17,20 @@
17
17
 
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.deepseek import deepseek
20
- from ai_edge_torch.generative.layers import kv_cache
21
20
  from ai_edge_torch.generative.utilities import converter
22
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
23
23
 
24
24
  flags = converter.define_conversion_flags('deepseek')
25
25
 
26
26
  def main(_):
27
+ checkpoint_path = flags.FLAGS.checkpoint_path
27
28
  pytorch_model = deepseek.build_model(
28
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
29
+ checkpoint_path,
30
+ custom_loader=loader.maybe_get_custom_loader(
31
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
32
+ ),
33
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
29
34
  )
30
35
  converter.convert_to_tflite(
31
36
  pytorch_model,
@@ -15,8 +15,10 @@
15
15
 
16
16
  """Example of building DeepSeek R1 distilled models."""
17
17
 
18
+ from typing import Callable, Dict
18
19
  import ai_edge_torch.generative.layers.model_config as cfg
19
20
  from ai_edge_torch.generative.utilities import model_builder
21
+ import torch
20
22
  from torch import nn
21
23
 
22
24
  TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
@@ -51,9 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
51
53
  intermediate_size=8960,
52
54
  )
53
55
  norm_config = cfg.NormalizationConfig(
54
- type=cfg.NormalizationType.RMS_NORM,
55
- epsilon=1e-06,
56
- enable_hlfb=True,
56
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
57
57
  )
58
58
  block_config = cfg.TransformerBlockConfig(
59
59
  attn_config=attn_config,
@@ -70,7 +70,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
70
70
  block_configs=block_config,
71
71
  final_norm_config=norm_config,
72
72
  lm_head_share_weight_with_embedding=False,
73
- enable_hlfb=True,
74
73
  )
75
74
  return config
76
75
 
@@ -84,10 +83,15 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
84
83
  return config
85
84
 
86
85
 
87
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
86
+ def build_model(
87
+ checkpoint_path: str,
88
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
89
+ **kwargs
90
+ ) -> nn.Module:
88
91
  return model_builder.build_decoder_only_model(
89
92
  checkpoint_path=checkpoint_path,
90
93
  config=get_model_config(**kwargs),
91
94
  tensor_names=TENSOR_NAMES,
92
95
  model_class=DeepSeekDistillQwen,
96
+ custom_loader=custom_loader,
93
97
  )
@@ -19,13 +19,19 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.gemma import gemma1
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags("gemma-2b")
24
25
 
25
26
 
26
27
  def main(_):
28
+ checkpoint_path = flags.FLAGS.checkpoint_path
27
29
  pytorch_model = gemma1.build_2b_model(
28
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
30
+ checkpoint_path,
31
+ custom_loader=loader.maybe_get_custom_loader(
32
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
+ ),
34
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
29
35
  )
30
36
  converter.convert_to_tflite(
31
37
  pytorch_model,
@@ -19,6 +19,7 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.gemma import gemma2
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags(
24
25
  "gemma2-2b", default_mask_as_input=True, default_transpose_kv_cache=True
@@ -26,8 +27,13 @@ flags = converter.define_conversion_flags(
26
27
 
27
28
 
28
29
  def main(_):
30
+ checkpoint_path = flags.FLAGS.checkpoint_path
29
31
  pytorch_model = gemma2.build_2b_model(
30
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
32
+ checkpoint_path,
33
+ custom_loader=loader.maybe_get_custom_loader(
34
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
35
+ ),
36
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
31
37
  )
32
38
  converter.convert_to_tflite(
33
39
  pytorch_model,
@@ -15,9 +15,12 @@
15
15
 
16
16
  """Example of building a Gemma1 model."""
17
17
 
18
+ from typing import Callable, Dict
19
+
18
20
  import ai_edge_torch.generative.layers.model_config as cfg
19
21
  from ai_edge_torch.generative.utilities import model_builder
20
22
  import ai_edge_torch.generative.utilities.loader as loading_utils
23
+ import torch
21
24
  from torch import nn
22
25
 
23
26
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
@@ -62,10 +65,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
62
65
  intermediate_size=16384,
63
66
  )
64
67
  norm_config = cfg.NormalizationConfig(
65
- type=cfg.NormalizationType.RMS_NORM,
66
- epsilon=1e-6,
67
- zero_centered=True,
68
- enable_hlfb=True,
68
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
69
69
  )
70
70
  block_config = cfg.TransformerBlockConfig(
71
71
  attn_config=attn_config,
@@ -84,7 +84,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
84
84
  block_configs=block_config,
85
85
  final_norm_config=norm_config,
86
86
  lm_head_use_bias=False,
87
- enable_hlfb=True,
88
87
  )
89
88
  return config
90
89
 
@@ -99,10 +98,15 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
99
98
  return config
100
99
 
101
100
 
102
- def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
101
+ def build_2b_model(
102
+ checkpoint_path: str,
103
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
104
+ **kwargs
105
+ ) -> nn.Module:
103
106
  return model_builder.build_decoder_only_model(
104
107
  checkpoint_path=checkpoint_path,
105
108
  config=get_model_config_2b(**kwargs),
106
109
  tensor_names=TENSOR_NAMES,
107
110
  model_class=Gemma1,
111
+ custom_loader=custom_loader,
108
112
  )
@@ -15,7 +15,7 @@
15
15
 
16
16
  """Example of building a Gemma2 model."""
17
17
 
18
- from typing import List, Optional, Tuple
18
+ from typing import Callable, Dict, List, Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import attention
21
21
  from ai_edge_torch.generative.layers import builder
@@ -233,10 +233,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
233
233
  The model config for a Gemma 2B model.
234
234
  """
235
235
  norm_config = cfg.NormalizationConfig(
236
- type=cfg.NormalizationType.RMS_NORM,
237
- epsilon=1e-6,
238
- zero_centered=True,
239
- enable_hlfb=True,
236
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
240
237
  )
241
238
  ff_config = cfg.FeedForwardConfig(
242
239
  type=cfg.FeedForwardType.GATED,
@@ -284,7 +281,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
284
281
  block_configs=[get_block_config(i) for i in range(num_layers)],
285
282
  final_norm_config=norm_config,
286
283
  lm_head_use_bias=False,
287
- enable_hlfb=True,
288
284
  final_logit_softcap=30.0,
289
285
  )
290
286
  return config
@@ -306,7 +302,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
306
302
  return config
307
303
 
308
304
 
309
- def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
305
+ def build_2b_model(
306
+ checkpoint_path: str,
307
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
308
+ **kwargs,
309
+ ) -> nn.Module:
310
310
  for tensor_names in TENSOR_NAMES_DICT.values():
311
311
  try:
312
312
  return model_builder.build_decoder_only_model(
@@ -314,6 +314,7 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
314
314
  config=get_model_config_2b(**kwargs),
315
315
  tensor_names=tensor_names,
316
316
  model_class=Gemma2,
317
+ custom_loader=custom_loader,
317
318
  )
318
319
  except KeyError as _:
319
320
  continue
@@ -25,13 +25,6 @@ flags = converter.define_conversion_flags(
25
25
  'gemma3-1b', default_mask_as_input=True, default_transpose_kv_cache=True
26
26
  )
27
27
 
28
- _CUSTOM_CHECKPOINT_LOADER = flags.DEFINE_bool(
29
- 'custom_checkpoint_loader',
30
- False,
31
- 'If true, the conversion script will use a custom checkpoint loader which'
32
- ' will read a checkpoint from a remote source.',
33
- )
34
-
35
28
  _MODEL_SIZE = flags.DEFINE_string(
36
29
  'model_size',
37
30
  '1b',
@@ -40,16 +33,14 @@ _MODEL_SIZE = flags.DEFINE_string(
40
33
 
41
34
 
42
35
  def main(_):
43
- custom_loader = None
44
- if flags.FLAGS.custom_checkpoint_loader:
45
- # If loading from a remote source, try to get a custom loader first.
46
- custom_loader = loader.get_custom_loader(flags.FLAGS.checkpoint_path)
47
-
36
+ checkpoint_path = flags.FLAGS.checkpoint_path
48
37
  if _MODEL_SIZE.value == '1b':
49
38
  pytorch_model = gemma3.build_model_1b(
50
- flags.FLAGS.checkpoint_path,
39
+ checkpoint_path,
40
+ custom_loader=loader.maybe_get_custom_loader(
41
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
42
+ ),
51
43
  kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
52
- custom_loader=custom_loader,
53
44
  )
54
45
  else:
55
46
  raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
@@ -149,8 +149,12 @@ class Decoder(nn.Module):
149
149
  cache_len=attention_mask.shape[-1],
150
150
  sliding_window_size=sliding_window_size,
151
151
  )
152
- # Combine masks using logical AND (min in this case).
153
- combined_mask = torch.min(attention_mask, sliding_mask)
152
+ # Expand sliding_mask to match attention_mask's dimensions
153
+ # (e.g., [B, 1, seq_len, cache_len]).
154
+ # Assuming the head dimension is dim 1 for attention_mask.
155
+ expanded_sliding_mask = sliding_mask.unsqueeze(1)
156
+ # Combine masks using logical AND (min ensures -inf propagates).
157
+ combined_mask = torch.min(attention_mask, expanded_sliding_mask)
154
158
  return combined_mask
155
159
  return attention_mask
156
160
 
@@ -161,9 +165,9 @@ class Decoder(nn.Module):
161
165
  sliding_window_size: int,
162
166
  ) -> torch.Tensor:
163
167
  """Creates mask for sliding window attention (PyTorch)."""
164
- cache_positions = torch.tensor(
165
- [i for i in range(cache_len)], dtype=torch.int32
166
- )
168
+ # Use torch.arange to create a tensor with a range of integers in a
169
+ # Dynamo-friendly way.
170
+ cache_positions = torch.arange(cache_len, dtype=torch.int32)
167
171
  cache_positions = cache_positions.view(1, 1, -1) # [1, 1, cache_len]
168
172
  segment_pos_expanded = segment_pos.clone().unsqueeze(-1) # [B, seq_len, 1]
169
173
 
@@ -329,10 +333,7 @@ def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
329
333
  The model config for a Gemma 1B model.
330
334
  """
331
335
  norm_config = cfg.NormalizationConfig(
332
- type=cfg.NormalizationType.RMS_NORM,
333
- epsilon=1e-6,
334
- zero_centered=True,
335
- enable_hlfb=True,
336
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True,
336
337
  )
337
338
  ff_config = cfg.FeedForwardConfig(
338
339
  type=cfg.FeedForwardType.GATED,
@@ -379,7 +380,6 @@ def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
379
380
  block_configs=[get_block_config(i) for i in range(num_layers)],
380
381
  final_norm_config=norm_config,
381
382
  lm_head_use_bias=False,
382
- enable_hlfb=True,
383
383
  final_logit_softcap=None,
384
384
  )
385
385
  return config
@@ -158,9 +158,7 @@ def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
158
158
  image_projection_scale=128**0.5,
159
159
  image_projection_use_bias=False,
160
160
  mm_norm_config=cfg.NormalizationConfig(
161
- type=cfg.NormalizationType.LAYER_NORM,
162
- epsilon=1e-6,
163
- enable_hlfb=True,
161
+ type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
164
162
  ),
165
163
  mm_extra_tokens=32,
166
164
  )
@@ -98,9 +98,7 @@ def get_image_encoder_config() -> cfg.ModelConfig:
98
98
  output_proj_use_bias=True,
99
99
  )
100
100
  norm_config = cfg.NormalizationConfig(
101
- type=cfg.NormalizationType.LAYER_NORM,
102
- epsilon=1e-6,
103
- enable_hlfb=True,
101
+ type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
104
102
  )
105
103
  ff_config = cfg.FeedForwardConfig(
106
104
  type=cfg.FeedForwardType.SEQUENTIAL,
@@ -123,7 +121,6 @@ def get_image_encoder_config() -> cfg.ModelConfig:
123
121
  image_embedding=image_embedding_config,
124
122
  block_configs=block_config,
125
123
  final_norm_config=norm_config,
126
- enable_hlfb=True,
127
124
  num_mm_tokens_per_image=256,
128
125
  )
129
126
  return config
@@ -19,6 +19,7 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.hammer import hammer
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags('hammer')
24
25
 
@@ -36,8 +37,13 @@ _BUILDER = {
36
37
 
37
38
 
38
39
  def main(_):
40
+ checkpoint_path = flags.FLAGS.checkpoint_path
39
41
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
40
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
42
+ checkpoint_path,
43
+ custom_loader=loader.maybe_get_custom_loader(
44
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
45
+ ),
46
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
41
47
  )
42
48
  converter.convert_to_tflite(
43
49
  pytorch_model,